diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 365221b8e..1fa44e1ae 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -154,3 +154,23 @@ on your machine. Please refer to the [official page](http://docs.openmm.org/latest/userguide/) of the project for a full description of the installation procedure. + +## `voroscoring` + +The use of the `[voroscoring]` module requires: +- A cluster with SLURM installed +- The setup of a conda environement (e.g.: ftdmp), in which you will install FTDMP +- A functional installation of [FTDMP](https://github.com/kliment-olechnovic) + +Once those three conditions are fulfilled, when using the `[voroscoring]` module in haddock3, the configuration file must be tuned to contain parameters describing how to load the appropriate conda env (ftdmp) and where to find FTDMP scripts and executables: + +```TOML +[voroscoring] +# This parameter defines the base directory where conda / miniconda is installed +conda_install_dir = "/absolute/path/to/conda/" +# This parameter defines the name of the conda env that you created and where FTDMP is installled +conda_env_name = "ftdmp" +# This parameter defines where FTDMP scripts / executables can be found +ftdmp_install_dir = "/absolute/path/to/FTDMP/" +``` + diff --git a/examples/scoring/voroscoring-test.cfg b/examples/scoring/voroscoring-test.cfg new file mode 100644 index 000000000..0f8f9d92f --- /dev/null +++ b/examples/scoring/voroscoring-test.cfg @@ -0,0 +1,33 @@ +# ==================================================================== +# Scoring example + +# directory in which the scoring will be done +run_dir = "run1-voroscoring-test" +clean = false + +# execution mode +ncores = 3 +mode = "local" + +# ensemble of different complexes to be scored +molecules = ["data/T161-rescoring-ens.pdb", + "data/HY3.pdb", + "data/protein-dna_1w.pdb", + "data/protein-protein_1w.pdb", + "data/protein-protein_2w.pdb", + "data/protein-trimer_1w.pdb" + ] + +# ==================================================================== +# Parameters for each stage are defined below + +[topoaa] + +[voroscoring] + +[seletop] +select = 3 + +[caprieval] + +# ==================================================================== \ No newline at end of file diff --git a/src/haddock/modules/scoring/__init__.py b/src/haddock/modules/scoring/__init__.py index 9d4f3aa0e..c68b95855 100644 --- a/src/haddock/modules/scoring/__init__.py +++ b/src/haddock/modules/scoring/__init__.py @@ -1,7 +1,8 @@ """HADDOCK3 modules to score models.""" +from os import linesep import pandas as pd -from haddock.core.typing import FilePath, Path, Any +from haddock.core.typing import FilePath, Path, Any, Optional from haddock.modules.base_cns_module import BaseCNSModule from haddock.modules import BaseHaddockModule, PDBFile @@ -14,6 +15,7 @@ def output( output_fname: FilePath, sep: str = "\t", ascending_sort: bool = True, + header_comments: Optional[str] = None, ) -> None: r"""Save the output in comprehensive tables. @@ -36,11 +38,23 @@ def output( df_sc = pd.DataFrame(sc_data, columns=df_columns) df_sc_sorted = df_sc.sort_values(by="score", ascending=ascending_sort) # writes to disk - df_sc_sorted.to_csv(output_fname, - sep=sep, - index=False, - na_rep="None", - float_format="%.3f") + output_file = open(output_fname, 'a') + # Check if some comment in header are here + if header_comments: + # Make sure the comments is ending by a new line + if header_comments[-1] != linesep: + header_comments += linesep + output_file.write(header_comments) + # Write the dataframe + df_sc_sorted.to_csv( + output_file, + sep=sep, + index=False, + na_rep="None", + float_format="%.3f", + lineterminator=linesep, + ) + return diff --git a/src/haddock/modules/scoring/emscoring/__init__.py b/src/haddock/modules/scoring/emscoring/__init__.py index a8cdfbf23..dfbf07b63 100644 --- a/src/haddock/modules/scoring/emscoring/__init__.py +++ b/src/haddock/modules/scoring/emscoring/__init__.py @@ -1,7 +1,8 @@ """EM scoring module. -This module performs energy minimization and scoring of the models generated in -the previous step of the workflow. No restraints are applied during this step. +This module performs energy minimization and scoring of the models generated +in the previous step of the workflow. +Note that no restraints (AIRs) are applied during this step. """ from pathlib import Path diff --git a/src/haddock/modules/scoring/mdscoring/__init__.py b/src/haddock/modules/scoring/mdscoring/__init__.py index a0d45a7e8..0d656edd4 100644 --- a/src/haddock/modules/scoring/mdscoring/__init__.py +++ b/src/haddock/modules/scoring/mdscoring/__init__.py @@ -1,7 +1,8 @@ """MD scoring module. This module will perform a short MD simulation on the input models and -score them. No restraints are applied during this step. +score them. +Note that no restraints (AIRs) are applied during this step. """ from pathlib import Path diff --git a/src/haddock/modules/scoring/voroscoring/__init__.py b/src/haddock/modules/scoring/voroscoring/__init__.py new file mode 100644 index 000000000..e9533b6bd --- /dev/null +++ b/src/haddock/modules/scoring/voroscoring/__init__.py @@ -0,0 +1,96 @@ +"""Voro scoring module. + +This module performs scoring of input pdb models using ftdmp voro-mqa-all tool. +For more information, please check: https://github.com/kliment-olechnovic/ftdmp + +It is a third party module, and requires the appropriate set up and intallation +for it to run without issue. +""" + +from os import linesep +from pathlib import Path + +from haddock.core.defaults import MODULE_DEFAULT_YAML +from haddock.core.typing import Any, FilePath +from haddock.modules import get_engine +from haddock.modules.scoring import ScoringModule +from haddock.modules.scoring.voroscoring.voroscoring import ( + VoroMQA, + update_models_with_scores, + ) + +RECIPE_PATH = Path(__file__).resolve().parent +DEFAULT_CONFIG = Path(RECIPE_PATH, MODULE_DEFAULT_YAML) + + +class HaddockModule(ScoringModule): + """.""" + + name = RECIPE_PATH.name + + def __init__( + self, + order: int, + path: Path, + *ignore: Any, + init_params: FilePath = DEFAULT_CONFIG, + **everything: Any, + ) -> None: + """Initialize class.""" + super().__init__(order, path, init_params) + + @classmethod + def confirm_installation(cls) -> None: + """Confirm module is installed.""" + # FIXME ? Check : + # - if conda env is accessible + # - if ftdmp is accessible + return + + def _run(self) -> None: + """Execute module.""" + # Retrieve previous models + try: + models_to_score = self.previous_io.retrieve_models( + individualize=True + ) + except Exception as e: + self.finish_with_error(e) + + # Initiate VoroMQA object + output_fname = Path("voro_mqa_all.tsv") + voromqa = VoroMQA( + models_to_score, + './', + self.params, + output=output_fname, + ) + + # Launch machinery + jobs: list[VoroMQA] = [voromqa] + # Run Job(s) + self.log("Running Voro-mqa scoring") + Engine = get_engine(self.params['mode'], self.params) + engine = Engine(jobs) + engine.run() + self.log("Voro-mqa scoring finished!") + + # Update score of output models + try: + self.output_models = update_models_with_scores( + output_fname, + models_to_score, + metric=self.params["metric"], + ) + except ValueError as e: + self.finish_with_error(e) + + # Write output file + scoring_tsv_fpath = f"{RECIPE_PATH.name}.tsv" + self.output( + scoring_tsv_fpath, + header_comments=f"# Note that negative of the value are reported in the case of non-energetical predictions{linesep}", # noqa : E501 + ) + # Export to next module + self.export_io_models() + diff --git a/src/haddock/modules/scoring/voroscoring/defaults.yaml b/src/haddock/modules/scoring/voroscoring/defaults.yaml new file mode 100644 index 000000000..baabb8b62 --- /dev/null +++ b/src/haddock/modules/scoring/voroscoring/defaults.yaml @@ -0,0 +1,54 @@ +metric: + default: jury_score + type: string + choices: + - jury_score + - GNN_sum_score + - GNN_pcadscore + - voromqa_dark + - voromqa_light + - voromqa_energy + - gen_voromqa_energy + - clash_score + - area + minchars: 1 + maxchars: 50 + title: VoroMQA metric used to score. + short: VoroMQA metric used to score. + long: VoroMQA metric used to score. + group: analysis + explevel: easy + +conda_install_dir: + default: "/trinity/login/vreys/miniconda3/" + type: string + minchars: 1 + maxchars: 158 + title: Path to conda intall directory. + short: Absolute path to conda intall directory. + long: Absolute path to conda intall directory. + group: execution + explevel: easy + +conda_env_name: + default: "ftdmp5" + type: string + minchars: 1 + maxchars: 100 + title: Name of the ftdmp conda env. + short: Name of the ftdmp conda env. + long: Name of the ftdmp conda env. + group: execution + explevel: easy + +ftdmp_install_dir: + default: "/trinity/login/vreys/Venclovas/ftdmp/" + type: string + minchars: 1 + maxchars: 158 + title: Path to ftdmp intall directory. + short: Absolute path to ftdmp intall directory. + long: Absolute path to ftdmp intall directory. + group: execution + explevel: easy + diff --git a/src/haddock/modules/scoring/voroscoring/voroscoring.py b/src/haddock/modules/scoring/voroscoring/voroscoring.py new file mode 100644 index 000000000..931ee8f1a --- /dev/null +++ b/src/haddock/modules/scoring/voroscoring/voroscoring.py @@ -0,0 +1,371 @@ +"""Voro scoring class. + +This class holds all the machinery to perform scoring of input pdb models using +ftdmp voro-mqa-all tool. +For more information, please check: https://github.com/kliment-olechnovic/ftdmp + +It is a third party module, and requires the appropriate set up and intallation +for it to run without issue. +""" + +import os +import subprocess +import glob +import time + +from random import randint + +from haddock import log +from haddock.core.typing import Any, Generator, Path, Union +from haddock.libs.libio import working_directory +from haddock.libs.libontology import NaN, PDBFile + + +# Defines the SLURM job template +# Notes: Please feel free to modify the #SBATCH entries to fit your needs/setup +SLURM_HEADER_GPU = """#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --partition=gpu +#SBATCH --gres=gpu:1 +""" + +SLURM_HEADER_CPU = """#SBATCH -J hd3-voroscoring-cpu +#SBATCH --partition haddock +#SBATCH --nodes=1 +#SBATCH --tasks-per-node=1 +""" + +# Job template +VOROMQA_CFG_TEMPLATE = """#!/bin/bash +#SBATCH -J {JOBNAME} +{HEADER} + +# Where to do the work +WORKDIR="{WORKDIR}" + +# Name of the outputfile (.ssv for space separated values) +OUTPUT_FNAME="voro_scores.ssv" + +# Define Constants +CONDA_INSTALL_DIR="{CONDA_INSTALL_DIR}" +CONDA_ENV_NAME="{CONDA_ENV_NAME}" +FTDMP_INSTALL_DIR="{FTDMP_INSTALL_DIR}" +VOROMQA_SCRIPT="ftdmp-qa-all" + +# Define workflow variables +OUTPUT_FPATH="$WORKDIR/$OUTPUT_FNAME" +PDB_LIST_PATH="{PDB_LIST_PATH}" +OUT_MSG="Output file is here: $OUTPUT_FPATH" + +# 1. Setup enviroments +# Load the gnu13 module... +# NOTE: specific to haddock-team users... +# This is made to get good gcc compiler +module load gnu13 +# Activate conda env +source "$CONDA_INSTALL_DIR/bin/activate" +conda activate $CONDA_ENV_NAME +echo "conda env: $CONDA_PREFIX" + +# 2. Setup run directory +# Create working directory +mkdir -p $WORKDIR + +# 3. Run voro-mqa (model quality assessment) +# Go to ftdmp install directory +cd $FTDMP_INSTALL_DIR +echo "Directory: $PWD" +# run voro-mqa +echo "./$VOROMQA_SCRIPT --conda-path $CONDA_INSTALL_DIR --conda-env $CONDA_ENV_NAME --workdir '$WORKDIR' --rank-names 'protein_protein_voromqa_and_global_and_gnn_no_sr' < $PDB_LIST_PATH > $OUTPUT_FPATH" +./$VOROMQA_SCRIPT --conda-path $CONDA_INSTALL_DIR --conda-env $CONDA_ENV_NAME --workdir $WORKDIR --rank-names 'protein_protein_voromqa_and_global_and_gnn_no_sr' --output-redundancy-threshold 1.0 < $PDB_LIST_PATH > $OUTPUT_FPATH +# Let the magic happen.. + +# 4. Analyze results +# Print final ouput file +echo $OUT_MSG +""" # noqa : E501 + + +class VoroMQA(): + """The Haddock3 implementation of voro-mqa-all as a python class.""" + + def __init__( + self, + models: list[PDBFile], + workdir: Union[str, Path], + params: dict[str, Any], + output: Union[str, Path] = "voroscoring_voro.tsv", + ): + """Init of the VoroMQA class. + + Parameters + ---------- + models : list[PDBFile] + List of input PDB files to be scored. + workdir : Union[str, Path] + Where to do the process. + params : dict[str, Any] + Config file parameters + output : Path, optional + Name of the generated file, by default Path("voroscoring_voro.tsv") + """ + self.models = models + self.workdir = workdir + self.params = params + self.output = Path(output) + + def run(self): + """Process class logic.""" + # Obtain absolute paths + self.workdir = Path(self.workdir).resolve() + all_pdbs = [ + str(Path(mdl.path, mdl.file_name).resolve()) + for mdl in self.models + ] + # Loop over batches + for bi, batch in enumerate(self.batched(all_pdbs, size=300)): + # Run slurm + self.run_voro_batch(batch, batch_index=bi + 1) + # Recombine all batches output files + scores_fpath = self.recombine_batches() + log.info(f"Generated output file: {scores_fpath}") + + def run_voro_batch( + self, + pdb_filepaths: list[str], + batch_index: int = 1, + ) -> None: + """Preset and launch predictions on subset of pdb files. + + Parameters + ---------- + pdb_filepaths : list[str] + List of absolute path to the PDBs to score + batch_index : int, optional + Index of the batch, by default 1 + """ + # Create workdir + batch_workdir = Path(self.workdir, f"batch_{batch_index}") + batch_workdir.mkdir(parents=True) + + # Create list of pdb files + pdb_files_list_path = Path(batch_workdir, "pdbs.list") + pdb_files_list_path.write_text(os.linesep.join(pdb_filepaths)) + + # Format config file + batch_cfg = VOROMQA_CFG_TEMPLATE.format( + HEADER=SLURM_HEADER_CPU, + CONDA_INSTALL_DIR=self.params["conda_install_dir"], + CONDA_ENV_NAME=self.params["conda_env_name"], + FTDMP_INSTALL_DIR=self.params["ftdmp_install_dir"], + JOBNAME=f"hd3_voro_b{batch_index}", + WORKDIR=batch_workdir, + PDB_LIST_PATH=pdb_files_list_path, + ) + + # Write it + batch_cfg_fpath = Path(batch_workdir, "vorobatchcfg.job") + batch_cfg_fpath.write_text(batch_cfg) + + # Launch script + self._launch_computation(batch_workdir, batch_cfg_fpath) + #initdir = os.getcwd() + #os.chdir(batch_workdir) + #log.info(f"sbatch {batch_cfg_fpath}") + #subprocess.run(f"sbatch {batch_cfg_fpath}", shell=True) + #os.chdir(initdir) + + def _launch_computation(self, batch_workdir: str, batch_cfg_fpath: str) -> None: + """Execute a given script from working directory. + + Parameters + ---------- + batch_workdir : str + Path to working directory + batch_cfg_fpath : str + Script to execute + """ + exec_tool = "sbatch" if self.params["mode"] == "batch" else "bash" + cmd_ = f"{exec_tool} {batch_cfg_fpath}" + with working_directory(batch_workdir): + log.info(cmd_) + subprocess.run(cmd_, shell=True) + + def recombine_batches(self) -> str: + """Recombine batches output file in a single one. + + Returns + ------- + finale_output_fpath : str + Filepath of the recombined scores + """ + # Wait for all results to be obtained + batches_result_paths = self.wait_for_termination() + # Loop over them + all_predictions: list[dict[str, str]] = [] + combined_header: list[str] = [] + for batch_results in batches_result_paths: + # Read voro results + with open(batch_results, 'r') as filin: + header = filin.readline().strip().split(' ') + for head in header: + if head not in combined_header: + combined_header.append(head) + for line in filin: + s_ = line.strip().split(' ') + all_predictions.append({ + head: s_[header.index(head)] + for head in header + }) + + # Sort all batches entries + sorted_entries = sorted( + all_predictions, + key=lambda k: float(k[self.params["metric"]]), + reverse="_energy" not in self.params["metric"], + ) + + # Write final output file + finale_output_fpath = f"{self.workdir}/{self.output}" + with open(finale_output_fpath, "w") as filout: + file_header = '\t'.join(combined_header) + filout.write(file_header + os.linesep) + for entry in sorted_entries: + ordered_data = [ + entry[h] if h in entry.keys() else '-' + for h in combined_header + ] + line = '\t'.join(ordered_data) + filout.write(line + os.linesep) + return finale_output_fpath + + def wait_for_termination(self, wait_time: float = 60) -> list[Path]: + """Wait until all results are accessible. + + Parameters + ---------- + wait_time : int, optional + Time in second between every termination checks, by default 60 + + Returns + ------- + output_files : list[Path] + List of voro scores results for every batches. + """ + batches_dirpath = glob.glob(f"{self.workdir}/batch_*/") + log.info( + f"Waiting for {len(batches_dirpath)} " + "voro-mqa prediction batch(es) to finish..." + ) + while True: + try: + output_files: list[Path] = [] + for batch_dir in batches_dirpath: + expected_outputfile = Path(batch_dir, "voro_scores.ssv") + assert expected_outputfile.exists() + assert expected_outputfile.stat().st_size != 0 + output_files.append(expected_outputfile) + except AssertionError: + log.info(f"Waiting {wait_time} sec more...") + time.sleep(wait_time) + else: + log.info( + "VoroMQA results are accessible: " + f"{len(output_files)} batch(es)" + ) + return output_files + + @staticmethod + def batched( + entries: list[str], + size: int = 300, + ) -> Generator[list[str], None, None]: + """Generate batches of defined size. + + Parameters + ---------- + entries : list[str] + List of pdb files. + size : int, optional + Maximum size in every batch, by default 300 + + Yields + ------ + batch : Generator[list[str], None, None] + List of pdb files <= size. + """ + batch = [] + for pdb in entries: + batch.append(pdb) + if len(batch) == size: + yield batch + batch = [] + if batch: + yield batch + + +def update_models_with_scores( + voro_scoring_fname: Union[str, Path], + models: list[PDBFile], + metric: str = "jury_score", + ) -> list[PDBFile]: + """Update PDBfiles with computed scores. + + Parameters + ---------- + output_fname : Union[str, Path] + Path to the file where to access scoring data. + models : list[PDBFile] + List of PDBFiles to be updated. + metric : str, optional + Name of the metric to be retrieved, by default "jury_score" + + Returns + ------- + models : list[PDBFile] + The updated list of PDBfiles now holding the score and rank attributes. + """ + scores_mapper: dict[str, float] = {} + ranking_mapper: dict[str, int] = {} + rank: int = 0 + # Read output file + with open(voro_scoring_fname, 'r') as filin: + for i, line in enumerate(filin): + s_ = line.strip().split('\t') + # Extract header + if i == 0: + header = s_ + continue + # Extract data + modelpath = str(s_[header.index("ID")]) + score = float(s_[header.index(metric)]) + # Only extract model filename + model_filename = modelpath.split('/')[-1] + # Reverse score if not an energy + if "_energy" not in metric: + score = -score + # Hold score + scores_mapper[model_filename] = score + rank += 1 + ranking_mapper[model_filename] = rank + + # Compute rankings + #ranking_mapper = { + # model_filename: rank + # for rank, model_filename in enumerate(sorted(scores_mapper), start=1) + # } + + # Loop over input models + for model in models: + # Add score and rank as attribute + if model.file_name in scores_mapper.keys(): + model.score = scores_mapper[model.file_name] + model.rank = ranking_mapper[model.file_name] + # In some cases computation may fail + else: + # Go for (garlic cheese) naans + model.score = NaN + model.rank = NaN + model.ori_name = model.file_name + return models diff --git a/tests/test_module_voroscoring.py b/tests/test_module_voroscoring.py new file mode 100644 index 000000000..40088c36a --- /dev/null +++ b/tests/test_module_voroscoring.py @@ -0,0 +1,163 @@ +"""Test the voroscoring module.""" +import os +import pytest +import pytest_mock # noqa : F401 +import tempfile +import subprocess +import shutil + +from numpy import isnan +from pathlib import Path + +from haddock.libs.libontology import NaN, PDBFile +from haddock.modules.scoring.voroscoring import ( + DEFAULT_CONFIG as params, + HaddockModule as VoroScoringModule, + ) +from haddock.modules.scoring.voroscoring.voroscoring import ( + VoroMQA, + update_models_with_scores, + ) + +from . import golden_data + + +@pytest.fixture +def output_models(): + """Prot-DNA models using for emscoring output.""" + return [ + PDBFile( + Path(golden_data, "protdna_complex_1.pdb"), + path=golden_data, + score=-0.28, + ), + PDBFile( + Path(golden_data, "protdna_complex_2.pdb"), + path=golden_data, + score=-0.42, + ), + PDBFile( + Path(golden_data, "protdna_complex_3.pdb"), + path=golden_data, + score=NaN, + ), + ] + + +@pytest.fixture +def voromqa(output_models): + with tempfile.TemporaryDirectory(dir=".") as tmpdir: + voromqa_object = VoroMQA( + output_models, + tmpdir, + params, + Path("raw_voromqa_scores.tsv"), + ) + yield voromqa_object + + +def test_voroscoring_output(output_models): + """Test voroscoring expected output.""" + voro_module = VoroScoringModule( + order=1, + path=Path("1_voroscoring"), + initial_params=params + ) + # original names + voro_module.output_models = output_models + for mod in range(len(output_models)): + ori_name = "original_name_" + str(mod) + ".pdb" + voro_module.output_models[mod].ori_name = ori_name + # creating output + output_fname = Path("voroscoring.tsv") + voro_module.output(output_fname) + observed_outf_l = [ + e.split() + for e in open(output_fname).readlines() + if not e.startswith('#') + ] + # expected output + expected_outf_l = [ + ["structure", "original_name", "md5", "score"], + ["protdna_complex_2.pdb", "original_name_1.pdb", "None", "-0.420"], + ["protdna_complex_1.pdb", "original_name_0.pdb", "None", "-0.280"], + ["protdna_complex_3.pdb", "original_name_2.pdb", "None", "None"], + ] + + assert observed_outf_l == expected_outf_l + output_fname.unlink() + + +def test_wait_for_termination(voromqa): + """Test waiting for results function behavior in voromqa.""" + nested_batch_dir = Path(voromqa.workdir, "batch_1") + os.mkdir(nested_batch_dir) + expected_ssv = Path(nested_batch_dir, "voro_scores.ssv") + # Trick to fake the generation of a file + delay_scriptpath = Path(nested_batch_dir, "delay.sh") + delay_scriptpath.write_text( + "\n".join(["sleep 0.1", f'echo "haddock3" > {expected_ssv}']) + ) + assert delay_scriptpath.exists() + os.system(f"chmod u+x {delay_scriptpath}") + os.system(f"./{delay_scriptpath} &") + assert not expected_ssv.exists() + # The actual test of the function + batches_ssv = voromqa.wait_for_termination(wait_time=0.1) + assert expected_ssv.exists() + assert batches_ssv[0] == expected_ssv + shutil.rmtree(nested_batch_dir) + + +def test_batched(voromqa): + """Test batched function behavior in voromqa.""" + for batch in voromqa.batched(list(range(10)), size=2): + assert len(batch) == 2 + batches = list(voromqa.batched(list(range(100)), size=99)) + assert len(batches[0]) == 99 + assert len(batches[1]) == 1 + + +def test_update_models_with_scores(output_models): + """Test to update PDBFiles with scores from voromqa tsv.""" + # Generate fake voro output file + output_fname = Path("fake_voro.tsv") + output_fname.write_text( + """ID\tjury_score\tfake_energy +protdna_complex_2.pdb\t0.5256\t-2 +protdna_complex_1.pdb\t0.1234\t-1 +""" + ) + updated_models = update_models_with_scores( + output_fname, + output_models, + metric="jury_score", + ) + assert updated_models[0].score == -0.1234 + assert updated_models[0].rank == 2 + assert updated_models[1].score == -0.5256 + assert updated_models[1].rank == 1 + assert isnan(updated_models[2].score) + assert isnan(updated_models[2].rank) + + updated_models = update_models_with_scores( + output_fname, + output_models, + metric="fake_energy", + ) + assert updated_models[0].score == -1 + assert updated_models[0].rank == 2 + assert updated_models[1].score == -2 + assert updated_models[1].rank == 1 + assert isnan(updated_models[2].score) + assert isnan(updated_models[2].rank) + + # Test error raising + with pytest.raises(ValueError): + updated_models2 = update_models_with_scores( + output_fname, + output_models, + metric="wrong", + ) + assert updated_models2 is None + output_fname.unlink()