diff --git a/README.md b/README.md index 93ccbe3f9..983369c84 100644 --- a/README.md +++ b/README.md @@ -326,10 +326,12 @@ subject-relations pairs: ('Dominican Republic', 'has form of government', ?) and ```python import torch -import kge.model +import kge.model import KgeModel +from kge.util.io import load_checkpoint # download link for this checkpoint given under results above -model = kge.model.KgeModel.load_from_checkpoint('fb15k-237-rescal.pt') +checkpoint = load_checkpoint('fb15k-237-rescal.pt') +model = KgeModel.create_from(checkpoint) s = torch.Tensor([0, 2,]).long() # subject indexes p = torch.Tensor([0, 1,]).long() # relation indexes diff --git a/kge/cli.py b/kge/cli.py index 6b10b8e9e..476b26cfd 100755 --- a/kge/cli.py +++ b/kge/cli.py @@ -10,6 +10,8 @@ from kge.job import Job from kge.misc import get_git_revision_short_hash, kge_base_dir, is_number from kge.util.dump import add_dump_parsers, dump +from kge.util.io import get_checkpoint_file, load_checkpoint +from kge.util.package import package_model, add_package_parser def argparse_bool_type(v): @@ -130,6 +132,7 @@ def create_parser(config, additional_args=[]): default="default", ) add_dump_parsers(subparsers) + add_package_parser(subparsers) return parser @@ -163,6 +166,11 @@ def main(): dump(args) exit() + # package command + if args.command == "package": + package_model(args) + exit() + # start command if args.command == "start": # use toy config file if no config given @@ -231,16 +239,7 @@ def main(): # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): - if args.checkpoint == "default": - if config.get("job.type") in ["eval", "valid"]: - checkpoint_file = config.checkpoint_file("best") - else: - checkpoint_file = None # means last - elif is_number(args.checkpoint, int) or args.checkpoint == "best": - checkpoint_file = config.checkpoint_file(args.checkpoint) - else: - # otherwise, treat it as a filename - checkpoint_file = args.checkpoint + checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated @@ -269,17 +268,30 @@ def main(): config.log("Job created successfully.") else: # load data - dataset = Dataset.load(config) + dataset = Dataset.create(config) # let's go - job = Job.create(config, dataset) if args.command == "resume": - job.resume(checkpoint_file) + if checkpoint_file is not None: + checkpoint = load_checkpoint( + checkpoint_file, config.get("job.device") + ) + job = Job.create_from( + checkpoint, new_config=config, dataset=dataset + ) + else: + job = Job.create(config, dataset) + job.config.log( + "No checkpoint found or specified, starting from scratch..." + ) + else: + job = Job.create(config, dataset) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None + if __name__ == "__main__": main() diff --git a/kge/config-default.yaml b/kge/config-default.yaml index bce72d8dd..13861d1e5 100644 --- a/kge/config-default.yaml +++ b/kge/config-default.yaml @@ -5,7 +5,7 @@ job: type: train # Main device to use for this job (e.g., 'cpu', 'cuda', 'cuda:0') - device: 'cuda' + device: 'cuda' # The seeds of the PRNGs can be set manually for (increased) reproducability. # Use -1 to use default seed. diff --git a/kge/config.py b/kge/config.py index 70f98568b..0bc6d88d6 100644 --- a/kge/config.py +++ b/kge/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import copy import datetime @@ -130,6 +132,13 @@ def get_first(self, *keys: str, use_get_default=False) -> Any: else: return self.get(self.get_first_present_key(*keys)) + def exists(self, key: str, remove_plusplusplus=True) -> bool: + try: + self.get(key, remove_plusplusplus) + return True + except KeyError: + return False + Overwrite = Enum("Overwrite", "Yes No Error") def set( @@ -298,12 +307,13 @@ def load( """ with open(filename, "r") as file: new_options = yaml.load(file, Loader=yaml.SafeLoader) - self.load_options( - new_options, - create=create, - overwrite=overwrite, - allow_deprecated=allow_deprecated, - ) + if new_options is not None: + self.load_options( + new_options, + create=create, + overwrite=overwrite, + allow_deprecated=allow_deprecated, + ) def load_options( self, new_options, create=False, overwrite=Overwrite.Yes, allow_deprecated=True @@ -331,11 +341,22 @@ def load_options( # now set all options self.set_all(new_options, create, overwrite) + def load_config( + self, config, create=False, overwrite=Overwrite.Yes, allow_deprecated=True + ): + "Like `load`, but loads from a Config object." + self.load_options(config.options, create, overwrite, allow_deprecated) + def save(self, filename): """Save this configuration to the given file""" with open(filename, "w+") as file: file.write(yaml.dump(self.options)) + def save_to(self, checkpoint: Dict) -> Dict: + """Adds the config file to a checkpoint""" + checkpoint["config"] = self + return checkpoint + @staticmethod def flatten(options: Dict[str, Any]) -> Dict[str, Any]: """Returns a dictionary of flattened configuration options.""" @@ -423,6 +444,26 @@ def init_folder(self): return True return False + @staticmethod + def create_from(checkpoint: Dict) -> Config: + """Create a config from a checkpoint.""" + config = Config() # round trip to handle deprecated configs + if "config" in checkpoint and checkpoint["config"] is not None: + config.load_config(checkpoint["config"].clone()) + if "folder" in checkpoint and checkpoint["folder"] is not None: + config.folder = checkpoint["folder"] + return config + + @staticmethod + def from_options(options: Dict[str, Any] = {}, **more_options) -> Config: + """Convert given options or kwargs to a Config object. + + Does not perform any checks for correctness.""" + config = Config(load_default=False) + config.set_all(options, create=True) + config.set_all(more_options, create=True) + return config + def checkpoint_file(self, cpt_id: Union[str, int]) -> str: "Return path of checkpoint file for given checkpoint id" from kge.misc import is_number @@ -432,8 +473,8 @@ def checkpoint_file(self, cpt_id: Union[str, int]) -> str: else: return os.path.join(self.folder, "checkpoint_{}.pt".format(cpt_id)) - def last_checkpoint(self) -> Optional[int]: - "Return epoch number of latest checkpoint" + def last_checkpoint_number(self) -> Optional[int]: + "Return number (epoch) of latest checkpoint" # stupid implementation, but works tried_epoch = 0 found_epoch = 0 @@ -447,13 +488,13 @@ def last_checkpoint(self) -> Optional[int]: return None @staticmethod - def get_best_or_last_checkpoint(path: str) -> str: + def best_or_last_checkpoint_file(path: str) -> str: """Return best (if present) or last checkpoint path for a given folder path.""" config = Config(folder=path, load_default=False) checkpoint_file = config.checkpoint_file("best") if os.path.isfile(checkpoint_file): return checkpoint_file - cpt_epoch = config.last_checkpoint() + cpt_epoch = config.last_checkpoint_number() if cpt_epoch: return config.checkpoint_file(cpt_epoch) else: diff --git a/kge/dataset.py b/kge/dataset.py index 5bb28b03a..33cbefc24 100644 --- a/kge/dataset.py +++ b/kge/dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import csv import os import sys @@ -31,7 +33,7 @@ class Dataset(Configurable): def __init__(self, config, folder=None): """Constructor for internal use. - To load a dataset, use `Dataset.load()`.""" + To load a dataset, use `Dataset.create()`.""" super().__init__(config, "dataset") #: directory in which dataset is stored @@ -69,8 +71,20 @@ def __init__(self, config, folder=None): ## LOADING ########################################################################## + def ensure_available(self, key): + """Checks if key can be loaded""" + if self.folder is None or not os.path.exists(self.folder): + raise IOError( + "Dataset {} not found".format(self.config.get("dataset.name")) + ) + filename = self.config.get(f"dataset.files.{key}.filename") + if filename is None: + raise IOError("Filename for key {} not specified in config".format(key)) + if not os.path.exists(os.path.join(self.folder, filename)): + raise IOError("File {} for key {} could not be found".format(os.path.join(self.folder, filename), key)) + @staticmethod - def load(config: Config, preload_data=True): + def create(config: Config, preload_data=True): """Loads a dataset. If preload_data is set, loads entity and relation maps as well as all splits. @@ -91,6 +105,57 @@ def load(config: Config, preload_data=True): dataset.split(split) return dataset + @staticmethod + def create_from( + checkpoint: Dict, + config: Config = None, + dataset: Optional[Dataset] = None, + preload_data=False, + ) -> Dataset: + """Creates dataset based on a checkpoint. + + If a dataset is provided, only (!) its meta data will be updated with the values + from the checkpoint. No further checks are performed. + + Args: + checkpoint: loaded checkpoint + config: config (should match the one of checkpoint if set) + dataset: dataset to update + preload_data: preload data + + Returns: created/updated dataset + + """ + if config is None: + config = Config.create_from(checkpoint) + if dataset is None: + dataset = Dataset.create(config, preload_data) + if "dataset" in checkpoint: + dataset_checkpoint = checkpoint["dataset"] + if ( + "dataset.meta" in dataset_checkpoint + and dataset_checkpoint["meta"] is not None + ): + dataset._meta.update(dataset_checkpoint["meta"]) + dataset._num_entities = dataset_checkpoint["num_entities"] + dataset._num_relations = dataset_checkpoint["num_relations"] + return dataset + + def save_to(self, checkpoint: Dict, meta_keys: Optional[List[str]] = None) -> Dict: + """Adds meta data to a checkpoint""" + dataset_checkpoint = { + "num_entities": self.num_entities(), + "num_relations": self.num_relations(), + } + checkpoint["dataset"] = dataset_checkpoint + if meta_keys is None: + return checkpoint + meta_checkpoint = {} + for key in meta_keys: + meta_checkpoint[key] = self.map_indexes(None, key) + checkpoint["dataset"]["meta"] = dataset_checkpoint + return checkpoint + @staticmethod def _to_valid_filename(s): invalid_chars = "\n\t\\/" @@ -117,6 +182,7 @@ def _load_triples(filename: str, delimiter="\t", use_pickle=False) -> Tensor: def load_triples(self, key: str) -> Tensor: "Load or return the triples with the specified key." if key not in self._triples: + self.ensure_available(key) filename = self.config.get(f"dataset.files.{key}.filename") filetype = self.config.get(f"dataset.files.{key}.type") if filetype != "triples": @@ -204,6 +270,7 @@ def load_map( """ if key not in self._meta: + self.ensure_available(key) filename = self.config.get(f"dataset.files.{key}.filename") filetype = self.config.get(f"dataset.files.{key}.type") if (maptype and filetype != maptype) or ( diff --git a/kge/job/auto_search.py b/kge/job/auto_search.py index 3207146db..f3168e0e3 100644 --- a/kge/job/auto_search.py +++ b/kge/job/auto_search.py @@ -30,13 +30,6 @@ def __init__(self, config: Config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def load(self, filename): - self.config.log("Loading checkpoint from {}...".format(filename)) - checkpoint = torch.load(filename, map_location="cpu") - self.parameters = checkpoint["parameters"] - self.results = checkpoint["results"] - return checkpoint.get("job_id") - def save(self, filename): self.config.log("Saving checkpoint to {}...".format(filename)) torch.save( @@ -49,24 +42,18 @@ def save(self, filename): filename, ) - def resume(self, checkpoint_file=None): - if checkpoint_file is None: - last_checkpoint = self.config.last_checkpoint() - if last_checkpoint is not None: - checkpoint_file = self.config.checkpoint_file(last_checkpoint) - - if checkpoint_file is not None: - self.resumed_from_job_id = self.load(checkpoint_file) - self.trace( - event="job_resumed", checkpoint_file=checkpoint_file - ) - self.config.log( - "Resumed from {} of job {}".format( - checkpoint_file, self.resumed_from_job_id - ) + def _load(self, checkpoint): + self.resumed_from_job_id = checkpoint.get("job_id") + self.parameters = checkpoint["parameters"] + self.results = checkpoint["results"] + self.trace( + event="job_resumed", checkpoint_file=checkpoint["file"] + ) + self.config.log( + "Resuming search from {} of job {}".format( + checkpoint["file"], self.resumed_from_job_id ) - else: - self.config.log("No checkpoint found, starting from scratch...") + ) # -- Abstract methods -------------------------------------------------------------- diff --git a/kge/job/ax_search.py b/kge/job/ax_search.py index 699bcbb4e..6f619ee9e 100644 --- a/kge/job/ax_search.py +++ b/kge/job/ax_search.py @@ -94,7 +94,7 @@ def init_search(self): num_generated = len(self.parameters) if num_generated > 0: num_sobol_generated = min( - self.ax_client.generation_strategy._curr.num_arms, num_generated + self.ax_client.generation_strategy._curr.num_trials, num_generated ) for i in range(num_sobol_generated): generator_run = self.ax_client.generation_strategy.gen( @@ -104,7 +104,7 @@ def init_search(self): self.config.log( "Skipped {} of {} Sobol trials due to prior data.".format( num_sobol_generated, - self.ax_client.generation_strategy._curr.num_arms, + self.ax_client.generation_strategy._curr.num_trials, ) ) diff --git a/kge/job/eval.py b/kge/job/eval.py index 9bc1f4759..de0c2f60c 100644 --- a/kge/job/eval.py +++ b/kge/job/eval.py @@ -1,6 +1,10 @@ import torch +from kge import Config, Dataset from kge.job import Job +from kge.model import KgeModel + +from typing import Dict, Union, Optional class EvaluationJob(Job): @@ -83,20 +87,52 @@ def run(self) -> dict: """ Compute evaluation metrics, output results to trace file """ raise NotImplementedError - def resume(self, checkpoint_file=None): - """Load model state from last or specified checkpoint.""" - # load model - from kge.job import TrainingJob - - training_job = TrainingJob.create(self.config, self.dataset) - training_job.resume(checkpoint_file) - self.model = training_job.model - self.epoch = training_job.epoch - self.resumed_from_job_id = training_job.resumed_from_job_id + def _load(self, checkpoint: Dict): + if checkpoint["type"] not in ["train", "package"]: + raise ValueError("Can only evaluate train and package checkpoints.") + self.resumed_from_job_id = checkpoint.get("job_id") + self.epoch = checkpoint["epoch"] self.trace( - event="job_resumed", epoch=self.epoch, checkpoint_file=checkpoint_file + event="job_resumed", epoch=self.epoch, checkpoint_file=checkpoint["file"] ) + @classmethod + def create_from( + cls, + checkpoint: Dict, + new_config: Config = None, + dataset: Dataset = None, + parent_job=None, + eval_split: Optional[str] = None, + ) -> Job: + """ + Creates a Job based on a checkpoint + Args: + checkpoint: loaded checkpoint + new_config: optional config object - overwrites options of config + stored in checkpoint + dataset: dataset object + parent_job: parent job (e.g. search job) + eval_split: 'valid' or 'test'. + Defines the split to evaluate on. + Overwrites split defined in new_config or config of + checkpoint. + + Returns: Evaluation-Job based on checkpoint + + """ + if new_config is None: + new_config = Config(load_default=False) + if ( + not new_config.exists("job.type") + or new_config.get("job.type") != "eval" + ): + new_config.set("job.type", "eval", create=True) + if eval_split is not None: + new_config.set("eval.split", eval_split, create=True) + + return super().create_from(checkpoint, new_config, dataset, parent_job) + # HISTOGRAM COMPUTATION ############################################################### diff --git a/kge/job/grid_search.py b/kge/job/grid_search.py index 0c8a750aa..8708678d3 100644 --- a/kge/job/grid_search.py +++ b/kge/job/grid_search.py @@ -69,7 +69,6 @@ def run(self): # and run it if self.config.get("grid_search.run"): job = Job.create(self.config, self.dataset, parent_job=self) - job.resume() job.run() else: self.config.log("Skipping running of search job as requested by user...") diff --git a/kge/job/job.py b/kge/job/job.py index 3b9353512..5822348af 100644 --- a/kge/job/job.py +++ b/kge/job/job.py @@ -1,10 +1,13 @@ +from __future__ import annotations + from kge import Config, Dataset +from kge.util import load_checkpoint import uuid from kge.misc import get_git_revision_short_hash import os import socket -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union def _trace_job_creation(job: "Job"): @@ -55,35 +58,81 @@ def __init__(self, config: Config, dataset: Dataset, parent_job: "Job" = None): for f in Job.job_created_hooks: f(self) - def resume(self, checkpoint_file: str = None): - """Load job state from last or specified checkpoint. - - Restores all relevant state to resume a previous job. To run the restored job, - use :func:`run`. + @staticmethod + def create( + config: Config, dataset: Optional[Dataset] = None, parent_job=None, model=None + ): + "Create a new job." + from kge.job import TrainingJob, EvaluationJob, SearchJob - Should set `resumed_from_job` to the job ID of the previous job. + if dataset is None: + dataset = Dataset.create(config) + + job_type = config.get("job.type") + if job_type == "train": + return TrainingJob.create( + config, dataset, parent_job=parent_job, model=model + ) + elif job_type == "search": + return SearchJob.create(config, dataset, parent_job=parent_job) + elif job_type == "eval": + return EvaluationJob.create( + config, dataset, parent_job=parent_job, model=model + ) + else: + raise ValueError("unknown job type") + @classmethod + def create_from( + cls, + checkpoint: Dict, + new_config: Config = None, + dataset: Dataset = None, + parent_job=None, + ) -> Job: """ - raise NotImplementedError + Creates a Job based on a checkpoint + Args: + checkpoint: loaded checkpoint + new_config: optional config object - overwrites options of config + stored in checkpoint + dataset: dataset object + parent_job: parent job (e.g. search job) - def run(self): - raise NotImplementedError + Returns: Job based on checkpoint - def create(config: Config, dataset: Dataset, parent_job: "Job" = None) -> "Job": - """Creates a job for a given configuration.""" + """ + from kge.model import KgeModel + + model: KgeModel = None + # search jobs don't have a model + if "model" in checkpoint and checkpoint["model"] is not None: + model = KgeModel.create_from( + checkpoint, new_config=new_config, dataset=dataset + ) + config = model.config + dataset = model.dataset + else: + config = Config.create_from(checkpoint) + if new_config: + config.load_config(new_config) + dataset = Dataset.create_from(checkpoint, config, dataset) + job = Job.create(config, dataset, parent_job, model) + job._load(checkpoint) + job.config.log("Loaded checkpoint from {}...".format(checkpoint["file"])) + return job - from kge.job import TrainingJob, EvaluationJob, SearchJob + def _load(self, checkpoint: Dict): + """Job type specific operations when created from checkpoint. - if config.get("job.type") == "train": - job = TrainingJob.create(config, dataset, parent_job) - elif config.get("job.type") == "search": - job = SearchJob.create(config, dataset, parent_job) - elif config.get("job.type") == "eval": - job = EvaluationJob.create(config, dataset, parent_job) - else: - raise ValueError("unknown job type") + Called during `create_from`. Assumes that config, dataset, and model have + already been loaded from the specified checkpoint. - return job + """ + pass + + def run(self): + raise NotImplementedError def trace(self, **kwargs) -> Dict[str, Any]: """Write a set of key-value pairs to the trace file and automatically append diff --git a/kge/job/manual_search.py b/kge/job/manual_search.py index c0ad2db41..5087cfe75 100644 --- a/kge/job/manual_search.py +++ b/kge/job/manual_search.py @@ -32,10 +32,6 @@ def __init__(self, config: Config, dataset: Dataset, parent_job=None): for f in Job.job_created_hooks: f(self) - def resume(self, checkpoint_file=None): - # no need to do anything here; run code automatically resumes - pass - def run(self): # read search configurations and expand them to full configs search_configs = copy.deepcopy(self.config.get("manual_search.configurations")) diff --git a/kge/job/search.py b/kge/job/search.py index da7e09700..e288f3f38 100644 --- a/kge/job/search.py +++ b/kge/job/search.py @@ -3,6 +3,7 @@ import concurrent.futures from kge.job import Job, Trace from kge.config import _process_deprecated_options +from kge.util.io import get_checkpoint_file, load_checkpoint class SearchJob(Job): @@ -41,6 +42,7 @@ def __init__(self, config, dataset, parent_job=None): for f in Job.job_created_hooks: f(self) + @staticmethod def create(config, dataset, parent_job=None): """Factory method to create a search job.""" @@ -130,8 +132,23 @@ def _run_train_job(sicnk, device=None): train_job_config.get("job.device"), ) ) - job = Job.create(train_job_config, search_job.dataset, parent_job=search_job) - job.resume() + checkpoint_file = get_checkpoint_file(train_job_config) + if checkpoint_file is not None: + checkpoint = load_checkpoint( + checkpoint_file, train_job_config.get("job.device") + ) + job = Job.create_from( + checkpoint=checkpoint, + new_config=train_job_config, + dataset=search_job.dataset, + parent_job=search_job, + ) + else: + job = Job.create( + config=train_job_config, + dataset=search_job.dataset, + parent_job=search_job, + ) # process the trace entries to far (in case of a resumed job) metric_name = search_job.config.get("valid.metric") diff --git a/kge/job/train.py b/kge/job/train.py index 4199cfddc..cdd24915a 100644 --- a/kge/job/train.py +++ b/kge/job/train.py @@ -15,7 +15,7 @@ from kge.model import KgeModel from kge.util import KgeLoss, KgeOptimizer, KgeSampler, KgeLRScheduler -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import kge.job.util SLOTS = [0, 1, 2] @@ -51,12 +51,15 @@ class TrainingJob(Job): """ def __init__( - self, config: Config, dataset: Dataset, parent_job: Job = None + self, config: Config, dataset: Dataset, parent_job: Job = None, model=None ) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job) - self.model: KgeModel = KgeModel.create(config, dataset) + if model is None: + self.model: KgeModel = KgeModel.create(config, dataset) + else: + self.model: KgeModel = model self.optimizer = KgeOptimizer.create(config, self.model) self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) @@ -64,6 +67,11 @@ def __init__( self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") + + self.config.check("train.trace_level", ["batch", "epoch"]) + self.trace_batch: bool = self.config.get("train.trace_level") == "batch" + self.epoch: int = 0 + self.valid_trace: List[Dict[str, Any]] = [] valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": @@ -72,12 +80,7 @@ def __init__( self.valid_job = EvaluationJob.create( valid_conf, dataset, parent_job=self, model=self.model ) - self.config.check("train.trace_level", ["batch", "epoch"]) - self.trace_batch: bool = self.config.get("train.trace_level") == "batch" - self.epoch: int = 0 - self.valid_trace: List[Dict[str, Any]] = [] self.is_prepared = False - self.model.train() # attributes filled in by implementing classes self.loader = None @@ -112,17 +115,19 @@ def __init__( for f in Job.job_created_hooks: f(self) + self.model.train() + @staticmethod def create( - config: Config, dataset: Dataset, parent_job: Job = None + config: Config, dataset: Dataset, parent_job: Job = None, model=None ) -> "TrainingJob": """Factory method to create a training job.""" if config.get("train.type") == "KvsAll": - return TrainingJobKvsAll(config, dataset, parent_job) + return TrainingJobKvsAll(config, dataset, parent_job, model=model) elif config.get("train.type") == "negative_sampling": - return TrainingJobNegativeSampling(config, dataset, parent_job) + return TrainingJobNegativeSampling(config, dataset, parent_job, model=model) elif config.get("train.type") == "1vsAll": - return TrainingJob1vsAll(config, dataset, parent_job) + return TrainingJob1vsAll(config, dataset, parent_job, model=model) else: # perhaps TODO: try class with specified name -> extensibility raise ValueError("train.type") @@ -246,32 +251,29 @@ def run(self) -> None: def save(self, filename) -> None: """Save current state to specified file""" self.config.log("Saving checkpoint to {}...".format(filename)) + checkpoint = self.save_to({}) torch.save( - { - "type": "train", - "config": self.config, - "epoch": self.epoch, - "valid_trace": self.valid_trace, - "model": self.model.save(), - "optimizer_state_dict": self.optimizer.state_dict(), - "lr_scheduler_state_dict": self.kge_lr_scheduler.state_dict(), - "job_id": self.job_id, - }, - filename, + checkpoint, filename, ) - def load(self, filename: str) -> str: - """Load job state from specified file. - - Returns job id of the job that created the checkpoint.""" - self.config.log("Loading checkpoint from {}...".format(filename)) - checkpoint = torch.load(filename, map_location="cpu") - if "model" in checkpoint: - # new format - self.model.load(checkpoint["model"]) - else: - # old format (deprecated, will eventually be removed) - self.model.load_state_dict(checkpoint["model_state_dict"]) + def save_to(self, checkpoint: Dict) -> Dict: + """Adds trainjob specific information to the checkpoint""" + train_checkpoint = { + "type": "train", + "epoch": self.epoch, + "valid_trace": self.valid_trace, + "model": self.model.save(), + "optimizer_state_dict": self.optimizer.state_dict(), + "lr_scheduler_state_dict": self.kge_lr_scheduler.state_dict(), + "job_id": self.job_id, + } + train_checkpoint = self.config.save_to(train_checkpoint) + checkpoint.update(train_checkpoint) + return checkpoint + + def _load(self, checkpoint: Dict) -> str: + if checkpoint["type"] != "train": + raise ValueError("Training can only be continued on trained checkpoints") self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if "lr_scheduler_state_dict" in checkpoint: # new format @@ -279,26 +281,15 @@ def load(self, filename: str) -> str: self.epoch = checkpoint["epoch"] self.valid_trace = checkpoint["valid_trace"] self.model.train() - return checkpoint.get("job_id") - - def resume(self, checkpoint_file: str = None) -> None: - if checkpoint_file is None: - last_checkpoint = self.config.last_checkpoint() - if last_checkpoint is not None: - checkpoint_file = self.config.checkpoint_file(last_checkpoint) - - if checkpoint_file is not None: - self.resumed_from_job_id = self.load(checkpoint_file) - self.trace( - event="job_resumed", epoch=self.epoch, checkpoint_file=checkpoint_file - ) - self.config.log( - "Resumed from {} of job {}".format( - checkpoint_file, self.resumed_from_job_id - ) + self.resumed_from_job_id = checkpoint.get("job_id") + self.trace( + event="job_resumed", epoch=self.epoch, checkpoint_file=checkpoint["file"], + ) + self.config.log( + "Resuming training from {} of job {}".format( + checkpoint["file"], self.resumed_from_job_id ) - else: - self.config.log("No checkpoint found, starting from scratch...") + ) def run_epoch(self) -> Dict[str, Any]: "Runs an epoch and returns a trace entry." @@ -514,8 +505,8 @@ class TrainingJobKvsAll(TrainingJob): - Example: a query + its labels, e.g., (John,marriedTo), [Jane] """ - def __init__(self, config, dataset, parent_job=None): - super().__init__(config, dataset, parent_job) + def __init__(self, config, dataset, parent_job=None, model=None): + super().__init__(config, dataset, parent_job, model=model) self.label_smoothing = config.check_range( "KvsAll.label_smoothing", float("-inf"), 1.0, max_inclusive=False ) @@ -777,8 +768,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: class TrainingJobNegativeSampling(TrainingJob): - def __init__(self, config, dataset, parent_job=None): - super().__init__(config, dataset, parent_job) + def __init__(self, config, dataset, parent_job=None, model=None): + super().__init__(config, dataset, parent_job, model=model) self._sampler = KgeSampler.create(config, "negative_sampling", dataset) self.is_prepared = False self._implementation = self.config.check( @@ -1027,8 +1018,8 @@ def _process_batch(self, batch_index, batch) -> TrainingJob._ProcessBatchResult: class TrainingJob1vsAll(TrainingJob): """Samples SPO pairs and queries sp_ and _po, treating all other entities as negative.""" - def __init__(self, config, dataset, parent_job=None): - super().__init__(config, dataset, parent_job) + def __init__(self, config, dataset, parent_job=None, model=None): + super().__init__(config, dataset, parent_job, model=model) self.is_prepared = False config.log("Initializing spo training job...") self.type_str = "1vsAll" diff --git a/kge/model/kge_model.py b/kge/model/kge_model.py index f59e81fcd..b3c07d14a 100644 --- a/kge/model/kge_model.py +++ b/kge/model/kge_model.py @@ -9,6 +9,7 @@ import kge from kge import Config, Configurable, Dataset from kge.misc import filename_in_module +from kge.util import load_checkpoint from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING @@ -315,12 +316,12 @@ def _init_configuration(self, config: Config, configuration_key: Optional[str]): @staticmethod def create( - config: Config, dataset: Dataset, configuration_key: str = None + config: Config, dataset: Dataset, configuration_key: Optional[str] = None ) -> "KgeModel": """Factory method for model creation.""" try: - if configuration_key: + if configuration_key is not None: model_name = config.get(configuration_key + ".type") else: model_name = config.get("model") @@ -344,7 +345,7 @@ def create( @staticmethod def create_default( model: Optional[str] = None, - dataset: Optional[Union[Dataset,str]] = None, + dataset: Optional[Union[Dataset, str]] = None, options: Dict[str, Any] = {}, folder: Optional[str] = None, ) -> "KgeModel": @@ -384,15 +385,18 @@ def create_default( # create dataset and model if not isinstance(dataset, Dataset): - dataset = Dataset.load(config) + dataset = Dataset.create(config) model = KgeModel.create(config, dataset) return model @staticmethod - def load_from_checkpoint( - filename: str, dataset=None, use_tmp_log_folder=True, device="cpu" + def create_from( + checkpoint: Dict, + dataset: Optional[Dataset] = None, + use_tmp_log_folder=True, + new_config: Config = None, ) -> "KgeModel": - """Loads a model from a checkpoint file of a training job. + """Loads a model from a checkpoint file of a training job or a packaged model. If dataset is specified, associates this dataset with the model. Otherwise uses the dataset used to train the model. @@ -402,23 +406,19 @@ def load_from_checkpoint( appended to) in the checkpoint's folder. """ + config = Config.create_from(checkpoint) + if new_config: + config.load_config(new_config) - checkpoint = torch.load(filename, map_location=device) - - original_config = checkpoint["config"] - config = Config() # round trip to handle deprecated configs - config.load_options(original_config.options) - config.set("job.device", device) if use_tmp_log_folder: import tempfile config.log_folder = tempfile.mkdtemp(prefix="kge-") else: - config.log_folder = os.path.dirname(filename) - if not config.log_folder: + config.log_folder = checkpoint["folder"] + if not config.log_folder or not os.path.exists(config.log_folder): config.log_folder = "." - if dataset is None: - dataset = Dataset.load(config, preload_data=False) + dataset = Dataset.create_from(checkpoint, config, dataset, preload_data=False) model = KgeModel.create(config, dataset) model.load(checkpoint["model"]) model.eval() diff --git a/kge/util/__init__.py b/kge/util/__init__.py index 69ad062ad..8a6ec3c79 100644 --- a/kge/util/__init__.py +++ b/kge/util/__init__.py @@ -2,3 +2,4 @@ from kge.util.optimizer import KgeOptimizer from kge.util.optimizer import KgeLRScheduler from kge.util.sampler import KgeSampler +from kge.util.io import load_checkpoint diff --git a/kge/util/dump.py b/kge/util/dump.py index 57dd72471..8c15feaee 100644 --- a/kge/util/dump.py +++ b/kge/util/dump.py @@ -85,7 +85,7 @@ def _dump_checkpoint(args): if os.path.isfile(args.source): checkpoint_file = args.source else: - checkpoint_file = Config.get_best_or_last_checkpoint(args.source) + checkpoint_file = Config.best_or_last_checkpoint_file(args.source) # Load the checkpoint and strip some fieleds checkpoint = torch.load(checkpoint_file, map_location="cpu") @@ -316,7 +316,7 @@ def _dump_trace(args): else: # determine job_id and epoch from last/best checkpoint automatically if args.checkpoint: - checkpoint_path = Config.get_best_or_last_checkpoint(args.source) + checkpoint_path = Config.best_or_last_checkpoint_file(args.source) folder_path = args.source if not checkpoint_path and truncate_flag: sys.exit( @@ -673,11 +673,11 @@ def _dump_config(args): config_file = args.source config.load(config_file) else: # a checkpoint - checkpoint_file = torch.load(args.source, map_location="cpu") + checkpoint = torch.load(args.source, map_location="cpu") if args.raw: - config = checkpoint_file["config"] + config = checkpoint["config"] else: - config.load_options(checkpoint_file["config"].options) + config.load_config(checkpoint["config"]) def print_options(options): # drop all arguments that are not included diff --git a/kge/util/io.py b/kge/util/io.py new file mode 100644 index 000000000..5ba35007b --- /dev/null +++ b/kge/util/io.py @@ -0,0 +1,46 @@ +import os +import torch +from kge import Config +from kge.misc import is_number + + +def get_checkpoint_file(config: Config, checkpoint_argument: str = "default"): + """ + Gets the path to a checkpoint file based on a config. + + Args: + config: config specifying the folder + checkpoint_argument: Which checkpoint to use: 'default', 'last', 'best', + a number or a file name + + Returns: + path to a checkpoint file + """ + if checkpoint_argument == "default": + if config.get("job.type") in ["eval", "valid"]: + checkpoint_file = config.checkpoint_file("best") + else: + last_epoch = config.last_checkpoint_number() + if last_epoch is None: + checkpoint_file = None + else: + checkpoint_file = config.checkpoint_file(last_epoch) + elif is_number(checkpoint_argument, int) or checkpoint_argument == "best": + checkpoint_file = config.checkpoint_file(checkpoint_argument) + else: + # otherwise, treat it as a filename + checkpoint_file = checkpoint_argument + return checkpoint_file + + +def load_checkpoint(checkpoint_file: str, device="cpu"): + if not os.path.exists(checkpoint_file): + raise IOError( + "Specified checkpoint file {} does not exist.".format(checkpoint_file) + ) + checkpoint = torch.load(checkpoint_file, map_location=device) + if device is not None and "config" in checkpoint: + checkpoint["config"].set("job.device", device) + checkpoint["file"] = checkpoint_file + checkpoint["folder"] = os.path.dirname(checkpoint_file) + return checkpoint diff --git a/kge/util/package.py b/kge/util/package.py new file mode 100644 index 000000000..07a2e7296 --- /dev/null +++ b/kge/util/package.py @@ -0,0 +1,47 @@ +import os +import torch +from kge import Config, Dataset +from kge.util import load_checkpoint + + +def add_package_parser(subparsers): + """Creates the parser for the command package""" + package_parser = subparsers.add_parser( + "package", help="Create packaged model (checkpoint only containing model)", + ) + package_parser.add_argument("checkpoint", type=str, help="filename of a checkpoint") + package_parser.add_argument( + "--file", type=str, help="output filename of packaged model" + ) + + +def package_model(args): + """ + Converts a checkpoint to a packaged model. + A packaged model only contains the model, entity/relation ids and the config. + """ + checkpoint_file = args.checkpoint + filename = args.file + checkpoint = load_checkpoint(checkpoint_file, device="cpu") + if checkpoint["type"] != "train": + raise ValueError("Can only package trained checkpoints.") + config = Config.create_from(checkpoint) + dataset = Dataset.create_from(checkpoint, config, preload_data=False) + packaged_model = { + "type": "package", + "model": checkpoint["model"], + "epoch": checkpoint["epoch"], + "job_id": checkpoint["job_id"], + "valid_trace": checkpoint["valid_trace"], + } + packaged_model = config.save_to(packaged_model) + packaged_model = dataset.save_to(packaged_model, ["entity_ids", "relation_ids"],) + if filename is None: + output_folder, filename = os.path.split(checkpoint_file) + if "checkpoint" in filename: + filename = filename.replace("checkpoint", "model") + else: + filename = filename.split(".pt")[0] + "_package.pt" + filename = os.path.join(output_folder, filename) + print(f"Saving to {filename}...") + torch.save(packaged_model, filename)