Skip to content

Commit

Permalink
account for key-padding mask
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2023
1 parent a505e50 commit c32ed28
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soft-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Soft MoE - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 9 additions & 1 deletion soft_moe_pytorch/soft_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
expert_klass(dim = dim, mult = expert_mult, dropout = dropout) for _ in range(num_experts)
])

def forward(self, x):
def forward(self, x, mask = None):
"""
einstein notation
b - batch
Expand All @@ -114,6 +114,14 @@ def forward(self, x):

logits = einsum('b n d, e s d -> b n e s', x, slot_embeds)

# account for key padding mask

if exists(mask):
mask = rearrange(mask, 'b n -> b n 1 1')
logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)

# get dispatch and combine weights (softmax across right dimensions)

dispatch_weights = logits.softmax(dim = 1)

combine_weights = rearrange(logits, 'b n e s -> b n (e s)')
Expand Down

0 comments on commit c32ed28

Please sign in to comment.