From 36682ccee1214acc210925b8b4a65a4551c306e7 Mon Sep 17 00:00:00 2001 From: Justin Baker Date: Fri, 14 Jul 2023 11:34:26 -0400 Subject: [PATCH] Checkpoint Training Model (#187) * checkpointing with formatting * indenting * removed disable option * checkpoint other test CI * checkpoint warmup --- hydragnn/train/train_validate_test.py | 26 ++++++++++++++++++- hydragnn/utils/model.py | 36 +++++++++++++++++++++++++++ tests/inputs/ci.json | 6 +++-- tests/inputs/ci_multihead.json | 2 ++ tests/inputs/ci_vectoroutput.json | 2 ++ 5 files changed, 69 insertions(+), 3 deletions(-) diff --git a/hydragnn/train/train_validate_test.py b/hydragnn/train/train_validate_test.py index 7bc0a03d..8dd82843 100644 --- a/hydragnn/train/train_validate_test.py +++ b/hydragnn/train/train_validate_test.py @@ -22,7 +22,7 @@ from hydragnn.utils.profile import Profiler from hydragnn.utils.distributed import get_device, print_peak_memory from hydragnn.preprocess.load_data import HydraDataLoader -from hydragnn.utils.model import EarlyStopping +from hydragnn.utils.model import Checkpoint, EarlyStopping import os @@ -58,6 +58,12 @@ def train_validate_test( else False ) + SaveCheckpoint = ( + config["Training"]["Checkpoint"] + if "Checkpoint" in config["Training"] + else False + ) + device = get_device() # total loss tracking for train/vali/test total_loss_train = torch.zeros(num_epoch, device=device) @@ -107,6 +113,14 @@ def train_validate_test( if "patience" in config["Training"]: earlystopper = EarlyStopping(patience=config["Training"]["patience"]) + if SaveCheckpoint: + checkpoint = Checkpoint(name=model_with_config_name) + if "checkpoint_warmup" in config["Training"]: + checkpoint = Checkpoint( + name=model_with_config_name, + warmup=config["Training"]["checkpoint_warmup"], + ) + timer = Timer("train_validate_test") timer.start() @@ -170,6 +184,16 @@ def train_validate_test( output_names=config["Variables_of_interest"]["output_names"], iepoch=epoch, ) + + if SaveCheckpoint: + if checkpoint(model, optimizer, reduce_values_ranks(val_loss).item()): + print_distributed( + verbosity, "Creating Checkpoint: %f" % checkpoint.min_perf_metric + ) + print_distributed( + verbosity, "Best Performance Metric: %f" % checkpoint.min_perf_metric + ) + if EarlyStop: if earlystopper(reduce_values_ranks(val_loss)): print_distributed( diff --git a/hydragnn/utils/model.py b/hydragnn/utils/model.py index 236f81e8..b8f6dc85 100644 --- a/hydragnn/utils/model.py +++ b/hydragnn/utils/model.py @@ -159,3 +159,39 @@ def __call__(self, val_loss): self.val_loss_min = val_loss self.count = 0 return False + + +class Checkpoint: + """ + Checkpoints the model and optimizer when: + + The performance metric is smaller than the stored performance metric + Args + warmup: (int) Number of epochs to warmup prior to checkpointing. + path: (str) Path for checkpointing + name: (str) Model name for the directory and the file to save. + """ + + def __init__( + self, + name: str, + warmup: int = 0, + path: str = "./logs/", + ): + self.count = 1 + self.warmup = warmup + self.path = path + self.name = name + self.min_perf_metric = float("inf") + self.min_delta = 0 + + def __call__(self, model, optimizer, perf_metric): + + if (perf_metric > self.min_perf_metric + self.min_delta) or ( + self.count < self.warmup + ): + self.count += 1 + return False + else: + self.min_perf_metric = perf_metric + save_model(model, optimizer, name=self.name, path=self.path) + return True diff --git a/tests/inputs/ci.json b/tests/inputs/ci.json index dfb8fd9f..36613eaa 100644 --- a/tests/inputs/ci.json +++ b/tests/inputs/ci.json @@ -66,8 +66,10 @@ "Training": { "num_epoch": 100, "perc_train": 0.7, - "EarlyStopping": true, - "patience": 10, + "EarlyStopping": true, + "patience": 10, + "Checkpoint": true, + "checkpoint_warmup": 10, "loss_function_type": "mse", "batch_size": 32, "Optimizer": { diff --git a/tests/inputs/ci_multihead.json b/tests/inputs/ci_multihead.json index bcbb9ffd..4d0a743e 100644 --- a/tests/inputs/ci_multihead.json +++ b/tests/inputs/ci_multihead.json @@ -63,6 +63,8 @@ }, "Training": { "num_epoch": 100, + "Checkpoint": true, + "checkpoint_warmup": 10, "perc_train": 0.7, "loss_function_type": "mse", "batch_size": 16, diff --git a/tests/inputs/ci_vectoroutput.json b/tests/inputs/ci_vectoroutput.json index f310a244..ddd34e61 100644 --- a/tests/inputs/ci_vectoroutput.json +++ b/tests/inputs/ci_vectoroutput.json @@ -53,6 +53,8 @@ }, "Training": { "num_epoch": 80, + "Checkpoint": true, + "checkpoint_warmup": 10, "perc_train": 0.7, "loss_function_type": "mse", "batch_size": 16,