Skip to content

[Pytorch] CP + THD + chunked attention support. #1887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def run_dpa_with_cp(
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
chunk_size=config.chunk_size,
)
core_attn = core_attn.cuda()

Expand Down Expand Up @@ -284,6 +285,7 @@ def run_dpa_with_cp(
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)

if fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
Expand Down Expand Up @@ -401,8 +403,12 @@ def _error(a, b):
_error(a[0], b[0])
_error(a[1], b[1])
elif qkv_format == "thd":
i = 0
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a, b)
str_names = ["out_", "dq_", "dk_", "dv_"]
print(f"{str_names[i]} passed on rank {rank}")
i += 1
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
chunk_size: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
Expand All @@ -110,6 +111,7 @@ def __init__(
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.chunk_size = chunk_size


@contextmanager
Expand Down
8 changes: 8 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024
), # MHA with chunks
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
Expand Down Expand Up @@ -100,6 +103,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_1_5": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
Expand Down Expand Up @@ -144,6 +150,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format != "thd" and config.chunk_size is not None:
pytest.skip("Only THD format supports chunking!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
Expand Down
Loading