Skip to content

Commit

Permalink
expose an extra flag
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 18, 2023
1 parent 17e4d2c commit 4dc07b0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soft-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'Soft MoE - Pytorch',
author = 'Phil Wang',
Expand Down
15 changes: 12 additions & 3 deletions soft_moe_pytorch/soft_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class Experts(nn.Module):
def __init__(
self,
experts,
is_distributed = None
is_distributed = None,
offload_unused_experts_to_cpu = True
):
super().__init__()
self.num_experts = len(experts)
Expand All @@ -117,6 +118,9 @@ def __init__(
if not exists(self.is_distributed):
self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1

# whether to offload unused experts to cpu, will require optimizer handles conversion of gradients to right device when accumulating
self.offload_unused_experts_to_cpu = offload_unused_experts_to_cpu

self.all_gather = AllGather()
self.register_buffer('dummy', torch.ones(1), persistent = False)

Expand All @@ -125,6 +129,9 @@ def device(self):
return self.dummy.device

def all_experts_to_cpu_besides(self, selection):
if not self.offload_unused_experts_to_cpu:
return

if isinstance(selection, int):
experts = [self.experts[selection]]
if isinstance(selection, slice):
Expand Down Expand Up @@ -266,7 +273,8 @@ def __init__(
expert_mult = 4,
dropout = 0.,
geglu = False,
is_distributed = None
is_distributed = None,
offload_unused_experts_to_cpu = True
):
super().__init__()
assert exists(seq_len) ^ exists(num_slots), 'either seq_len, or num_slots must be passed into SoftMoE'
Expand All @@ -282,7 +290,8 @@ def __init__(

self.experts = Experts(
experts = [expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts)],
is_distributed = is_distributed
is_distributed = is_distributed,
offload_unused_experts_to_cpu = offload_unused_experts_to_cpu
)

def forward(self, x, mask = None, add_noise = False, noise_mult = 1.):
Expand Down

0 comments on commit 4dc07b0

Please sign in to comment.