From b049f2bd9e53695d635f84a6b1c0f4f74759d2e0 Mon Sep 17 00:00:00 2001 From: NicolasRR Date: Mon, 15 Apr 2024 16:09:06 +0000 Subject: [PATCH 1/4] implemented checkpointing and retrieval --- src/config/base.py | 2 ++ src/config/sparse.py | 2 ++ src/main.py | 22 +++++++++++++++++++++- src/optim/base.py | 29 +++++++++++++++++++++++++---- src/optim/sparse.py | 27 ++++++++++++++++++++++++--- 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/src/config/base.py b/src/config/base.py index 36f732d..a1cae7a 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -45,4 +45,6 @@ def parse_args(base_parser, args, namespace): # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type + parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) + return parser.parse_args(args, namespace) diff --git a/src/config/sparse.py b/src/config/sparse.py index 80f8d22..d17855f 100644 --- a/src/config/sparse.py +++ b/src/config/sparse.py @@ -47,4 +47,6 @@ def parse_args(base_parser, args, namespace): # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type + parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) + return parser.parse_args(args, namespace) diff --git a/src/main.py b/src/main.py index 301c1a9..bcf4b23 100755 --- a/src/main.py +++ b/src/main.py @@ -112,6 +112,26 @@ def main(args): elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed print(f"Already found experiment '{ckpt_path}'.\nSkipping.") sys.exit(0) + itr = 0 + rng_state_dict = None + checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] + if checkpoints: + last_ckpt_path = sorted(checkpoints)[-1] + print(f"Training interrupted, resuming from {last_ckpt_path}") + checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path)) + model_state_dict = {k.replace("_orig_mod.", ""):v for k,v in checkpoint['model'].items()} + # FIXME checkpoints from compiled model have _orig_mod keyword + + optimizer_state_dict = checkpoint['optimizer'] + scheduler_state_dict = checkpoint['scheduler'] + rng_state_dict = { + module: checkpoint[module] for module in ["cpu_rng_state", "gpu_rng_state", "numpy_rng_state", "py_rng_state"] + } + + model.load_state_dict(model_state_dict) + opt.load_state_dict(optimizer_state_dict) + scheduler.load_state_dict(scheduler_state_dict) + itr=checkpoint['itr'] if args.model == 'base': # all train functions have the same interface train = train_base @@ -125,7 +145,7 @@ def main(args): stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, eval_freq=args.eval_freq, distributed_backend=distributed_backend, - ckpt_path=f"{ckpt_path}/ckpt.pt", extra_args=args) + ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args) args.device = None args.dtype = None diff --git a/src/optim/base.py b/src/optim/base.py index e73ecbc..10ed0a0 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -7,15 +7,17 @@ import time import itertools import copy - +import random +import os +import numpy as np from .utils import eval, get_batch, save_checkpoint -def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): +def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None): device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) - itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible + substep, best_val_loss, text_table = 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible data["train"], train_sampler = get_dataloader( data["train"], @@ -50,6 +52,12 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba t0 = time.time() train_epochs = 0 while itr < iterations: + if not rng_state_dict is None: + torch.set_rng_state(rng_state_dict["cpu_rng_state"]) + torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) + np.random.set_state(rng_state_dict["numpy_rng_state"]) + random.setstate(rng_state_dict["py_rng_state"]) + for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) with type_ctx: @@ -122,7 +130,20 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba model.train() t0 = time.time() - + if distributed_backend.is_master_process(): + if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: + print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt") + save_checkpoint(distributed_backend=distributed_backend, + model=model, + opt=opt, + scheduler=scheduler, + itr=itr, + cpu_rng_state=torch.get_rng_state(), + gpu_rng_state=torch.cuda.get_rng_state(), + numpy_rng_state=np.random.get_state(), + py_rng_state=random.getstate(), + ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt")) + if distributed_backend.is_master_process(): print(f"saving checkpoint to {ckpt_path}") save_checkpoint(distributed_backend=distributed_backend, diff --git a/src/optim/sparse.py b/src/optim/sparse.py index 729e744..e5715c6 100755 --- a/src/optim/sparse.py +++ b/src/optim/sparse.py @@ -5,15 +5,16 @@ import wandb import time import copy - +import numpy as np +import random from .utils import eval_sparse, get_batch, eval_sweep_dropk, save_checkpoint -def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): +def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args, itr=0, rng_state_dict=None): device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) - itr, substep, best_val_loss, text_table, sparsity_plot = 0, 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible + substep, best_val_loss, text_table, sparsity_plot = 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible data["train"] = get_dataloader( data["train"], sequence_length=sequence_length, @@ -43,6 +44,12 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, t0 = time.time() while itr < iterations: + if not rng_state_dict is None: + torch.set_rng_state(rng_state_dict["cpu_rng_state"]) + torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) + np.random.set_state(rng_state_dict["numpy_rng_state"]) + random.setstate(rng_state_dict["py_rng_state"]) + for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) with type_ctx: @@ -129,6 +136,20 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, model.train() t0 = time.time() + if distributed_backend.is_master_process(): + if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: + print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt") + save_checkpoint(distributed_backend=distributed_backend, + model=model, + opt=opt, + scheduler=scheduler, + itr=itr, + cpu_rng_state=torch.get_rng_state(), + gpu_rng_state=torch.cuda.get_rng_state(), + numpy_rng_state=np.random.get_state(), + py_rng_state=random.getstate(), + ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt") + if distributed_backend.is_master_process(): print(f"saving checkpoint to {ckpt_path}") From b51d951ab70f98d8baaf4dc15021e5266d96932c Mon Sep 17 00:00:00 2001 From: NicolasRR Date: Mon, 15 Apr 2024 18:45:31 +0200 Subject: [PATCH 2/4] fixed scheduler and random state dict --- src/main.py | 5 +++-- src/optim/base.py | 4 ++-- src/optim/sparse.py | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/main.py b/src/main.py index bcf4b23..9d11501 100755 --- a/src/main.py +++ b/src/main.py @@ -123,15 +123,16 @@ def main(args): # FIXME checkpoints from compiled model have _orig_mod keyword optimizer_state_dict = checkpoint['optimizer'] - scheduler_state_dict = checkpoint['scheduler'] rng_state_dict = { module: checkpoint[module] for module in ["cpu_rng_state", "gpu_rng_state", "numpy_rng_state", "py_rng_state"] } model.load_state_dict(model_state_dict) opt.load_state_dict(optimizer_state_dict) - scheduler.load_state_dict(scheduler_state_dict) itr=checkpoint['itr'] + if not scheduler is None: + scheduler_state_dict = checkpoint['scheduler'] + scheduler.load_state_dict(scheduler_state_dict) if args.model == 'base': # all train functions have the same interface train = train_base diff --git a/src/optim/base.py b/src/optim/base.py index 10ed0a0..865252c 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -51,12 +51,12 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba t0 = time.time() train_epochs = 0 - while itr < iterations: - if not rng_state_dict is None: + if not rng_state_dict is None: torch.set_rng_state(rng_state_dict["cpu_rng_state"]) torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) np.random.set_state(rng_state_dict["numpy_rng_state"]) random.setstate(rng_state_dict["py_rng_state"]) + while itr < iterations: for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) diff --git a/src/optim/sparse.py b/src/optim/sparse.py index e5715c6..39864d6 100755 --- a/src/optim/sparse.py +++ b/src/optim/sparse.py @@ -43,13 +43,14 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, model.train() t0 = time.time() - while itr < iterations: - if not rng_state_dict is None: + if not rng_state_dict is None: torch.set_rng_state(rng_state_dict["cpu_rng_state"]) torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) np.random.set_state(rng_state_dict["numpy_rng_state"]) random.setstate(rng_state_dict["py_rng_state"]) + while itr < iterations: + for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) with type_ctx: From 4531e469c9623b72915a7ad81c0f3626cd59594a Mon Sep 17 00:00:00 2001 From: NicolasRR Date: Mon, 15 Apr 2024 19:06:32 +0200 Subject: [PATCH 3/4] ensure master created the ckpt folder --- src/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main.py b/src/main.py index 9d11501..79a509f 100755 --- a/src/main.py +++ b/src/main.py @@ -114,6 +114,7 @@ def main(args): sys.exit(0) itr = 0 rng_state_dict = None + distributed_backend.sync() checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] if checkpoints: last_ckpt_path = sorted(checkpoints)[-1] From 25884dac97bffa101beeff838a75795fd3e3e469 Mon Sep 17 00:00:00 2001 From: NicolasRR Date: Tue, 16 Apr 2024 18:21:06 +0200 Subject: [PATCH 4/4] minor fixes --- src/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index 79a509f..d7554da 100755 --- a/src/main.py +++ b/src/main.py @@ -109,12 +109,12 @@ def main(args): if not os.path.exists(ckpt_path): if distributed_backend.is_master_process(): os.makedirs(ckpt_path) + distributed_backend.sync() elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed print(f"Already found experiment '{ckpt_path}'.\nSkipping.") sys.exit(0) itr = 0 rng_state_dict = None - distributed_backend.sync() checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] if checkpoints: last_ckpt_path = sorted(checkpoints)[-1] @@ -131,7 +131,7 @@ def main(args): model.load_state_dict(model_state_dict) opt.load_state_dict(optimizer_state_dict) itr=checkpoint['itr'] - if not scheduler is None: + if scheduler is not None: scheduler_state_dict = checkpoint['scheduler'] scheduler.load_state_dict(scheduler_state_dict)