diff --git a/src/optim/base.py b/src/optim/base.py index fa176f0..018934a 100755 --- a/src/optim/base.py +++ b/src/optim/base.py @@ -40,7 +40,8 @@ def train_base(model, opt, data, scheduler, iterations, acc_steps, batch_size, s if extra_args.grad_clip != 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) opt.step() - scheduler.step() + if scheduler != None: + scheduler.step() opt.zero_grad(set_to_none=True) itr += 1 diff --git a/src/optim/sparse.py b/src/optim/sparse.py index b223f17..6669174 100755 --- a/src/optim/sparse.py +++ b/src/optim/sparse.py @@ -37,7 +37,8 @@ def train_sparse(model, opt, data, scheduler, iterations, acc_steps, batch_size, substep += 1 opt.step() - scheduler.step() + if scheduler != None: + scheduler.step() opt.zero_grad(set_to_none=True) itr += 1 diff --git a/src/optim/utils.py b/src/optim/utils.py index dd00621..957fad8 100755 --- a/src/optim/utils.py +++ b/src/optim/utils.py @@ -115,7 +115,7 @@ def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, checkpoint = dict({ 'model': distributed_backend.get_raw_model(model).state_dict(), 'optimizer': opt.state_dict(), - 'scheduler': scheduler.state_dict(), + 'scheduler': scheduler.state_dict() if scheduler != None else None, 'itr': itr, }, **extra_args)