Skip to content

Commit

Permalink
Checkpoint Training Model (#187)
Browse files Browse the repository at this point in the history
* checkpointing with formatting

* indenting

* removed disable option

* checkpoint other test CI

* checkpoint warmup
  • Loading branch information
JustinBakerMath authored Jul 14, 2023
1 parent b95f136 commit 36682cc
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 3 deletions.
26 changes: 25 additions & 1 deletion hydragnn/train/train_validate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from hydragnn.utils.profile import Profiler
from hydragnn.utils.distributed import get_device, print_peak_memory
from hydragnn.preprocess.load_data import HydraDataLoader
from hydragnn.utils.model import EarlyStopping
from hydragnn.utils.model import Checkpoint, EarlyStopping

import os

Expand Down Expand Up @@ -58,6 +58,12 @@ def train_validate_test(
else False
)

SaveCheckpoint = (
config["Training"]["Checkpoint"]
if "Checkpoint" in config["Training"]
else False
)

device = get_device()
# total loss tracking for train/vali/test
total_loss_train = torch.zeros(num_epoch, device=device)
Expand Down Expand Up @@ -107,6 +113,14 @@ def train_validate_test(
if "patience" in config["Training"]:
earlystopper = EarlyStopping(patience=config["Training"]["patience"])

if SaveCheckpoint:
checkpoint = Checkpoint(name=model_with_config_name)
if "checkpoint_warmup" in config["Training"]:
checkpoint = Checkpoint(
name=model_with_config_name,
warmup=config["Training"]["checkpoint_warmup"],
)

timer = Timer("train_validate_test")
timer.start()

Expand Down Expand Up @@ -170,6 +184,16 @@ def train_validate_test(
output_names=config["Variables_of_interest"]["output_names"],
iepoch=epoch,
)

if SaveCheckpoint:
if checkpoint(model, optimizer, reduce_values_ranks(val_loss).item()):
print_distributed(
verbosity, "Creating Checkpoint: %f" % checkpoint.min_perf_metric
)
print_distributed(
verbosity, "Best Performance Metric: %f" % checkpoint.min_perf_metric
)

if EarlyStop:
if earlystopper(reduce_values_ranks(val_loss)):
print_distributed(
Expand Down
36 changes: 36 additions & 0 deletions hydragnn/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,39 @@ def __call__(self, val_loss):
self.val_loss_min = val_loss
self.count = 0
return False


class Checkpoint:
"""
Checkpoints the model and optimizer when:
+ The performance metric is smaller than the stored performance metric
Args
warmup: (int) Number of epochs to warmup prior to checkpointing.
path: (str) Path for checkpointing
name: (str) Model name for the directory and the file to save.
"""

def __init__(
self,
name: str,
warmup: int = 0,
path: str = "./logs/",
):
self.count = 1
self.warmup = warmup
self.path = path
self.name = name
self.min_perf_metric = float("inf")
self.min_delta = 0

def __call__(self, model, optimizer, perf_metric):

if (perf_metric > self.min_perf_metric + self.min_delta) or (
self.count < self.warmup
):
self.count += 1
return False
else:
self.min_perf_metric = perf_metric
save_model(model, optimizer, name=self.name, path=self.path)
return True
6 changes: 4 additions & 2 deletions tests/inputs/ci.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@
"Training": {
"num_epoch": 100,
"perc_train": 0.7,
"EarlyStopping": true,
"patience": 10,
"EarlyStopping": true,
"patience": 10,
"Checkpoint": true,
"checkpoint_warmup": 10,
"loss_function_type": "mse",
"batch_size": 32,
"Optimizer": {
Expand Down
2 changes: 2 additions & 0 deletions tests/inputs/ci_multihead.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
},
"Training": {
"num_epoch": 100,
"Checkpoint": true,
"checkpoint_warmup": 10,
"perc_train": 0.7,
"loss_function_type": "mse",
"batch_size": 16,
Expand Down
2 changes: 2 additions & 0 deletions tests/inputs/ci_vectoroutput.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
},
"Training": {
"num_epoch": 80,
"Checkpoint": true,
"checkpoint_warmup": 10,
"perc_train": 0.7,
"loss_function_type": "mse",
"batch_size": 16,
Expand Down

0 comments on commit 36682cc

Please sign in to comment.