Skip to content

Fix RPC Param server example for multiple trainers #877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 85 additions & 57 deletions distributed/rpc/parameter_server/rpc_parameter_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import argparse
import os
from threading import Lock
import time
import logging
import sys

import torch
import torch.distributed.autograd as dist_autograd
Expand All @@ -12,25 +15,33 @@
from torch.distributed.optim import DistributedOptimizer
from torchvision import datasets, transforms

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging.getLogger()

# Constants
TRAINER_LOG_INTERVAL = 5 # How frequently to log information
TERMINATE_AT_ITER = 300 # for early stopping when debugging
PS_AVERAGE_EVERY_N = 25 # How often to average models between trainers

# --------- MNIST Network to train, from pytorch/examples -----


class Net(nn.Module):
def __init__(self, num_gpus=0):
super(Net, self).__init__()
print(f"Using {num_gpus} GPUs to train")
logger.info(f"Using {num_gpus} GPUs to train")
self.num_gpus = num_gpus
device = torch.device(
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
print(f"Putting first 2 convs on {str(device)}")
logger.info(f"Putting first 2 convs on {str(device)}")
# Put conv layers on the first cuda device
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
# Put rest of the network on the 2nd cuda device, if there is one
if "cuda" in str(device) and num_gpus > 1:
device = torch.device("cuda:1")

print(f"Putting rest of layers on {str(device)}")
logger.info(f"Putting rest of layers on {str(device)}")
self.dropout1 = nn.Dropout2d(0.25).to(device)
self.dropout2 = nn.Dropout2d(0.5).to(device)
self.fc1 = nn.Linear(9216, 128).to(device)
Expand All @@ -56,40 +67,21 @@ def forward(self, x):
return output


# --------- Helper Methods --------------------

# On the local node, call a method with first arg as the value held by the
# RRef. Other args are passed in as arguments to the function called.
# Useful for calling instance methods.
def call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)

# Given an RRef, return the result of calling the passed in method on the value
# held by the RRef. This call is done on the remote node that owns
# the RRef. args and kwargs are passed into the method.
# Example: If the value held by the RRef is of type Foo, then
# remote_method(Foo.bar, rref, arg1, arg2) is equivalent to calling
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
# back.


def remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)


# --------- Parameter Server --------------------
class ParameterServer(nn.Module):
def __init__(self, num_gpus=0):
super().__init__()
model = Net(num_gpus=num_gpus)
self.model = model
self.num_gpus = num_gpus
self.models = {}
# This lock is only used during init, and does not
# impact training perf.
self.models_init_lock = Lock()
self.input_device = torch.device(
"cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")

def forward(self, inp):
def forward(self, rank, inp):
inp = inp.to(self.input_device)
out = self.model(inp)
out = self.models[rank](inp)
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
# Tensors must be moved in and out of GPU memory due to this.
out = out.to("cpu")
Expand All @@ -109,22 +101,45 @@ def get_dist_gradients(self, cid):

# Wrap local parameters in a RRef. Needed for building the
# DistributedOptimizer which optimizes parameters remotely.
def get_param_rrefs(self):
param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
def get_param_rrefs(self, rank):
param_rrefs = [rpc.RRef(param)
for param in self.models[rank].parameters()]
return param_rrefs

def create_model_for_rank(self, rank, num_gpus):
assert num_gpus == self.num_gpus, f"Inconsistent no. of GPUs requested from rank vs initialized with on PS: {num_gpus} vs {self.num_gpus}"
with self.models_init_lock:
if rank not in self.models:
self.models[rank] = Net(num_gpus=num_gpus)

def get_num_models(self):
with self.models_init_lock:
return len(self.models)

def average_models(self, rank):
# Load state dict of requested rank
state_dict_for_rank = self.models[rank].state_dict()
# Average all params
for key in state_dict_for_rank:
state_dict_for_rank[key] = sum(self.models[r].state_dict()[
key] for r in self.models) / len(self.models)
# Rewrite back state dict
self.models[rank].load_state_dict(state_dict_for_rank)


param_server = None
global_lock = Lock()


def get_parameter_server(num_gpus=0):
def get_parameter_server(rank, num_gpus=0):
global param_server
# Ensure that we get only one handle to the ParameterServer.
with global_lock:
if not param_server:
# construct it once
param_server = ParameterServer(num_gpus=num_gpus)
# Add model for this rank
param_server.create_model_for_rank(rank, num_gpus)
return param_server


Expand All @@ -134,11 +149,11 @@ def run_parameter_server(rank, world_size):
# rpc.shutdown() will wait for all workers to complete by default, which
# in this case means that the parameter server will wait for all trainers
# to complete, and then exit.
print("PS master initializing RPC")
logger.info("PS master initializing RPC")
rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size)
print("RPC initialized! Running parameter server...")
logger.info("RPC initialized! Running parameter server...")
rpc.shutdown()
print("RPC shutdown on parameter server.")
logger.info("RPC shutdown on parameter server.")


# --------- Trainers --------------------
Expand All @@ -147,49 +162,62 @@ def run_parameter_server(rank, world_size):
# forward() method simply invokes the network on the given parameter
# server.
class TrainerNet(nn.Module):
def __init__(self, num_gpus=0):
def __init__(self, rank, num_gpus=0,):
super().__init__()
self.num_gpus = num_gpus
self.rank = rank
self.param_server_rref = rpc.remote(
"parameter_server", get_parameter_server, args=(num_gpus,))
"parameter_server", get_parameter_server, args=(
self.rank, num_gpus,))

def get_global_param_rrefs(self):
remote_params = remote_method(
ParameterServer.get_param_rrefs,
self.param_server_rref)
remote_params = self.param_server_rref.rpc_sync().get_param_rrefs(self.rank)
return remote_params

def forward(self, x):
model_output = remote_method(
ParameterServer.forward, self.param_server_rref, x)
model_output = self.param_server_rref.rpc_sync().forward(self.rank, x)
return model_output

def average_model_across_trainers(self):
self.param_server_rref.rpc_sync().average_models(self.rank)


def run_training_loop(rank, num_gpus, train_loader, test_loader):
def run_training_loop(rank, world_size, num_gpus, train_loader, test_loader):
# Runs the typical neural network forward + backward + optimizer step, but
# in a distributed fashion.
net = TrainerNet(num_gpus=num_gpus)
net = TrainerNet(rank=rank, num_gpus=num_gpus)
# Wait for all nets on PS to be created, otherwise we could run
# into race conditions during training.
num_created = net.param_server_rref.rpc_sync().get_num_models()
while num_created != world_size - 1:
time.sleep(0.5)
num_created = net.param_server_rref.rpc_sync().get_num_models()

# Build DistributedOptimizer.
param_rrefs = net.get_global_param_rrefs()
opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
for i, (data, target) in enumerate(train_loader):
if TERMINATE_AT_ITER is not None and i == TERMINATE_AT_ITER:
break
if i % PS_AVERAGE_EVERY_N == 0:
# Request server to update model with average params across all
# trainers.
logger.info(f"Rank {rank} averaging model across all trainers.")
net.average_model_across_trainers()
with dist_autograd.context() as cid:
model_output = net(data)
target = target.to(model_output.device)
loss = F.nll_loss(model_output, target)
if i % 5 == 0:
print(f"Rank {rank} training batch {i} loss {loss.item()}")
if i % TRAINER_LOG_INTERVAL == 0:
logger.info(f"Rank {rank} training batch {i} loss {loss.item()}")
dist_autograd.backward(cid, [loss])
# Ensure that dist autograd ran successfully and gradients were
# returned.
assert remote_method(
ParameterServer.get_dist_gradients,
net.param_server_rref,
cid) != {}
assert net.param_server_rref.rpc_sync().get_dist_gradients(cid) != {}
opt.step(cid)

print("Training complete!")
print("Getting accuracy....")
logger.info("Training complete!")
logger.info("Getting accuracy....")
get_accuracy(test_loader, net)


Expand All @@ -198,7 +226,7 @@ def get_accuracy(test_loader, model):
correct_sum = 0
# Use GPU to evaluate if possible
device = torch.device("cuda:0" if model.num_gpus > 0
and torch.cuda.is_available() else "cpu")
and torch.cuda.is_available() else "cpu")
with torch.no_grad():
for i, (data, target) in enumerate(test_loader):
out = model(data)
Expand All @@ -207,20 +235,20 @@ def get_accuracy(test_loader, model):
correct = pred.eq(target.view_as(pred)).sum().item()
correct_sum += correct

print(f"Accuracy {correct_sum / len(test_loader.dataset)}")
logger.info(f"Accuracy {correct_sum / len(test_loader.dataset)}")


# Main loop for trainers.
def run_worker(rank, world_size, num_gpus, train_loader, test_loader):
print(f"Worker rank {rank} initializing RPC")
logger.info(f"Worker rank {rank} initializing RPC")
rpc.init_rpc(
name=f"trainer_{rank}",
rank=rank,
world_size=world_size)

print(f"Worker {rank} done initializing RPC")
logger.info(f"Worker {rank} done initializing RPC")

run_training_loop(rank, num_gpus, train_loader, test_loader)
run_training_loop(rank, world_size, num_gpus, train_loader, test_loader)
rpc.shutdown()

# --------- Launcher --------------------
Expand Down