diff --git a/pyprophet/scoring/classifiers.py b/pyprophet/scoring/classifiers.py index 9daac76..f8f91fe 100644 --- a/pyprophet/scoring/classifiers.py +++ b/pyprophet/scoring/classifiers.py @@ -27,6 +27,7 @@ train_test_split, ) from sklearn.svm import LinearSVC +from sklearn.ensemble import HistGradientBoostingClassifier from .data_handling import Experiment @@ -299,6 +300,41 @@ def set_parameters(self, classifier): return self +class HistGBCLearner(AbstractLearner): + def __init__(self, autotune=False, threads=1): + self.classifier = None + self.importance = None + self.autotune = autotune + self.threads = threads + + def tune( + self, decoy_peaks, target_peaks, use_main_score=True, cv_splits=3, n_jobs=-1 + ): + raise NotImplementedError( + "Hyperparameter tuning for HistGradientBoostingClassifier is not implemented." + ) + + def learn(self, decoy_peaks, target_peaks, use_main_score=True): + assert isinstance(decoy_peaks, Experiment) + assert isinstance(target_peaks, Experiment) + + X0 = decoy_peaks.get_feature_matrix(use_main_score) + X1 = target_peaks.get_feature_matrix(use_main_score) + X = np.vstack((X0, X1)) + y = np.zeros((X.shape[0],)) + y[X0.shape[0] :] = 1.0 + + classifier = HistGradientBoostingClassifier( + random_state=42, max_iter=100, early_stopping=True, validation_fraction=0.1 + ) + classifier.fit(X, y) + + self.classifier = classifier + # self.importance = classifier.feature_importances_ + + return self + + class XGBLearner(AbstractLearner): """ Implements an XGBoost-based learner for scoring.