diff --git a/pyprophet/pyprophet.py b/pyprophet/pyprophet.py index 54fe39e3..130989c0 100644 --- a/pyprophet/pyprophet.py +++ b/pyprophet/pyprophet.py @@ -347,7 +347,13 @@ def _build_result(self, table, final_classifier, score_columns, experiment): result = Result(summary_statistics, final_statistics, scored_table) + # Set feature names in XGBoost classifier + if self.classifier == "XGBoost": + classifier_table.feature_names = list(score_columns) + classifier_table.save_model('xgboostModel.json') + click.echo("Info: Finished scoring and estimation statistics.") + return result, scorer, classifier_table diff --git a/pyprophet/runner.py b/pyprophet/runner.py index 78003a89..308b631d 100644 --- a/pyprophet/runner.py +++ b/pyprophet/runner.py @@ -13,6 +13,8 @@ from .report import save_report from .data_handling import is_sqlite_file, check_sqlite_table from shutil import copyfile +import json +import xgboost as xgb try: profile @@ -403,17 +405,30 @@ def save_osw_weights(self, weights): elif self.classifier == "XGBoost": con = sqlite3.connect(self.outfile) + #### to ensure a consistent model save model as .json and then load .json and attach to sqlite + json_file_path = "tmp.json" + weights.save_model(json_file_path) + + json_object = '' + with open(json_file_path, 'r') as json_file: + json_object = json.load(json_file) + c = con.cursor() c.execute('SELECT count(name) FROM sqlite_master WHERE type="table" AND name="PYPROPHET_XGB";') if c.fetchone()[0] == 1: c.execute('DELETE FROM PYPROPHET_XGB WHERE LEVEL =="%s"' % self.level) else: - c.execute('CREATE TABLE PYPROPHET_XGB (level TEXT, xgb BLOB)') + c.execute('CREATE TABLE PYPROPHET_XGB (level TEXT, xgb TEXT)') + + c.execute('INSERT INTO PYPROPHET_XGB VALUES(?, ?)', [self.level, json.dumps(json_object)]) - c.execute('INSERT INTO PYPROPHET_XGB VALUES(?, ?)', [self.level, pickle.dumps(weights)]) con.commit() c.close() + ## delete the temporary json object + if os.path.exists(json_file_path): + os.remove(json_file_path) + def save_bin_weights(self, weights, extra_writes): trained_weights_path = extra_writes.get("trained_model_path_" + self.level) if trained_weights_path is not None: @@ -481,9 +496,22 @@ def __init__(self, infile, outfile, classifier, xgb_hyperparams, xgb_params, xgb if not check_sqlite_table(con, "PYPROPHET_XGB"): raise click.ClickException("PYPROPHET_XGB table is not present in file, cannot apply weights for XGBoost classifier! Make sure you have run the scoring on a subset of the data first, or that you supplied the right `--classifier` parameter.") - data = con.execute("SELECT xgb FROM PYPROPHET_XGB WHERE LEVEL=='%s'" % self.level).fetchone() + data = json.loads(con.execute("SELECT xgb FROM PYPROPHET_XGB WHERE LEVEL=='%s'" % self.level).fetchone()[0]) con.close() - self.persisted_weights = pickle.loads(data[0]) + + ## dump json to temporary file so it can be loaded proerly with xgboost + json_file_path='tmp.json' + with open(json_file_path, 'w') as outfile: + json.dump(data, outfile) + + + self.persisted_weights = xgb.Booster() + self.persisted_weights.load_model(json_file_path) + + ## remove feature names for compatibility with other methods TODO: allow for apply weights to work with feature names + self.persisted_weights.feature_names = None + + except Exception: import traceback traceback.print_exc()