From 8362806d619d6ad5e6c47b97cdbd93a6059acf17 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 27 Jan 2021 22:56:58 -0800 Subject: [PATCH 1/3] Fix for RPC parameter server --- .../parameter_server/rpc_parameter_server.py | 94 ++++++++++--------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/distributed/rpc/parameter_server/rpc_parameter_server.py b/distributed/rpc/parameter_server/rpc_parameter_server.py index a899292f19..024ff34486 100644 --- a/distributed/rpc/parameter_server/rpc_parameter_server.py +++ b/distributed/rpc/parameter_server/rpc_parameter_server.py @@ -12,6 +12,11 @@ from torch.distributed.optim import DistributedOptimizer from torchvision import datasets, transforms +# Constants +TRAINER_LOG_INTERVAL = 5 # How frequently to print out log information +TERMINATE_AT_ITER = 300 # for early stopping when debugging +PS_AVERAGE_EVERY_N = 25 # How often to average models between trainers + # --------- MNIST Network to train, from pytorch/examples ----- @@ -56,40 +61,19 @@ def forward(self, x): return output -# --------- Helper Methods -------------------- - -# On the local node, call a method with first arg as the value held by the -# RRef. Other args are passed in as arguments to the function called. -# Useful for calling instance methods. -def call_method(method, rref, *args, **kwargs): - return method(rref.local_value(), *args, **kwargs) - -# Given an RRef, return the result of calling the passed in method on the value -# held by the RRef. This call is done on the remote node that owns -# the RRef. args and kwargs are passed into the method. -# Example: If the value held by the RRef is of type Foo, then -# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling -# .bar(arg1, arg2) on the remote node and getting the result -# back. - - -def remote_method(method, rref, *args, **kwargs): - args = [method, rref] + list(args) - return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs) - - # --------- Parameter Server -------------------- class ParameterServer(nn.Module): def __init__(self, num_gpus=0): super().__init__() - model = Net(num_gpus=num_gpus) - self.model = model + self.model_copy = Net(num_gpus=num_gpus) + self.num_gpus = num_gpus + self.models = {} self.input_device = torch.device( "cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu") - def forward(self, inp): + def forward(self, rank, inp): inp = inp.to(self.input_device) - out = self.model(inp) + out = self.models[rank](inp) # This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors. # Tensors must be moved in and out of GPU memory due to this. out = out.to("cpu") @@ -109,22 +93,40 @@ def get_dist_gradients(self, cid): # Wrap local parameters in a RRef. Needed for building the # DistributedOptimizer which optimizes parameters remotely. - def get_param_rrefs(self): - param_rrefs = [rpc.RRef(param) for param in self.model.parameters()] + def get_param_rrefs(self, rank): + param_rrefs = [rpc.RRef(param) + for param in self.models[rank].parameters()] return param_rrefs + def create_model_for_rank(self, rank, num_gpus): + assert num_gpus == self.num_gpus, f"Inconsistent no. of GPUs requested from rank vs initialized with on PS: {num_gpus} vs {self.num_gpus}" + if rank not in self.models: + self.models[rank] = Net(num_gpus=num_gpus) + + def average_models(self, rank): + # Load state dict of requested rank + state_dict_for_rank = self.models[rank].state_dict() + # Average all params + for key in state_dict_for_rank: + state_dict_for_rank[key] = sum(self.models[r].state_dict()[ + key] for r in self.models) / len(self.models) + # Rewrite back state dict + self.models[rank].load_state_dict(state_dict_for_rank) + param_server = None global_lock = Lock() -def get_parameter_server(num_gpus=0): +def get_parameter_server(rank, num_gpus=0): global param_server # Ensure that we get only one handle to the ParameterServer. with global_lock: if not param_server: # construct it once param_server = ParameterServer(num_gpus=num_gpus) + # Add model for this rank + param_server.create_model_for_rank(rank, num_gpus) return param_server @@ -147,45 +149,51 @@ def run_parameter_server(rank, world_size): # forward() method simply invokes the network on the given parameter # server. class TrainerNet(nn.Module): - def __init__(self, num_gpus=0): + def __init__(self, rank, num_gpus=0,): super().__init__() self.num_gpus = num_gpus + self.rank = rank self.param_server_rref = rpc.remote( - "parameter_server", get_parameter_server, args=(num_gpus,)) + "parameter_server", get_parameter_server, args=( + self.rank, num_gpus,)) def get_global_param_rrefs(self): - remote_params = remote_method( - ParameterServer.get_param_rrefs, - self.param_server_rref) + remote_params = self.param_server_rref.rpc_sync().get_param_rrefs(self.rank) return remote_params def forward(self, x): - model_output = remote_method( - ParameterServer.forward, self.param_server_rref, x) + model_output = self.param_server_rref.rpc_sync().forward(self.rank, x) return model_output + def average_model_across_trainers(self): + self.param_server_rref.rpc_sync().average_models(self.rank) + def run_training_loop(rank, num_gpus, train_loader, test_loader): # Runs the typical neural network forward + backward + optimizer step, but # in a distributed fashion. - net = TrainerNet(num_gpus=num_gpus) + net = TrainerNet(rank=rank, num_gpus=num_gpus) # Build DistributedOptimizer. param_rrefs = net.get_global_param_rrefs() opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03) for i, (data, target) in enumerate(train_loader): + if TERMINATE_AT_ITER is not None and i == TERMINATE_AT_ITER: + break + if i % PS_AVERAGE_EVERY_N == 0: + # Request server to update model with average params across all + # trainers. + print(f"Rank {rank} averaging model across all trainers.") + net.average_model_across_trainers() with dist_autograd.context() as cid: model_output = net(data) target = target.to(model_output.device) loss = F.nll_loss(model_output, target) - if i % 5 == 0: + if i % TRAINER_LOG_INTERVAL == 0: print(f"Rank {rank} training batch {i} loss {loss.item()}") dist_autograd.backward(cid, [loss]) # Ensure that dist autograd ran successfully and gradients were # returned. - assert remote_method( - ParameterServer.get_dist_gradients, - net.param_server_rref, - cid) != {} + assert net.param_server_rref.rpc_sync().get_dist_gradients(cid) != {} opt.step(cid) print("Training complete!") @@ -198,7 +206,7 @@ def get_accuracy(test_loader, model): correct_sum = 0 # Use GPU to evaluate if possible device = torch.device("cuda:0" if model.num_gpus > 0 - and torch.cuda.is_available() else "cpu") + and torch.cuda.is_available() else "cpu") with torch.no_grad(): for i, (data, target) in enumerate(test_loader): out = model(data) From 0f4c17681e42b0ffa9df2f67c5707c6e3c10c10e Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 27 Jan 2021 23:18:34 -0800 Subject: [PATCH 2/3] Fix --- .../parameter_server/rpc_parameter_server.py | 55 ++++++++++++------- 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/distributed/rpc/parameter_server/rpc_parameter_server.py b/distributed/rpc/parameter_server/rpc_parameter_server.py index 024ff34486..b3e448d8a9 100644 --- a/distributed/rpc/parameter_server/rpc_parameter_server.py +++ b/distributed/rpc/parameter_server/rpc_parameter_server.py @@ -1,6 +1,9 @@ import argparse import os from threading import Lock +import time +import logging +import sys import torch import torch.distributed.autograd as dist_autograd @@ -12,8 +15,11 @@ from torch.distributed.optim import DistributedOptimizer from torchvision import datasets, transforms +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging.getLogger() + # Constants -TRAINER_LOG_INTERVAL = 5 # How frequently to print out log information +TRAINER_LOG_INTERVAL = 5 # How frequently to log information TERMINATE_AT_ITER = 300 # for early stopping when debugging PS_AVERAGE_EVERY_N = 25 # How often to average models between trainers @@ -23,11 +29,11 @@ class Net(nn.Module): def __init__(self, num_gpus=0): super(Net, self).__init__() - print(f"Using {num_gpus} GPUs to train") + logger.info(f"Using {num_gpus} GPUs to train") self.num_gpus = num_gpus device = torch.device( "cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu") - print(f"Putting first 2 convs on {str(device)}") + logger.info(f"Putting first 2 convs on {str(device)}") # Put conv layers on the first cuda device self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device) self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device) @@ -35,7 +41,7 @@ def __init__(self, num_gpus=0): if "cuda" in str(device) and num_gpus > 1: device = torch.device("cuda:1") - print(f"Putting rest of layers on {str(device)}") + logger.info(f"Putting rest of layers on {str(device)}") self.dropout1 = nn.Dropout2d(0.25).to(device) self.dropout2 = nn.Dropout2d(0.5).to(device) self.fc1 = nn.Linear(9216, 128).to(device) @@ -65,9 +71,9 @@ def forward(self, x): class ParameterServer(nn.Module): def __init__(self, num_gpus=0): super().__init__() - self.model_copy = Net(num_gpus=num_gpus) self.num_gpus = num_gpus self.models = {} + self.models_init_lock = Lock() self.input_device = torch.device( "cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu") @@ -100,8 +106,13 @@ def get_param_rrefs(self, rank): def create_model_for_rank(self, rank, num_gpus): assert num_gpus == self.num_gpus, f"Inconsistent no. of GPUs requested from rank vs initialized with on PS: {num_gpus} vs {self.num_gpus}" - if rank not in self.models: - self.models[rank] = Net(num_gpus=num_gpus) + with self.models_init_lock: + if rank not in self.models: + self.models[rank] = Net(num_gpus=num_gpus) + + def get_num_models(self): + with self.models_init_lock: + return len(self.models) def average_models(self, rank): # Load state dict of requested rank @@ -136,11 +147,11 @@ def run_parameter_server(rank, world_size): # rpc.shutdown() will wait for all workers to complete by default, which # in this case means that the parameter server will wait for all trainers # to complete, and then exit. - print("PS master initializing RPC") + logger.info("PS master initializing RPC") rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size) - print("RPC initialized! Running parameter server...") + logger.info("RPC initialized! Running parameter server...") rpc.shutdown() - print("RPC shutdown on parameter server.") + logger.info("RPC shutdown on parameter server.") # --------- Trainers -------------------- @@ -169,10 +180,16 @@ def average_model_across_trainers(self): self.param_server_rref.rpc_sync().average_models(self.rank) -def run_training_loop(rank, num_gpus, train_loader, test_loader): +def run_training_loop(rank, world_size, num_gpus, train_loader, test_loader): # Runs the typical neural network forward + backward + optimizer step, but # in a distributed fashion. net = TrainerNet(rank=rank, num_gpus=num_gpus) + # Wait for all nets on PS to be created. + num_created = net.param_server_rref.rpc_sync().get_num_models() + while num_created != world_size - 1: + time.sleep(0.5) + num_created = net.param_server_rref.rpc_sync().get_num_models() + # Build DistributedOptimizer. param_rrefs = net.get_global_param_rrefs() opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03) @@ -182,22 +199,22 @@ def run_training_loop(rank, num_gpus, train_loader, test_loader): if i % PS_AVERAGE_EVERY_N == 0: # Request server to update model with average params across all # trainers. - print(f"Rank {rank} averaging model across all trainers.") + logger.info(f"Rank {rank} averaging model across all trainers.") net.average_model_across_trainers() with dist_autograd.context() as cid: model_output = net(data) target = target.to(model_output.device) loss = F.nll_loss(model_output, target) if i % TRAINER_LOG_INTERVAL == 0: - print(f"Rank {rank} training batch {i} loss {loss.item()}") + logger.info(f"Rank {rank} training batch {i} loss {loss.item()}") dist_autograd.backward(cid, [loss]) # Ensure that dist autograd ran successfully and gradients were # returned. assert net.param_server_rref.rpc_sync().get_dist_gradients(cid) != {} opt.step(cid) - print("Training complete!") - print("Getting accuracy....") + logger.info("Training complete!") + logger.info("Getting accuracy....") get_accuracy(test_loader, net) @@ -215,20 +232,20 @@ def get_accuracy(test_loader, model): correct = pred.eq(target.view_as(pred)).sum().item() correct_sum += correct - print(f"Accuracy {correct_sum / len(test_loader.dataset)}") + logger.info(f"Accuracy {correct_sum / len(test_loader.dataset)}") # Main loop for trainers. def run_worker(rank, world_size, num_gpus, train_loader, test_loader): - print(f"Worker rank {rank} initializing RPC") + logger.info(f"Worker rank {rank} initializing RPC") rpc.init_rpc( name=f"trainer_{rank}", rank=rank, world_size=world_size) - print(f"Worker {rank} done initializing RPC") + logger.info(f"Worker {rank} done initializing RPC") - run_training_loop(rank, num_gpus, train_loader, test_loader) + run_training_loop(rank, world_size, num_gpus, train_loader, test_loader) rpc.shutdown() # --------- Launcher -------------------- From 8315ac1ef4ed070627c81b161186c077dbb44659 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 27 Jan 2021 23:22:15 -0800 Subject: [PATCH 3/3] Minor changes --- distributed/rpc/parameter_server/rpc_parameter_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/rpc/parameter_server/rpc_parameter_server.py b/distributed/rpc/parameter_server/rpc_parameter_server.py index b3e448d8a9..9285496b77 100644 --- a/distributed/rpc/parameter_server/rpc_parameter_server.py +++ b/distributed/rpc/parameter_server/rpc_parameter_server.py @@ -73,6 +73,8 @@ def __init__(self, num_gpus=0): super().__init__() self.num_gpus = num_gpus self.models = {} + # This lock is only used during init, and does not + # impact training perf. self.models_init_lock = Lock() self.input_device = torch.device( "cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu") @@ -184,7 +186,8 @@ def run_training_loop(rank, world_size, num_gpus, train_loader, test_loader): # Runs the typical neural network forward + backward + optimizer step, but # in a distributed fashion. net = TrainerNet(rank=rank, num_gpus=num_gpus) - # Wait for all nets on PS to be created. + # Wait for all nets on PS to be created, otherwise we could run + # into race conditions during training. num_created = net.param_server_rref.rpc_sync().get_num_models() while num_created != world_size - 1: time.sleep(0.5)