Skip to content

Commit

Permalink
Introduce --force-use-checkpoint for redoing v0.3.1 runs
Browse files Browse the repository at this point in the history
  • Loading branch information
sjfleming committed Apr 10, 2024
1 parent 411cc39 commit 1fd390f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
11 changes: 10 additions & 1 deletion cellbender/remove_background/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,19 @@ def add_subparser_args(subparsers: argparse) -> argparse:
subparser.add_argument("--checkpoint", nargs=None, type=str,
dest='input_checkpoint_tarball',
required=False, default=consts.CHECKPOINT_FILE_NAME,
help="Checkpoint tarball produced by the same version "
help="Checkpoint tarball produced by v0.3.0+ "
"of CellBender remove-background. If present, "
"and the workflow hashes match, training will "
"restart from this checkpoint.")
subparser.add_argument("--force-use-checkpoint",
dest='force_use_checkpoint', action="store_true",
help="Normally, checkpoints can only be used if the CellBender "
"code and certain input args match exactly. This flag allows you "
"to bypass this requirement. An example use would be to create a new output "
"using a checkpoint from a run of v0.3.1, a redacted version with "
"faulty output count matrices. If you use this flag, "
"ensure that the input file and the checkpoint match, because "
"CellBender will not check.")
subparser.add_argument("--expected-cells", nargs=None, type=int,
default=None,
dest="expected_cell_count",
Expand Down
36 changes: 23 additions & 13 deletions cellbender/remove_background/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def save_checkpoint(filebase: str,

def load_checkpoint(filebase: Optional[str],
tarball_name: str = consts.CHECKPOINT_FILE_NAME,
force_device: Optional[str] = None)\
force_device: Optional[str] = None,
force_use_checkpoint: bool = False)\
-> Dict[str, Union['RemoveBackgroundPyroModel', pyro.optim.PyroOptim, DataLoader, bool]]:
"""Load checkpoint and prepare a RemoveBackgroundPyroModel and optimizer."""

Expand All @@ -163,6 +164,7 @@ def load_checkpoint(filebase: Optional[str],
tarball_name=tarball_name,
to_load=['model', 'optim', 'param_store', 'dataloader', 'args', 'random_state'],
force_device=force_device,
force_use_checkpoint=force_use_checkpoint,
)
out.update({'loaded': True})
logger.info(f'Loaded partially-trained checkpoint from {tarball_name}')
Expand All @@ -172,7 +174,8 @@ def load_checkpoint(filebase: Optional[str],
def load_from_checkpoint(filebase: Optional[str],
tarball_name: str = consts.CHECKPOINT_FILE_NAME,
to_load: List[str] = ['model'],
force_device: Optional[str] = None) -> Dict:
force_device: Optional[str] = None,
force_use_checkpoint: bool = False) -> Dict:
"""Load specific files from a checkpoint tarball."""

load_kwargs = {}
Expand All @@ -192,19 +195,24 @@ def load_from_checkpoint(filebase: Optional[str],
else:
# no tarball loaded, so do not continue trying to load files
raise FileNotFoundError

# See if files have a hash matching input filebase.
if filebase is not None:

# If posterior is present, do not require run hash to match: will pick up
# after training and run estimation from existing posterior.
# This smoothly allows re-runs (including for problematic v0.3.1)
logger.debug(f'force_use_checkpoint: {force_use_checkpoint}')
if force_use_checkpoint or (filebase is None):
filebase = (glob.glob(os.path.join(tmp_dir, '*_model.torch'))[0]
.replace('_model.torch', ''))
logger.debug(f'Accepting any file hash, so loading {filebase}*')

else:
# See if files have a hash matching input filebase.
basename = os.path.basename(filebase)
filebase = os.path.join(tmp_dir, basename)
logger.debug(f'Looking for files with base name matching {filebase}*')
if not os.path.exists(filebase + '_model.torch'):
logger.info('Workflow hash does not match that of checkpoint.')
raise ValueError('Workflow hash does not match that of checkpoint.')
else:
filebase = (glob.glob(os.path.join(tmp_dir, '*_model.torch'))[0]
.replace('_model.torch', ''))
logger.debug(f'Accepting any file hash, so loading {filebase}*')
raise ValueError('Workflow hash does not match that of checkpoint.')

out = {}

Expand Down Expand Up @@ -265,9 +273,10 @@ def load_from_checkpoint(filebase: Optional[str],
return out


def attempt_load_checkpoint(filebase: str,
def attempt_load_checkpoint(filebase: Optional[str],
tarball_name: str = consts.CHECKPOINT_FILE_NAME,
force_device: Optional[str] = None)\
force_device: Optional[str] = None,
force_use_checkpoint: bool = False)\
-> Dict[str, Union['RemoveBackgroundPyroModel', pyro.optim.PyroOptim, DataLoader, bool]]:
"""Load checkpoint and prepare a RemoveBackgroundPyroModel and optimizer,
or return the inputs if loading fails."""
Expand All @@ -276,7 +285,8 @@ def attempt_load_checkpoint(filebase: str,
logger.debug('Attempting to load checkpoint from ' + tarball_name)
return load_checkpoint(filebase=filebase,
tarball_name=tarball_name,
force_device=force_device)
force_device=force_device,
force_use_checkpoint=force_use_checkpoint)

except FileNotFoundError:
logger.debug('No tarball found')
Expand Down
6 changes: 4 additions & 2 deletions cellbender/remove_background/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ def _do_posterior_regularization(posterior: Posterior):
try:
ckpt_posterior = load_from_checkpoint(tarball_name=args.input_checkpoint_tarball,
filebase=args.checkpoint_filename,
to_load=['posterior'])
to_load=['posterior'],
force_use_checkpoint=args.force_use_checkpoint)
except ValueError:
# input checkpoint tarball was not a match for this workflow
# but we still may have saved a new tarball
ckpt_posterior = load_from_checkpoint(tarball_name=consts.CHECKPOINT_FILE_NAME,
filebase=args.checkpoint_filename,
to_load=['posterior'])
to_load=['posterior'],
force_use_checkpoint=args.force_use_checkpoint)
if os.path.exists(ckpt_posterior.get('posterior_file', 'does_not_exist')):
# Load posterior if it was saved in the checkpoint.
posterior.load(file=ckpt_posterior['posterior_file'])
Expand Down
3 changes: 2 additions & 1 deletion cellbender/remove_background/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,8 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset,
# Attempt to load from a previously-saved checkpoint.
ckpt = attempt_load_checkpoint(filebase=checkpoint_filename,
tarball_name=args.input_checkpoint_tarball,
force_device='cuda:0' if args.use_cuda else 'cpu')
force_device='cuda:0' if args.use_cuda else 'cpu',
force_use_checkpoint=args.force_use_checkpoint)
ckpt_loaded = ckpt['loaded'] # True if a checkpoint was loaded successfully

if ckpt_loaded:
Expand Down

0 comments on commit 1fd390f

Please sign in to comment.