diff --git a/setup.py b/setup.py index f9fb421..31e2682 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soft-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'Soft MoE - Pytorch', author = 'Phil Wang', diff --git a/soft_moe_pytorch/soft_moe.py b/soft_moe_pytorch/soft_moe.py index c62869a..0b196b4 100644 --- a/soft_moe_pytorch/soft_moe.py +++ b/soft_moe_pytorch/soft_moe.py @@ -91,7 +91,7 @@ def __init__( expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts) ]) - def forward(self, x): + def forward(self, x, mask = None): """ einstein notation b - batch @@ -114,6 +114,14 @@ def forward(self, x): logits = einsum('b n d, e s d -> b n e s', x, slot_embeds) + # account for key padding mask + + if exists(mask): + mask = rearrange(mask, 'b n -> b n 1 1') + logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max) + + # get dispatch and combine weights (softmax across right dimensions) + dispatch_weights = logits.softmax(dim = 1) combine_weights = rearrange(logits, 'b n e s -> b n (e s)')