Skip to content

Commit

Permalink
Merge pull request #94 from AdrianKs/package
Browse files Browse the repository at this point in the history
create packaged models from checkpoint
  • Loading branch information
rgemulla authored May 26, 2020
2 parents 9a4a817 + ced8f48 commit fde4565
Show file tree
Hide file tree
Showing 18 changed files with 467 additions and 176 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,12 @@ subject-relations pairs: ('Dominican Republic', 'has form of government', ?) and

```python
import torch
import kge.model
import kge.model import KgeModel
from kge.util.io import load_checkpoint

# download link for this checkpoint given under results above
model = kge.model.KgeModel.load_from_checkpoint('fb15k-237-rescal.pt')
checkpoint = load_checkpoint('fb15k-237-rescal.pt')
model = KgeModel.create_from(checkpoint)

s = torch.Tensor([0, 2,]).long() # subject indexes
p = torch.Tensor([0, 1,]).long() # relation indexes
Expand Down
38 changes: 25 additions & 13 deletions kge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from kge.job import Job
from kge.misc import get_git_revision_short_hash, kge_base_dir, is_number
from kge.util.dump import add_dump_parsers, dump
from kge.util.io import get_checkpoint_file, load_checkpoint
from kge.util.package import package_model, add_package_parser


def argparse_bool_type(v):
Expand Down Expand Up @@ -130,6 +132,7 @@ def create_parser(config, additional_args=[]):
default="default",
)
add_dump_parsers(subparsers)
add_package_parser(subparsers)
return parser


Expand Down Expand Up @@ -163,6 +166,11 @@ def main():
dump(args)
exit()

# package command
if args.command == "package":
package_model(args)
exit()

# start command
if args.command == "start":
# use toy config file if no config given
Expand Down Expand Up @@ -231,16 +239,7 @@ def main():

# determine checkpoint to resume (if any)
if hasattr(args, "checkpoint"):
if args.checkpoint == "default":
if config.get("job.type") in ["eval", "valid"]:
checkpoint_file = config.checkpoint_file("best")
else:
checkpoint_file = None # means last
elif is_number(args.checkpoint, int) or args.checkpoint == "best":
checkpoint_file = config.checkpoint_file(args.checkpoint)
else:
# otherwise, treat it as a filename
checkpoint_file = args.checkpoint
checkpoint_file = get_checkpoint_file(config, args.checkpoint)

# disable processing of outdated cached dataset files globally
Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated
Expand Down Expand Up @@ -269,17 +268,30 @@ def main():
config.log("Job created successfully.")
else:
# load data
dataset = Dataset.load(config)
dataset = Dataset.create(config)

# let's go
job = Job.create(config, dataset)
if args.command == "resume":
job.resume(checkpoint_file)
if checkpoint_file is not None:
checkpoint = load_checkpoint(
checkpoint_file, config.get("job.device")
)
job = Job.create_from(
checkpoint, new_config=config, dataset=dataset
)
else:
job = Job.create(config, dataset)
job.config.log(
"No checkpoint found or specified, starting from scratch..."
)
else:
job = Job.create(config, dataset)
job.run()
except BaseException as e:
tb = traceback.format_exc()
config.log(tb, echo=False)
raise e from None


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion kge/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ job:
type: train

# Main device to use for this job (e.g., 'cpu', 'cuda', 'cuda:0')
device: 'cuda'
device: 'cuda'

# The seeds of the PRNGs can be set manually for (increased) reproducability.
# Use -1 to use default seed.
Expand Down
61 changes: 51 additions & 10 deletions kge/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections
import copy
import datetime
Expand Down Expand Up @@ -130,6 +132,13 @@ def get_first(self, *keys: str, use_get_default=False) -> Any:
else:
return self.get(self.get_first_present_key(*keys))

def exists(self, key: str, remove_plusplusplus=True) -> bool:
try:
self.get(key, remove_plusplusplus)
return True
except KeyError:
return False

Overwrite = Enum("Overwrite", "Yes No Error")

def set(
Expand Down Expand Up @@ -298,12 +307,13 @@ def load(
"""
with open(filename, "r") as file:
new_options = yaml.load(file, Loader=yaml.SafeLoader)
self.load_options(
new_options,
create=create,
overwrite=overwrite,
allow_deprecated=allow_deprecated,
)
if new_options is not None:
self.load_options(
new_options,
create=create,
overwrite=overwrite,
allow_deprecated=allow_deprecated,
)

def load_options(
self, new_options, create=False, overwrite=Overwrite.Yes, allow_deprecated=True
Expand Down Expand Up @@ -331,11 +341,22 @@ def load_options(
# now set all options
self.set_all(new_options, create, overwrite)

def load_config(
self, config, create=False, overwrite=Overwrite.Yes, allow_deprecated=True
):
"Like `load`, but loads from a Config object."
self.load_options(config.options, create, overwrite, allow_deprecated)

def save(self, filename):
"""Save this configuration to the given file"""
with open(filename, "w+") as file:
file.write(yaml.dump(self.options))

def save_to(self, checkpoint: Dict) -> Dict:
"""Adds the config file to a checkpoint"""
checkpoint["config"] = self
return checkpoint

@staticmethod
def flatten(options: Dict[str, Any]) -> Dict[str, Any]:
"""Returns a dictionary of flattened configuration options."""
Expand Down Expand Up @@ -423,6 +444,26 @@ def init_folder(self):
return True
return False

@staticmethod
def create_from(checkpoint: Dict) -> Config:
"""Create a config from a checkpoint."""
config = Config() # round trip to handle deprecated configs
if "config" in checkpoint and checkpoint["config"] is not None:
config.load_config(checkpoint["config"].clone())
if "folder" in checkpoint and checkpoint["folder"] is not None:
config.folder = checkpoint["folder"]
return config

@staticmethod
def from_options(options: Dict[str, Any] = {}, **more_options) -> Config:
"""Convert given options or kwargs to a Config object.
Does not perform any checks for correctness."""
config = Config(load_default=False)
config.set_all(options, create=True)
config.set_all(more_options, create=True)
return config

def checkpoint_file(self, cpt_id: Union[str, int]) -> str:
"Return path of checkpoint file for given checkpoint id"
from kge.misc import is_number
Expand All @@ -432,8 +473,8 @@ def checkpoint_file(self, cpt_id: Union[str, int]) -> str:
else:
return os.path.join(self.folder, "checkpoint_{}.pt".format(cpt_id))

def last_checkpoint(self) -> Optional[int]:
"Return epoch number of latest checkpoint"
def last_checkpoint_number(self) -> Optional[int]:
"Return number (epoch) of latest checkpoint"
# stupid implementation, but works
tried_epoch = 0
found_epoch = 0
Expand All @@ -447,13 +488,13 @@ def last_checkpoint(self) -> Optional[int]:
return None

@staticmethod
def get_best_or_last_checkpoint(path: str) -> str:
def best_or_last_checkpoint_file(path: str) -> str:
"""Return best (if present) or last checkpoint path for a given folder path."""
config = Config(folder=path, load_default=False)
checkpoint_file = config.checkpoint_file("best")
if os.path.isfile(checkpoint_file):
return checkpoint_file
cpt_epoch = config.last_checkpoint()
cpt_epoch = config.last_checkpoint_number()
if cpt_epoch:
return config.checkpoint_file(cpt_epoch)
else:
Expand Down
71 changes: 69 additions & 2 deletions kge/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import csv
import os
import sys
Expand Down Expand Up @@ -31,7 +33,7 @@ class Dataset(Configurable):
def __init__(self, config, folder=None):
"""Constructor for internal use.
To load a dataset, use `Dataset.load()`."""
To load a dataset, use `Dataset.create()`."""
super().__init__(config, "dataset")

#: directory in which dataset is stored
Expand Down Expand Up @@ -69,8 +71,20 @@ def __init__(self, config, folder=None):

## LOADING ##########################################################################

def ensure_available(self, key):
"""Checks if key can be loaded"""
if self.folder is None or not os.path.exists(self.folder):
raise IOError(
"Dataset {} not found".format(self.config.get("dataset.name"))
)
filename = self.config.get(f"dataset.files.{key}.filename")
if filename is None:
raise IOError("Filename for key {} not specified in config".format(key))
if not os.path.exists(os.path.join(self.folder, filename)):
raise IOError("File {} for key {} could not be found".format(os.path.join(self.folder, filename), key))

@staticmethod
def load(config: Config, preload_data=True):
def create(config: Config, preload_data=True):
"""Loads a dataset.
If preload_data is set, loads entity and relation maps as well as all splits.
Expand All @@ -91,6 +105,57 @@ def load(config: Config, preload_data=True):
dataset.split(split)
return dataset

@staticmethod
def create_from(
checkpoint: Dict,
config: Config = None,
dataset: Optional[Dataset] = None,
preload_data=False,
) -> Dataset:
"""Creates dataset based on a checkpoint.
If a dataset is provided, only (!) its meta data will be updated with the values
from the checkpoint. No further checks are performed.
Args:
checkpoint: loaded checkpoint
config: config (should match the one of checkpoint if set)
dataset: dataset to update
preload_data: preload data
Returns: created/updated dataset
"""
if config is None:
config = Config.create_from(checkpoint)
if dataset is None:
dataset = Dataset.create(config, preload_data)
if "dataset" in checkpoint:
dataset_checkpoint = checkpoint["dataset"]
if (
"dataset.meta" in dataset_checkpoint
and dataset_checkpoint["meta"] is not None
):
dataset._meta.update(dataset_checkpoint["meta"])
dataset._num_entities = dataset_checkpoint["num_entities"]
dataset._num_relations = dataset_checkpoint["num_relations"]
return dataset

def save_to(self, checkpoint: Dict, meta_keys: Optional[List[str]] = None) -> Dict:
"""Adds meta data to a checkpoint"""
dataset_checkpoint = {
"num_entities": self.num_entities(),
"num_relations": self.num_relations(),
}
checkpoint["dataset"] = dataset_checkpoint
if meta_keys is None:
return checkpoint
meta_checkpoint = {}
for key in meta_keys:
meta_checkpoint[key] = self.map_indexes(None, key)
checkpoint["dataset"]["meta"] = dataset_checkpoint
return checkpoint

@staticmethod
def _to_valid_filename(s):
invalid_chars = "\n\t\\/"
Expand All @@ -117,6 +182,7 @@ def _load_triples(filename: str, delimiter="\t", use_pickle=False) -> Tensor:
def load_triples(self, key: str) -> Tensor:
"Load or return the triples with the specified key."
if key not in self._triples:
self.ensure_available(key)
filename = self.config.get(f"dataset.files.{key}.filename")
filetype = self.config.get(f"dataset.files.{key}.type")
if filetype != "triples":
Expand Down Expand Up @@ -204,6 +270,7 @@ def load_map(
"""
if key not in self._meta:
self.ensure_available(key)
filename = self.config.get(f"dataset.files.{key}.filename")
filetype = self.config.get(f"dataset.files.{key}.type")
if (maptype and filetype != maptype) or (
Expand Down
35 changes: 11 additions & 24 deletions kge/job/auto_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ def __init__(self, config: Config, dataset, parent_job=None):
for f in Job.job_created_hooks:
f(self)

def load(self, filename):
self.config.log("Loading checkpoint from {}...".format(filename))
checkpoint = torch.load(filename, map_location="cpu")
self.parameters = checkpoint["parameters"]
self.results = checkpoint["results"]
return checkpoint.get("job_id")

def save(self, filename):
self.config.log("Saving checkpoint to {}...".format(filename))
torch.save(
Expand All @@ -49,24 +42,18 @@ def save(self, filename):
filename,
)

def resume(self, checkpoint_file=None):
if checkpoint_file is None:
last_checkpoint = self.config.last_checkpoint()
if last_checkpoint is not None:
checkpoint_file = self.config.checkpoint_file(last_checkpoint)

if checkpoint_file is not None:
self.resumed_from_job_id = self.load(checkpoint_file)
self.trace(
event="job_resumed", checkpoint_file=checkpoint_file
)
self.config.log(
"Resumed from {} of job {}".format(
checkpoint_file, self.resumed_from_job_id
)
def _load(self, checkpoint):
self.resumed_from_job_id = checkpoint.get("job_id")
self.parameters = checkpoint["parameters"]
self.results = checkpoint["results"]
self.trace(
event="job_resumed", checkpoint_file=checkpoint["file"]
)
self.config.log(
"Resuming search from {} of job {}".format(
checkpoint["file"], self.resumed_from_job_id
)
else:
self.config.log("No checkpoint found, starting from scratch...")
)

# -- Abstract methods --------------------------------------------------------------

Expand Down
Loading

0 comments on commit fde4565

Please sign in to comment.