Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create packaged models from checkpoint #94

Merged
merged 37 commits into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3e11144
create packaged models from checkpoint
AdrianKs Apr 15, 2020
dfa089a
adopt packaged models to PR notes
AdrianKs Apr 15, 2020
2f86e7d
replace method resume with static load_from
AdrianKs Apr 20, 2020
5dcd265
improve resuming from checkpoints
AdrianKs Apr 20, 2020
a79927a
remove accidentally committed remove_key function
AdrianKs Apr 20, 2020
5b8878b
add load function to auto search
AdrianKs Apr 20, 2020
927c417
fix resume of search jobs
AdrianKs Apr 20, 2020
cfa552e
improve reuming from checkpoint and packaging of models
AdrianKs Apr 22, 2020
5fa535e
add num_entities and num_relations to dataset.save_to
AdrianKs Apr 23, 2020
80adebb
don't package search job checkpoints
AdrianKs Apr 23, 2020
45a2883
allow evaluation without config with EntityRankingJob.create_from
AdrianKs Apr 23, 2020
c5c163a
fix formatting
AdrianKs Apr 24, 2020
e4fecf1
address pr comments for package model
AdrianKs Apr 27, 2020
afa52fd
improve logging message in find_and_create
AdrianKs Apr 27, 2020
7801e61
fix saving of checkpoint
AdrianKs Apr 29, 2020
e8f283a
raise error if dataset not found or None
AdrianKs Apr 29, 2020
5f7968d
merge functions create_from and find_and_create_from
AdrianKs Apr 29, 2020
3942cbc
reformat job.py
AdrianKs Apr 29, 2020
f87ed13
address package-PR comments
AdrianKs May 4, 2020
8d807c1
Merge remote-tracking branch 'remotes/upstream/master' into package
AdrianKs May 22, 2020
f2a0924
separate loading of checkpoint from job.create_from
AdrianKs May 22, 2020
ddee973
separate loading of checkpoint from model.create_from
AdrianKs May 22, 2020
ebda06a
Merge branch 'package' of https://github.com/AdrianKs/kge-1 into Adri…
rgemulla May 25, 2020
cddb2e6
Support loading of old checkpoints with current embedders
rgemulla May 25, 2020
265a674
Merge branch 'master' of https://github.com/uma-pi1/kge into AdrianKs…
rgemulla May 25, 2020
cacc485
Package revisions part 1
rgemulla May 25, 2020
dea2d28
Update docs
rgemulla May 25, 2020
8f4f86d
Support empty config files
rgemulla May 25, 2020
e469544
Documentation udpate
rgemulla May 25, 2020
5db17f7
Renamed some checkpoint functions
rgemulla May 25, 2020
e90e299
SImplify Config.create_from
rgemulla May 25, 2020
a22ff1f
Minor revision of Dataset
rgemulla May 25, 2020
b0bb3a8
Fix resume of auto search jobs with new API
rgemulla May 25, 2020
30bfeff
Add some additional keys to packages
rgemulla May 25, 2020
277c68d
Add Config#load_config
rgemulla May 25, 2020
69666f3
Consistent method names in Dataset
rgemulla May 25, 2020
ced8f48
Revised Job loading
rgemulla May 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ import torch
import kge.model

# download link for this checkpoint given under results above
model = kge.model.KgeModel.load_from_checkpoint('fb15k-237-rescal.pt')
model = kge.model.KgeModel.load_from('fb15k-237-rescal.pt')

s = torch.Tensor([0, 2,]).long() # subject indexes
p = torch.Tensor([0, 1,]).long() # relation indexes
Expand Down
12 changes: 10 additions & 2 deletions kge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kge.job import Job
from kge.misc import get_git_revision_short_hash, kge_base_dir, is_number
from kge.util.dump import add_dump_parsers, dump
from kge.util.package import package_model, add_package_parser


def argparse_bool_type(v):
Expand Down Expand Up @@ -130,6 +131,7 @@ def create_parser(config, additional_args=[]):
default="default",
)
add_dump_parsers(subparsers)
add_package_parser(subparsers)
return parser


Expand Down Expand Up @@ -163,6 +165,11 @@ def main():
dump(args)
exit()

# package command
if args.command == "package":
package_model(args.checkpoint)
exit()

# start command
if args.command == "start":
# use toy config file if no config given
Expand Down Expand Up @@ -272,9 +279,10 @@ def main():
dataset = Dataset.load(config)

# let's go
job = Job.create(config, dataset)
if args.command == "resume":
job.resume(checkpoint_file)
job = Job.resume(checkpoint_file, config=config)
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
else:
job = Job.create(config, dataset)
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
job.run()
except BaseException as e:
tb = traceback.format_exc()
Expand Down
22 changes: 22 additions & 0 deletions kge/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections
import copy
import datetime
Expand Down Expand Up @@ -423,6 +425,26 @@ def init_folder(self):
return True
return False

