From cdefb4388fd594de94820dc6087b832b260bd40a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 10 Sep 2023 13:15:47 -0700 Subject: [PATCH] just port over some logic from st moe --- assert.py | 2 +- setup.py | 2 +- soft_moe_pytorch/soft_moe.py | 70 +++++++++++++++++++++++++++++------- 3 files changed, 59 insertions(+), 15 deletions(-) diff --git a/assert.py b/assert.py index c369ed3..d615560 100644 --- a/assert.py +++ b/assert.py @@ -76,7 +76,7 @@ def start( cleanup() if __name__ == '__main__': - world_size = 4 + world_size = 9 num_experts = 8 batch_size = 2 batch_size_var_len = False diff --git a/setup.py b/setup.py index e5146a0..2ee0811 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soft-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.8', + version = '0.0.9', license='MIT', description = 'Soft MoE - Pytorch', author = 'Phil Wang', diff --git a/soft_moe_pytorch/soft_moe.py b/soft_moe_pytorch/soft_moe.py index bd293bf..558705d 100644 --- a/soft_moe_pytorch/soft_moe.py +++ b/soft_moe_pytorch/soft_moe.py @@ -24,6 +24,16 @@ def default(val, d): def divisible_by(num, den): return (num % den) == 0 +def chunk_num(num, chunks): + num_per_chunk, remainder = divmod(num, chunks) + + out = [] + for i in range(chunks): + n = num_per_chunk + out.append(n + int(i < remainder)) + + return out + def pack_one(t, pattern): return pack([t], pattern) @@ -33,6 +43,12 @@ def unpack_one(t, ps, pattern): def l2norm(t): return F.normalize(t, dim = - 1) +def cumsum_exclusive(t, dim = -3): + assert dim < 0 + num_pad_dims = -dim - 1 + pre_padding = (0, 0) * num_pad_dims + return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim) + # norm class RMSNorm(Module): @@ -139,6 +155,7 @@ def forward( assert has_only_one_value(seq_sizes), 'number of tokens per expert must be the same' x, batch_sizes = self.all_gather(x) + total_batch_size = x.shape[0] world_size = dist.get_world_size() rank = dist.get_rank() @@ -147,18 +164,38 @@ def forward( rank = 0 # the experts in use on the rank - # for now, make sure number of machines is right multiple - if world_size <= num_experts: - assert divisible_by(num_experts, world_size), 'if number of machines is less than the number of experts, the number of experts must be divisible by number of machines' - num_experts_per_rank = num_experts // world_size - expert_start_index = rank * num_experts_per_rank - else: - assert divisible_by(world_size, num_experts), 'if number of machines is greater than number of experts, machines must be divisible by number of experts, so experts are evenly distributed' - num_experts_per_rank = 1 - expert_start_index = rank // (world_size // num_experts) + if is_distributed: + if world_size <= num_experts: + num_experts_across_ranks = chunk_num(num_experts, world_size) + start_indices = cumsum_exclusive(torch.tensor(num_experts_across_ranks), dim = -1) + + num_experts_per_rank = num_experts_across_ranks[rank] + num_experts_batches_across_ranks = tuple(i * total_batch_size for i in num_experts_across_ranks) + + expert_start_index = start_indices[rank].item() + else: + num_batch_chunks = world_size // num_experts + total_ranks_in_use = num_batch_chunks * num_experts + + expert_start_index = rank // num_batch_chunks - expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank) + batch_splits = chunk_num(total_batch_size, num_batch_chunks) + num_experts_batches_across_ranks = batch_splits * num_experts + + # for now, remaining machines just process nothing + + remain_ranks = world_size % num_experts + num_experts_batches_across_ranks += (0,) * remain_ranks + + num_experts_per_rank = int(rank < total_ranks_in_use) + + assert len(num_experts_batches_across_ranks) == world_size + + expert_slice = slice(expert_start_index, expert_start_index + num_experts_per_rank) + else: + num_experts_per_rank = num_experts + expert_slice = slice(0, num_experts) # if distributed, each machine only handles subset of experts and batch @@ -166,9 +203,13 @@ def forward( if is_distributed: x, expert_batch_packed_shape = pack_one(x, '* n d') - x = rearrange(x, '(r eb) n d -> r eb n d', r = world_size) + x = x.split(num_experts_batches_across_ranks, dim = 0) x = split_by_rank(x) - x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank) + + if num_experts_per_rank > 0: + x = rearrange(x, '(e b) n d -> e b n d', e = num_experts_per_rank) + else: + x = x.reshape(num_experts, *x.shape) # get the experts in use @@ -183,7 +224,10 @@ def forward( out = expert(expert_input) outs.append(out) - outs = torch.stack(outs) + if len(outs) > 0: + outs = torch.stack(outs) + else: + outs = torch.empty_like(x).requires_grad_() # all gather across merged expert batches dimensions # then split the batch dimension back