Skip to content

Commit

Permalink
More typo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 10, 2024
1 parent 72e27c6 commit 81e01ef
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
14 changes: 12 additions & 2 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
return_softmax,
block_table,
Expand All @@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
causal,
window_size[0],
window_size[1],
softcap,
return_softmax,
None,
)
Expand Down Expand Up @@ -300,6 +302,7 @@ def forward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
Expand All @@ -318,6 +321,7 @@ def forward(
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
Expand All @@ -328,6 +332,7 @@ def forward(
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
Expand Down Expand Up @@ -355,12 +360,13 @@ def backward(ctx, dout, *args):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None, None


class FlashAttnKVPackedFunc(torch.autograd.Function):
Expand All @@ -373,6 +379,7 @@ def forward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
Expand All @@ -387,6 +394,7 @@ def forward(
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
Expand All @@ -395,6 +403,7 @@ def forward(
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
Expand All @@ -419,13 +428,14 @@ def backward(ctx, dout, *args):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None


class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
Expand All @@ -318,6 +319,7 @@ def attention_kvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)

Expand All @@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
Expand All @@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)

Expand Down

0 comments on commit 81e01ef

Please sign in to comment.