@staticmethod
def load_from(checkpoint_config: Config, new_config=None, folder=None) -> Config:
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
"""
Load a config from checkpoint and update all deprecated keys.
Overwrite options in checkpoint config with a new config.
Args:
checkpoint_config: Config stored in checkpoint
new_config: new config with options to overwrite

Returns: Config object

"""
config = Config() # round trip to handle deprecated configs
config.load_options(checkpoint_config.options)
if folder is not None:
config.folder = folder
if new_config is not None:
config.load_options(new_config.options)
return config

def checkpoint_file(self, cpt_id: Union[str, int]) -> str:
"Return path of checkpoint file for given checkpoint id"
from kge.misc import is_number
Expand Down
16 changes: 16 additions & 0 deletions kge/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import csv
import os
import sys
Expand Down Expand Up @@ -91,6 +93,20 @@ def load(config: Config, preload_data=True):
dataset.split(split)
return dataset

@staticmethod
def load_meta(config: Config, meta_checkpoint: Dict, preload_data=True) -> Dataset:
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
"""Loads a dataset and overrides meta data."""
dataset = Dataset.load(config, preload_data)
dataset._meta.update(meta_checkpoint)
return dataset

def save_meta(self, keys: List[str]):
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
"""Creates a dataset_meta dictionary for a checkpoint."""
meta_checkpoint = {}
for key in keys:
meta_checkpoint[key] = self._meta[key]
return meta_checkpoint

@staticmethod
def _to_valid_filename(s):
invalid_chars = "\n\t\\/"
Expand Down
26 changes: 9 additions & 17 deletions kge/job/auto_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,16 @@ def save(self, filename):
filename,
)

def resume(self, checkpoint_file=None):
if checkpoint_file is None:
last_checkpoint = self.config.last_checkpoint()
if last_checkpoint is not None:
checkpoint_file = self.config.checkpoint_file(last_checkpoint)

