diff --git a/src/config/base.py b/src/config/base.py index 36f732d..a1cae7a 100644 --- a/src/config/base.py +++ b/src/config/base.py @@ -45,4 +45,6 @@ def parse_args(base_parser, args, namespace): # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type + parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) + return parser.parse_args(args, namespace) diff --git a/src/config/sparse.py b/src/config/sparse.py index 80f8d22..d17855f 100644 --- a/src/config/sparse.py +++ b/src/config/sparse.py @@ -47,4 +47,6 @@ def parse_args(base_parser, args, namespace): # Distributed args parser.add_argument('--distributed_backend', default=None, type=str, required=False, choices=distributed.registered_backends()) # distributed backend type + parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) + return parser.parse_args(args, namespace) diff --git a/src/main.py b/src/main.py index 301c1a9..d7554da 100755 --- a/src/main.py +++ b/src/main.py @@ -109,9 +109,31 @@ def main(args): if not os.path.exists(ckpt_path): if distributed_backend.is_master_process(): os.makedirs(ckpt_path) + distributed_backend.sync() elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed print(f"Already found experiment '{ckpt_path}'.\nSkipping.") sys.exit(0) + itr = 0 + rng_state_dict = None + checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] + if checkpoints: + last_ckpt_path = sorted(checkpoints)[-1] + print(f"Training interrupted, resuming from {last_ckpt_path}") + checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path)) + model_state_dict = {k.replace("_orig_mod.", ""):v for k,v in checkpoint['model'].items()} + # FIXME checkpoints from compiled model have _orig_mod keyword + + optimizer_state_dict = checkpoint['optimizer'] + rng_state_dict = { + module: checkpoint[module] for module in ["cpu_rng_state", "gpu_rng_state", "numpy_rng_state", "py_rng_state"] + } + + model.load_state_dict(model_state_dict) + opt.load_state_dict(optimizer_state_dict) + itr=checkpoint['itr'] + if scheduler is not None: + scheduler_state_dict = checkpoint['scheduler'] + scheduler.load_state_dict(scheduler_state_dict) if args.model == 'base': # all train functions have the same interface train = train_base @@ -125,7 +147,7 @@ def main(args): stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, eval_freq=args.eval_freq, distributed_backend=distributed_backend, - ckpt_path=f"{ckpt_path}/ckpt.pt", extra_args=args) + ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args) args.device = None args.dtype = None diff --git a/src/optim/base.py b/src/optim/base.py index e73ecbc..865252c 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -7,15 +7,17 @@ import time import itertools import copy - +import random +import os +import numpy as np from .utils import eval, get_batch, save_checkpoint -def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): +def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None): device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) - itr, substep, best_val_loss, text_table = 0, 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible + substep, best_val_loss, text_table = 0, float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible data["train"], train_sampler = get_dataloader( data["train"], @@ -49,7 +51,13 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba t0 = time.time() train_epochs = 0 + if not rng_state_dict is None: + torch.set_rng_state(rng_state_dict["cpu_rng_state"]) + torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) + np.random.set_state(rng_state_dict["numpy_rng_state"]) + random.setstate(rng_state_dict["py_rng_state"]) while itr < iterations: + for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) with type_ctx: @@ -122,7 +130,20 @@ def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, ba model.train() t0 = time.time() - + if distributed_backend.is_master_process(): + if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: + print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt") + save_checkpoint(distributed_backend=distributed_backend, + model=model, + opt=opt, + scheduler=scheduler, + itr=itr, + cpu_rng_state=torch.get_rng_state(), + gpu_rng_state=torch.cuda.get_rng_state(), + numpy_rng_state=np.random.get_state(), + py_rng_state=random.getstate(), + ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt")) + if distributed_backend.is_master_process(): print(f"saving checkpoint to {ckpt_path}") save_checkpoint(distributed_backend=distributed_backend, diff --git a/src/optim/sparse.py b/src/optim/sparse.py index 729e744..39864d6 100755 --- a/src/optim/sparse.py +++ b/src/optim/sparse.py @@ -5,15 +5,16 @@ import wandb import time import copy - +import numpy as np +import random from .utils import eval_sparse, get_batch, eval_sweep_dropk, save_checkpoint -def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args): +def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend, extra_args, itr=0, rng_state_dict=None): device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( device_type=device_type, dtype=torch.bfloat16) # extra_args.dtype) - itr, substep, best_val_loss, text_table, sparsity_plot = 0, 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible + substep, best_val_loss, text_table, sparsity_plot = 0, float('inf'), None, None # best_val_loss not used atm, early stopping not recommended but possible data["train"] = get_dataloader( data["train"], sequence_length=sequence_length, @@ -42,7 +43,14 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, model.train() t0 = time.time() + if not rng_state_dict is None: + torch.set_rng_state(rng_state_dict["cpu_rng_state"]) + torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) + np.random.set_state(rng_state_dict["numpy_rng_state"]) + random.setstate(rng_state_dict["py_rng_state"]) + while itr < iterations: + for microstep_idx in range(acc_steps): # gradient accumulation x, y = get_batch(data_train_iter, device=extra_args.device) with type_ctx: @@ -129,6 +137,20 @@ def train_sparse(model, opt, data, data_seed, scheduler, iterations, acc_steps, model.train() t0 = time.time() + if distributed_backend.is_master_process(): + if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: + print(f"saving checkpoint to {ckpt_path}/ckpt_{itr}.pt") + save_checkpoint(distributed_backend=distributed_backend, + model=model, + opt=opt, + scheduler=scheduler, + itr=itr, + cpu_rng_state=torch.get_rng_state(), + gpu_rng_state=torch.cuda.get_rng_state(), + numpy_rng_state=np.random.get_state(), + py_rng_state=random.getstate(), + ckpt_path=f"{ckpt_path}/ckpt_{itr}.pt") + if distributed_backend.is_master_process(): print(f"saving checkpoint to {ckpt_path}")