Skip to content

Commit

Permalink
Merge branch 'master' into bug4813
Browse files Browse the repository at this point in the history
  • Loading branch information
jxysoft committed Dec 15, 2023
2 parents 9e099f4 + 7914b19 commit a612e48
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import deepspeed
import deepspeed.comm as dist
import deepspeed.runtime.utils as ds_utils
from deepspeed.runtime.utils import required_torch_version
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec

Expand Down Expand Up @@ -111,8 +112,11 @@ def cifar_trainset(fp16=False):


def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()],
device_type=get_accelerator().device_name()):
if required_torch_version(min_version=2.1):
fork_kwargs = {"device_type": get_accelerator().device_name()}
else:
fork_kwargs = {}
with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs):
ds_utils.set_random_seed(seed)

# disable dropout
Expand Down

0 comments on commit a612e48

Please sign in to comment.