if checkpoint_file is not None:
self.resumed_from_job_id = self.load(checkpoint_file)
self.trace(
event="job_resumed", checkpoint_file=checkpoint_file
)
self.config.log(
"Resumed from {} of job {}".format(
checkpoint_file, self.resumed_from_job_id
)
def load(self, checkpoint):
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
self.resumed_from_job_id = checkpoint.get("job_id")
self.trace(
event="job_resumed", checkpoint_file=checkpoint["file"]
)
self.config.log(
"Resumed from {} of job {}".format(
checkpoint["file"], self.resumed_from_job_id
)
else:
self.config.log("No checkpoint found, starting from scratch...")
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
)

# -- Abstract methods --------------------------------------------------------------

Expand Down
33 changes: 22 additions & 11 deletions kge/job/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch

from kge import Config, Dataset
from kge.job import Job
from kge.model import KgeModel

from typing import Dict, Union


class EvaluationJob(Job):
Expand Down Expand Up @@ -83,20 +87,27 @@ def run(self) -> dict:
""" Compute evaluation metrics, output results to trace file """
raise NotImplementedError

def resume(self, checkpoint_file=None):
"""Load model state from last or specified checkpoint."""
# load model
from kge.job import TrainingJob

training_job = TrainingJob.create(self.config, self.dataset)
training_job.resume(checkpoint_file)
self.model = training_job.model
self.epoch = training_job.epoch
self.resumed_from_job_id = training_job.resumed_from_job_id
def load(self, checkpoint, model):
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
self.resumed_from_job_id = checkpoint.get("job_id")
self.epoch = checkpoint["epoch"]
self.trace(
event="job_resumed", epoch=self.epoch, checkpoint_file=checkpoint_file
event="job_resumed", epoch=self.epoch, checkpoint_file=checkpoint["file"]
)

@classmethod
def load_from(
cls,
checkpoint: Union[str, Dict],
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
config: Config = None,
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dataset = None,
parent_job=None,
) -> Job:
if config is None:
config = Config()
config.set("job.type", "eval")
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
eval_job = super().load_from(checkpoint, config, dataset, parent_job=parent_job)
return eval_job


# HISTOGRAM COMPUTATION ###############################################################

Expand Down
1 change: 0 additions & 1 deletion kge/job/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def run(self):
# and run it
if self.config.get("grid_search.run"):
job = Job.create(self.config, self.dataset, parent_job=self)
job.resume()
job.run()
else:
self.config.log("Skipping running of search job as requested by user...")
100 changes: 76 additions & 24 deletions kge/job/job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from kge import Config, Dataset
from kge.model import KgeModel
from kge.util import load_checkpoint
import uuid

from kge.misc import get_git_revision_short_hash
import os
import socket
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union


def _trace_job_creation(job: "Job"):
Expand Down Expand Up @@ -55,36 +59,84 @@ def __init__(self, config: Config, dataset: Dataset, parent_job: "Job" = None):
for f in Job.job_created_hooks:
f(self)

def resume(self, checkpoint_file: str = None):
"""Load job state from last or specified checkpoint.

Restores all relevant state to resume a previous job. To run the restored job,
use :func:`run`.

Should set `resumed_from_job` to the job ID of the previous job.

"""
raise NotImplementedError

def run(self):
raise NotImplementedError

def create(config: Config, dataset: Dataset, parent_job: "Job" = None) -> "Job":
"""Creates a job for a given configuration."""

@staticmethod
def create(config: Config, dataset: Dataset, parent_job=None, model=None):
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
from kge.job import TrainingJob, EvaluationJob, SearchJob

if config.get("job.type") == "train":
job = TrainingJob.create(config, dataset, parent_job)
elif config.get("job.type") == "search":
job = SearchJob.create(config, dataset, parent_job)
elif config.get("job.type") == "eval":
job = EvaluationJob.create(config, dataset, parent_job)
job_type = config.get("job.type")
if job_type == "train":
return TrainingJob.create(
config, dataset, parent_job=parent_job, model=model
)
elif job_type == "search":
return SearchJob.create(config, dataset, parent_job=parent_job)
elif job_type == "eval":
return EvaluationJob.create(
config, dataset, parent_job=parent_job, model=model
)
else:
raise ValueError("unknown job type")

@staticmethod
def resume(
checkpoint: str = None,
config: Config = None,
dataset: Dataset = None,
parent_job=None,
) -> Job:
if checkpoint is None and config is None:
raise ValueError(
"Please provide either the config file located in the folder structure "
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
"containing the checkpoint or the checkpoint itself."
)
elif checkpoint is None:
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
last_checkpoint = config.last_checkpoint()
if last_checkpoint is not None:
checkpoint = config.checkpoint_file(last_checkpoint)

if checkpoint is not None:
job = Job.load_from(checkpoint, config, dataset, parent_job=parent_job)
if type(checkpoint) == str:
job.config.log("Loading checkpoint from {}...".format(checkpoint))
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
else:
job = Job.create(config, dataset, parent_job=parent_job)
job.config.log("No checkpoint found, starting from scratch...")
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
return job

@classmethod
def load_from(
cls,
checkpoint: Union[str, Dict],
config: Config = None,
dataset: Dataset = None,
parent_job=None,
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
) -> Job:
if config is not None:
device = config.get("job.device")
else:
device = "cpu"
if type(checkpoint) == str:
checkpoint = load_checkpoint(checkpoint, device)
config = Config.load_from(
checkpoint["config"], config, folder=checkpoint["folder"]
)
model: KgeModel = None
dataset = None
if checkpoint["model"] is not None:
model = KgeModel.load_from(checkpoint, config=config, dataset=dataset)
dataset = model.dataset
if dataset is None:
dataset = Dataset.load(config)
job = Job.create(config, dataset, parent_job, model)
job.load(checkpoint, model)
return job

def load(self, checkpoint, model):
pass

def run(self):
raise NotImplementedError

def trace(self, **kwargs) -> Dict[str, Any]:
"""Write a set of key-value pairs to the trace file and automatically append
information about this job. See `Config.trace` for more information."""
Expand Down
4 changes: 0 additions & 4 deletions kge/job/manual_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ def __init__(self, config: Config, dataset: Dataset, parent_job=None):
for f in Job.job_created_hooks:
f(self)

def resume(self, checkpoint_file=None):
# no need to do anything here; run code automatically resumes
pass

def run(self):
# read search configurations and expand them to full configs
search_configs = copy.deepcopy(self.config.get("manual_search.configurations"))
Expand Down
9 changes: 7 additions & 2 deletions kge/job/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, config, dataset, parent_job=None):
for f in Job.job_created_hooks:
f(self)

@staticmethod
def create(config, dataset, parent_job=None):
"""Factory method to create a search job."""

Expand Down Expand Up @@ -130,8 +131,12 @@ def _run_train_job(sicnk, device=None):
train_job_config.get("job.device"),
)
)
job = Job.create(train_job_config, search_job.dataset, parent_job=search_job)
job.resume()
job = Job.resume(
AdrianKs marked this conversation as resolved.
Show resolved Hide resolved
checkpoint=None,
config=train_job_config,
dataset=search_job.dataset,
parent_job=search_job,
)

# process the trace entries to far (in case of a resumed job)
metric_name = search_job.config.get("valid.metric")
Expand Down
Loading