diff --git a/cellbender/remove_background/argparser.py b/cellbender/remove_background/argparser.py index d53aca3..aebe670 100644 --- a/cellbender/remove_background/argparser.py +++ b/cellbender/remove_background/argparser.py @@ -48,10 +48,19 @@ def add_subparser_args(subparsers: argparse) -> argparse: subparser.add_argument("--checkpoint", nargs=None, type=str, dest='input_checkpoint_tarball', required=False, default=consts.CHECKPOINT_FILE_NAME, - help="Checkpoint tarball produced by the same version " + help="Checkpoint tarball produced by v0.3.0+ " "of CellBender remove-background. If present, " "and the workflow hashes match, training will " "restart from this checkpoint.") + subparser.add_argument("--force-use-checkpoint", + dest='force_use_checkpoint', action="store_true", + help="Normally, checkpoints can only be used if the CellBender " + "code and certain input args match exactly. This flag allows you " + "to bypass this requirement. An example use would be to create a new output " + "using a checkpoint from a run of v0.3.1, a redacted version with " + "faulty output count matrices. If you use this flag, " + "ensure that the input file and the checkpoint match, because " + "CellBender will not check.") subparser.add_argument("--expected-cells", nargs=None, type=int, default=None, dest="expected_cell_count", diff --git a/cellbender/remove_background/checkpoint.py b/cellbender/remove_background/checkpoint.py index 206737b..44ed12e 100644 --- a/cellbender/remove_background/checkpoint.py +++ b/cellbender/remove_background/checkpoint.py @@ -154,7 +154,8 @@ def save_checkpoint(filebase: str, def load_checkpoint(filebase: Optional[str], tarball_name: str = consts.CHECKPOINT_FILE_NAME, - force_device: Optional[str] = None)\ + force_device: Optional[str] = None, + force_use_checkpoint: bool = False)\ -> Dict[str, Union['RemoveBackgroundPyroModel', pyro.optim.PyroOptim, DataLoader, bool]]: """Load checkpoint and prepare a RemoveBackgroundPyroModel and optimizer.""" @@ -163,6 +164,7 @@ def load_checkpoint(filebase: Optional[str], tarball_name=tarball_name, to_load=['model', 'optim', 'param_store', 'dataloader', 'args', 'random_state'], force_device=force_device, + force_use_checkpoint=force_use_checkpoint, ) out.update({'loaded': True}) logger.info(f'Loaded partially-trained checkpoint from {tarball_name}') @@ -172,7 +174,8 @@ def load_checkpoint(filebase: Optional[str], def load_from_checkpoint(filebase: Optional[str], tarball_name: str = consts.CHECKPOINT_FILE_NAME, to_load: List[str] = ['model'], - force_device: Optional[str] = None) -> Dict: + force_device: Optional[str] = None, + force_use_checkpoint: bool = False) -> Dict: """Load specific files from a checkpoint tarball.""" load_kwargs = {} @@ -192,19 +195,24 @@ def load_from_checkpoint(filebase: Optional[str], else: # no tarball loaded, so do not continue trying to load files raise FileNotFoundError - - # See if files have a hash matching input filebase. - if filebase is not None: + + # If posterior is present, do not require run hash to match: will pick up + # after training and run estimation from existing posterior. + # This smoothly allows re-runs (including for problematic v0.3.1) + logger.debug(f'force_use_checkpoint: {force_use_checkpoint}') + if force_use_checkpoint or (filebase is None): + filebase = (glob.glob(os.path.join(tmp_dir, '*_model.torch'))[0] + .replace('_model.torch', '')) + logger.debug(f'Accepting any file hash, so loading {filebase}*') + + else: + # See if files have a hash matching input filebase. basename = os.path.basename(filebase) filebase = os.path.join(tmp_dir, basename) logger.debug(f'Looking for files with base name matching {filebase}*') if not os.path.exists(filebase + '_model.torch'): logger.info('Workflow hash does not match that of checkpoint.') - raise ValueError('Workflow hash does not match that of checkpoint.') - else: - filebase = (glob.glob(os.path.join(tmp_dir, '*_model.torch'))[0] - .replace('_model.torch', '')) - logger.debug(f'Accepting any file hash, so loading {filebase}*') + raise ValueError('Workflow hash does not match that of checkpoint.') out = {} @@ -265,9 +273,10 @@ def load_from_checkpoint(filebase: Optional[str], return out -def attempt_load_checkpoint(filebase: str, +def attempt_load_checkpoint(filebase: Optional[str], tarball_name: str = consts.CHECKPOINT_FILE_NAME, - force_device: Optional[str] = None)\ + force_device: Optional[str] = None, + force_use_checkpoint: bool = False)\ -> Dict[str, Union['RemoveBackgroundPyroModel', pyro.optim.PyroOptim, DataLoader, bool]]: """Load checkpoint and prepare a RemoveBackgroundPyroModel and optimizer, or return the inputs if loading fails.""" @@ -276,7 +285,8 @@ def attempt_load_checkpoint(filebase: str, logger.debug('Attempting to load checkpoint from ' + tarball_name) return load_checkpoint(filebase=filebase, tarball_name=tarball_name, - force_device=force_device) + force_device=force_device, + force_use_checkpoint=force_use_checkpoint) except FileNotFoundError: logger.debug('No tarball found') diff --git a/cellbender/remove_background/posterior.py b/cellbender/remove_background/posterior.py index d99e533..5bdd890 100644 --- a/cellbender/remove_background/posterior.py +++ b/cellbender/remove_background/posterior.py @@ -107,13 +107,15 @@ def _do_posterior_regularization(posterior: Posterior): try: ckpt_posterior = load_from_checkpoint(tarball_name=args.input_checkpoint_tarball, filebase=args.checkpoint_filename, - to_load=['posterior']) + to_load=['posterior'], + force_use_checkpoint=args.force_use_checkpoint) except ValueError: # input checkpoint tarball was not a match for this workflow # but we still may have saved a new tarball ckpt_posterior = load_from_checkpoint(tarball_name=consts.CHECKPOINT_FILE_NAME, filebase=args.checkpoint_filename, - to_load=['posterior']) + to_load=['posterior'], + force_use_checkpoint=args.force_use_checkpoint) if os.path.exists(ckpt_posterior.get('posterior_file', 'does_not_exist')): # Load posterior if it was saved in the checkpoint. posterior.load(file=ckpt_posterior['posterior_file']) diff --git a/cellbender/remove_background/run.py b/cellbender/remove_background/run.py index 6f646a0..7ac466b 100644 --- a/cellbender/remove_background/run.py +++ b/cellbender/remove_background/run.py @@ -630,7 +630,8 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset, # Attempt to load from a previously-saved checkpoint. ckpt = attempt_load_checkpoint(filebase=checkpoint_filename, tarball_name=args.input_checkpoint_tarball, - force_device='cuda:0' if args.use_cuda else 'cpu') + force_device='cuda:0' if args.use_cuda else 'cpu', + force_use_checkpoint=args.force_use_checkpoint) ckpt_loaded = ckpt['loaded'] # True if a checkpoint was loaded successfully if ckpt_loaded: