Skip to content

Commit

Permalink
just port over some logic from st moe
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 10, 2023
1 parent fb138b1 commit cdefb43
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
2 changes: 1 addition & 1 deletion assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def start(
cleanup()

if __name__ == '__main__':
world_size = 4
world_size = 9
num_experts = 8
batch_size = 2
batch_size_var_len = False
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soft-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'Soft MoE - Pytorch',
author = 'Phil Wang',
Expand Down
70 changes: 57 additions & 13 deletions soft_moe_pytorch/soft_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def default(val, d):
def divisible_by(num, den):
return (num % den) == 0

def chunk_num(num, chunks):
num_per_chunk, remainder = divmod(num, chunks)

out = []
for i in range(chunks):
n = num_per_chunk
out.append(n + int(i < remainder))

return out

def pack_one(t, pattern):
return pack([t], pattern)

Expand All @@ -33,6 +43,12 @@ def unpack_one(t, ps, pattern):
def l2norm(t):
return F.normalize(t, dim = - 1)

def cumsum_exclusive(t, dim = -3):
assert dim < 0
num_pad_dims = -dim - 1
pre_padding = (0, 0) * num_pad_dims
return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim)

# norm

class RMSNorm(Module):
Expand Down Expand Up @@ -139,6 +155,7 @@ def forward(
assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same'

x, batch_sizes = self.all_gather(x)
total_batch_size = x.shape[0]

world_size = dist.get_world_size()
rank = dist.get_rank()
Expand All @@ -147,28 +164,52 @@ def forward(
rank = 0

# the experts in use on the rank
# for now, make sure number of machines is right multiple

if world_size <= num_experts:
assert divisible_by(num_experts, world_size), 'if number of machines is less than the number of experts, the number of experts must be divisible by number of machines'
num_experts_per_rank = num_experts // world_size
expert_start_index = rank * num_experts_per_rank
else:
assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed'
num_experts_per_rank = 1
expert_start_index = rank // (world_size // num_experts)
if is_distributed:
if world_size <= num_experts:
num_experts_across_ranks = chunk_num(num_experts, world_size)
start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1)

num_experts_per_rank = num_experts_across_ranks[rank]
num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks)

expert_start_index = start_indices[rank].item()
else:
num_batch_chunks = world_size // num_experts
total_ranks_in_use = num_batch_chunks * num_experts

expert_start_index = rank // num_batch_chunks

expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
batch_splits = chunk_num(total_batch_size, num_batch_chunks)
num_experts_batches_across_ranks = batch_splits * num_experts

# for now, remaining machines just process nothing

remain_ranks = world_size % num_experts
num_experts_batches_across_ranks += (0,) * remain_ranks

num_experts_per_rank = int(rank < total_ranks_in_use)

assert len(num_experts_batches_across_ranks) == world_size

expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank)
else:
num_experts_per_rank = num_experts
expert_slice = slice(0, num_experts)

# if distributed, each machine only handles subset of experts and batch

x = rearrange(x, 'b e n d -> e b n d')

if is_distributed:
x, expert_batch_packed_shape = pack_one(x, '* n d')
x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size)
x = x.split(num_experts_batches_across_ranks, dim = 0)
x = split_by_rank(x)
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)

if num_experts_per_rank > 0:
x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank)
else:
x = x.reshape(num_experts, *x.shape)

# get the experts in use

Expand All @@ -183,7 +224,10 @@ def forward(
out = expert(expert_input)
outs.append(out)

outs = torch.stack(outs)
if len(outs) > 0:
outs = torch.stack(outs)
else:
outs = torch.empty_like(x).requires_grad_()

# all gather across merged expert batches dimensions
# then split the batch dimension back
Expand Down

0 comments on commit cdefb43

Please sign in to comment.