From eea9a09e01526e415fb20f0af239bd7df3fbd93b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 6 Jun 2025 11:31:03 +0000 Subject: [PATCH 1/5] code drop Signed-off-by: Pawel Gadzinski --- tests/pytorch/fused_attn/test_fused_attn.py | 3 +- .../fused_attn/test_fused_attn_with_cp.py | 1 + .../dot_product_attention/backends.py | 4 + .../dot_product_attention/context_parallel.py | 34 ++- .../dot_product_attention.py | 15 ++ .../attention/dot_product_attention/utils.py | 203 ++++++++++++++++++ 6 files changed, 254 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 6ce8637bc7..807e7d8c48 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -90,6 +90,7 @@ def __init__( window_size: Tuple[int, int] = (-1, -1), total_requests: int = None, max_ctx_len: int = None, + chunk_size: int = None, ): self.batch_size = batch_size self.num_heads = num_heads @@ -110,7 +111,7 @@ def __init__( self.window_size = window_size self.total_requests = total_requests self.max_ctx_len = max_ctx_len - + self.chunk_size = chunk_size @contextmanager def logging_context(highest_level=logging.WARNING): diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index b17c85327c..249f465fd0 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -24,6 +24,7 @@ "cp_1_3": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) ), # MHA + "cp_1_4": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024), # MHA with chunks "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 9feef64210..48bc352d64 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -465,6 +465,7 @@ def forward( max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + chunk_size: Optional[int] = None, alibi_slopes: Optional[torch.Tensor] = None, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, @@ -677,6 +678,7 @@ def forward( attn_mask_type=attn_mask_type, deterministic=self.deterministic, window_size=window_size, + chunk_size=chunk_size, quantizers=quantizers, pad_between_seqs=False, use_flash_attn_3=use_flash_attn_3, @@ -1426,6 +1428,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + chunk_size: Optional[int] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1578,6 +1581,7 @@ def forward( cp_stream, cp_comm_type, softmax_scale=self.softmax_scale, + chunk_size=chunk_size, qkv_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f9a5d02496..e6127e54f5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -444,6 +444,7 @@ def forward( cu_seqlens_kv_padded, dropout_p, softmax_scale, + chunk_size, qkv_format, attn_mask_type, attn_bias_type, @@ -722,6 +723,8 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q cu_seqlens_kv_per_step[i] = cu_seqlens_kv + cu_seqlens_q_padded_per_step = cu_seqlens_q_padded + cu_seqlens_kv_padded_per_step = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -738,6 +741,12 @@ def forward( ) elif qkv_format == "thd": q_inputs[i % 2] = q + if chunk_size is not None: + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step = dpa_utils.thd_chunkify( + cu_seqlens_q, cu_seqlens_q_padded, None, chunk_size, True, rank, cp_size) + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step = dpa_utils.thd_chunkify( + cu_seqlens_kv, cu_seqlens_kv_padded, None, chunk_size, True, rank, cp_size) + if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size @@ -791,8 +800,8 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step, **fp8_meta_kwargs, ) if fp8: @@ -855,6 +864,9 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half + + cu_seqlens_q_padded_per_step = cu_seqlens_q_padded + cu_seqlens_kv_padded_per_step = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -871,6 +883,16 @@ def forward( kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) + + + if chunk_size is not None: + assert q_inputs[i % 2].shape[1] == 2 * kv_inputs[i % 2].shape[1], \ + "THD+chunking is not supported for cross attention - initial q length should be the same as initial kv length" + cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step =\ + dpa_utils.thd_seq_tweak_below_diagonal( + cu_seqlens_q, cu_seqlens_q_padded, chunk_size, rank, cp_size + ) + if use_fused_attention: kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: @@ -918,11 +940,11 @@ def forward( attn_mask_type="padding" if padding else "no_mask", attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded=( None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 + if cu_seqlens_q_padded_per_step is None + else cu_seqlens_q_padded_per_step // 2 ), **fp8_meta_kwargs, ) @@ -3477,6 +3499,7 @@ def attn_forward_func_with_cp( deterministic=False, use_fused_attention=False, window_size=None, + chunk_size=None, fp8=False, fp8_meta=None, quantizers=None, @@ -3597,6 +3620,7 @@ def attn_forward_func_with_cp( cu_seqlens_kv_padded, dropout_p, softmax_scale, + chunk_size, qkv_format, attn_mask_type, attn_bias_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7d50b9fa54..835d4d1459 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -217,6 +217,7 @@ def __init__( tp_group: Optional[dist_group_type] = None, layer_number: Optional[int] = None, attention_type: str = "self", + chunk_size: Optional[int] = None, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, @@ -264,6 +265,9 @@ def __init__( num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" + assert chunk_size is None or qkv_format == "thd", "Chunk size is only supported for thd format" + self.chunk_size = chunk_size + self.rng_states_tracker = None if sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext @@ -747,6 +751,15 @@ def forward( assert ( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + + if self.chunk_size is not None and self.cp_group is not None: + # todo: check if this condition is correct + cu_seqlens_q, cu_seqlens_q_padded = dpa_utils.thd_chunkify( + cu_seqlens_q, cu_seqlens_q_padded, 0, self.chunk_size) + cu_seqlens_kv, cu_seqlens_kv_padded = dpa_utils.thd_chunkify( + cu_seqlens_kv, cu_seqlens_kv_padded, 0, self.chunk_size) + + batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: @@ -760,6 +773,7 @@ def forward( else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + # update KV cache and retrieve saved tokens from cache for inference if inference_params is not None: @@ -947,6 +961,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + chunk_size=self.chunk_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 82fc04a69a..bb2d321370 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -190,6 +190,8 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size: Tuple[int, int], default = None Sliding window attention size. + chunk_size: int, default = None + Chunk size for context parallelism. alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type: str, default = `no_bias` @@ -229,6 +231,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + chunk_size: int = None alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -1793,3 +1796,203 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ) return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + +def thd_chunkify( + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + start_idx: Optional[torch.Tensor] = None, + chunk_size: int = None, + cp_load_balance: bool = False, + cp_rank: Optional[int] = None, + cp_size: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Chunkify the cu_seqlens tensor. + Returns new cu_seqlens, cu_seqlens_padded tensors + + First and last chunks in every sequence can be not full. + """ + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + pad_seq_lens = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + pad_lens = pad_seq_lens - seq_lens + new_cu_seqlens = [] + new_cu_seqlens_padded = [] + + + new_seq_lens = [] + new_seq_lens_padded = [] + for i in range(cu_seqlens.size(0)): + seq_len = seq_lens[i] + new_seq_len = 0 + pad_len = pad_lens[i] + if start_idx is None: + start_id = seq_len // 2 * cp_rank + else: + start_id = start_idx[i] + + # first chunk + first_chunk_lenght = (chunk_size - start_id) % chunk_size + + if cp_load_balance: + first_chunk_lenght = min(first_chunk_lenght, seq_len // 2) + first_chunk_lenght = min(first_chunk_lenght, seq_len) + + new_seq_lens.append(first_chunk_lenght) + new_seq_lens_padded.append(first_chunk_lenght) + new_seq_len += first_chunk_lenght + + if new_seq_len == seq_len: + continue + + while True: + if cp_load_balance: + if new_seq_len + chunk_size > seq_len // 2: + break + else: + if new_seq_len + chunk_size > seq_len: + break + + new_seq_lens.append(chunk_size) + new_seq_lens_padded.append(chunk_size) + new_seq_len += chunk_size + + if cp_load_balance: + last_token_first_part_id = start_id + seq_len // 2 - 1 + total_seq_size = seq_len * cp_size // 2 + first_token_second_part_id = total_seq_size - last_token_first_part_id # is symmetrical to last token of first part with respect to the middle of the sequence + + last_chunk_of_first_part_id = last_token_first_part_id // chunk_size + first_chunk_of_second_part_id = first_token_second_part_id // chunk_size + + + extend_last_chunk = last_chunk_of_first_part_id == first_chunk_of_second_part_id + + # first chunk of second part + first_chunk_lenght = (chunk_size - first_token_second_part_id) % chunk_size + + first_chunk_lenght = min(first_chunk_lenght, seq_len // 2) + + if extend_last_chunk: + new_seq_lens[-1] += first_chunk_lenght + new_seq_lens_padded[-1] += first_chunk_lenght + else: + new_seq_lens.append(first_chunk_lenght) + new_seq_lens_padded.append(first_chunk_lenght) + new_seq_len += first_chunk_lenght + + while True: + if new_seq_len + chunk_size > seq_len // 2: + break + + new_seq_lens.append(chunk_size) + new_seq_lens_padded.append(chunk_size) + new_seq_len += chunk_size + + last_chunk_lenght = seq_len - new_seq_len + new_seq_lens.append(last_chunk_lenght) + new_seq_len += last_chunk_lenght + # add last_chunk + padding to new_seq_lens_padded + new_seq_lens_padded.append(last_chunk_lenght + pad_len) + assert new_seq_len == seq_len + + new_cu_seqlens = torch.cumsum(torch.tensor(new_seq_lens), dim=0) + new_cu_seqlens_padded = torch.cumsum(torch.tensor(new_seq_lens_padded), dim=0) + + return new_cu_seqlens, new_cu_seqlens_padded + + + +@jit_fuser +def thd_seq_tweak_below_diagonal( + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + cp_size: int, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + pad_lens = ( + cu_seqlens_padded[1:] + - cu_seqlens_padded[:-1] + - seq_lens + ) + + seq_plus_pad = seq_lens + pad_lens + half_seq_lens = seq_plus_pad // 2 + + last_kv_first_half_id = cp_rank * (half_seq_lens + 1) - 1 + last_kv_chunk_pos = half_seq_lens - (half_seq_lens % chunk_size) + last_kv_chunk_id = last_kv_first_half_id // chunk_size + + first_half_q_first_chunk_id = (cp_rank * seq_plus_pad) // chunk_size + first_half_q_second_chunk_id = ( + (2 * cp_size - 1 - cp_rank) * half_seq_lens + ) // chunk_size + + first_half_q_first_chunk_last_pos = torch.min( + half_seq_lens, (first_half_q_first_chunk_id * chunk_size) % half_seq_lens) + second_half_q_first_chunk_last_pos = torch.min( + seq_lens, half_seq_lens + (first_half_q_second_chunk_id * chunk_size) % half_seq_lens) + + take_0 = last_kv_chunk_id != first_half_q_first_chunk_last_pos + take_first_half_q = last_kv_chunk_id != second_half_q_first_chunk_last_pos + take_second_half_q = (~take_0) & (~take_first_half_q) + + chunk_end_q_seqs = torch.zeros_like(seq_lens) + chunk_end_q_seqs[take_0] = 0 + chunk_end_q_seqs[take_first_half_q] = first_half_q_first_chunk_last_pos + chunk_end_q_seqs[take_second_half_q] = second_half_q_first_chunk_last_pos + + chunk_start_kv_seqs = torch.zeros_like(seq_lens) + chunk_start_kv_seqs[~take_0] = last_kv_chunk_pos[~take_0] + + # Helper aliases + zeros_like = torch.zeros_like + minimum = torch.minimum + + # 1. Build per-sequence chunk sizes + # Q chunks : [0, first_part, second_part] + # KV chunks : [first_part, second_part, 0 third_part] + q_0 = zeros_like(seq_lens) + q_1 = chunk_end_q_seqs + q_2 = seq_plus_pad - chunk_end_q_seqs + q_3 = zeros_like(seq_lens) + q_chunks = torch.stack((q_0, q_1, q_2, q_3), dim=1) + + kv_0 = chunk_start_kv_seqs + kv_1 = zeros_like(seq_lens) + kv_2 = half_seq_lens - chunk_start_kv_seqs + kv_3 = half_seq_lens + kv_chunks = torch.stack((kv_0, kv_1, kv_2, kv_3), dim=1) + + # 2. Padded variants – keep padding only in the final chunk + q_chunks_padded = torch.stack( + (q_0, q_1, q_2, seq_plus_pad - q_0 - q_1 - q_2), dim=1 + ) + kv_chunks_padded = torch.stack( + (kv_0, kv_1, kv_2, seq_plus_pad - kv_0 - kv_1 - kv_2), dim=1 + ) + + cu_seqlens_q_per_step = q_chunks.flatten().cumsum(0) + cu_seqlens_kv_per_step = kv_chunks.flatten().cumsum(0) + cu_seqlens_q_padded_per_step = q_chunks_padded.flatten().cumsum(0) + cu_seqlens_kv_padded_per_step = kv_chunks_padded.flatten().cumsum(0) + + return cu_seqlens_q_per_step, cu_seqlens_kv_per_step, cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step + + +@jit_fuser +def thd_seq_tweak_above_diagonal( + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + cp_size: int, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + pad_lens = ( + cu_seqlens_padded[1:] + - cu_seqlens_padded[:-1] + - seq_lens \ No newline at end of file From ff8f27c09f1e1d7e9a36d60fb5dcbb85e1dd953a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 16 Jun 2025 14:04:49 +0200 Subject: [PATCH 2/5] code drop Signed-off-by: Pawel Gadzinski --- .../fused_attn/run_fused_attn_with_cp.py | 10 +- .../fused_attn/test_fused_attn_with_cp.py | 3 + .../common/fused_attn/context_parallel.cu | 592 ++++++++++++++++++ .../include/transformer_engine/fused_attn.h | 73 +++ .../dot_product_attention/context_parallel.py | 95 +-- .../dot_product_attention.py | 11 +- .../attention/dot_product_attention/utils.py | 326 ++++++---- .../pytorch/cpp_extensions/fused_attn.py | 1 + transformer_engine/pytorch/csrc/common.cpp | 5 +- .../pytorch/csrc/extensions/attention.cpp | 145 +++++ 10 files changed, 1097 insertions(+), 164 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index ad3bc32079..918f878c03 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -41,7 +41,7 @@ def run_dpa_with_cp( if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" config = model_configs_fused_attn[model] - + assert config.attn_mask_type in [ "causal", "no_mask", @@ -95,6 +95,7 @@ def run_dpa_with_cp( qkv_format=qkv_format, attn_mask_type=config.attn_mask_type, window_size=config.window_size, + chunk_size=config.chunk_size, ) core_attn = core_attn.cuda() @@ -272,6 +273,7 @@ def run_dpa_with_cp( else: fp8_context = nullcontext() + with fp8_context: out_ = core_attn( q_, @@ -295,6 +297,8 @@ def run_dpa_with_cp( assert isinstance(out_, Float8Tensor) out = out.dequantize() out_ = out_.dequantize() + + for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) @@ -401,8 +405,12 @@ def _error(a, b): _error(a[0], b[0]) _error(a[1], b[1]) elif qkv_format == "thd": + i = 0 for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): _error(a, b) + str_names = ["out_", "dq_", "dk_", "dv_"] + print(f"{str_names[i]} passed on rank {rank}") + i += 1 else: assert False, f"{qkv_format} is an unsupported qkv_format!" diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 249f465fd0..e02b8339a0 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -101,6 +101,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # MHA + "cp_1_5": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024), # MHA "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA @@ -133,6 +134,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha config = model_configs_fused_attn[model] if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") + if qkv_format != "thd" and config.chunk_size is not None: + pytest.skip("Only THD format supports chunking!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if qkv_format == "thd" and "a2a" in cp_comm_type: diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index e340242c63..4c164cbb48 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -302,6 +302,373 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in } } + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks. + **************************************************************************************************/ + +__global__ void thd_chunkify_kernel(const int* __restrict__ d_cu_seqlens, + const int* __restrict__ d_cu_seqlens_padded, + int* __restrict__ d_out_cu_seqlens, + int* __restrict__ d_out_cu_seqlens_padded, + int batch, // = len(cu_seqlens)-1 + int output_len, + int chunk_size) +{ + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= output_len) return; + if (i == 0) + { + d_out_cu_seqlens[0] = 0; + d_out_cu_seqlens_padded[0] = 0; + return; + } + + int pos_id = 0; + int seq_len = 0; + int pad_len = 0; + int seq_start = d_cu_seqlens[batch]; + int pad_start = d_cu_seqlens_padded[batch]; + int cur_start = 0; + int cur_last = -1; + bool found = false; + + for (int j = 0; j < batch; ++j) + { + cur_start = (j == 0) ? 1 : cur_last + 1; + + int d_cu_seqlens_padded_j = d_cu_seqlens_padded[j]; + int d_cu_seqlens_padded_j_1 = d_cu_seqlens_padded[j + 1]; + int d_cu_seqlens_j = d_cu_seqlens[j]; + int d_cu_seqlens_j_1 = d_cu_seqlens[j + 1]; + + int num_chunks = (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j + (chunk_size - 1)) / chunk_size; + cur_last = cur_start + num_chunks - 1; + + bool match = (i >= cur_start) && (i <= cur_last); + pos_id = match ? (i - cur_start) : pos_id; + seq_len = match ? (d_cu_seqlens_j_1 - d_cu_seqlens_j) : seq_len; + pad_len = match ? (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j) : pad_len; + seq_start = match ? d_cu_seqlens_j : seq_start; + pad_start = match ? d_cu_seqlens_padded_j : pad_start; + found = match || found; + } + + if (!found) + { + d_out_cu_seqlens[i] = d_cu_seqlens[batch]; + d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; + } + else{ + + int32_t out_seq = ((pos_id > (seq_len / chunk_size)) ? seq_len : chunk_size * pos_id) + seq_start; + int32_t out_pad = ((pos_id > (pad_len / chunk_size)) ? pad_len : chunk_size * pos_id) + pad_start; + + d_out_cu_seqlens[i] = out_seq; + d_out_cu_seqlens_padded[i] = out_pad; + } +} + + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + + +__global__ void thd_chunkify_p2p_kernel( + const int* __restrict__ d_cu_seqlens, + const int* __restrict__ d_cu_seqlens_padded, + int* __restrict__ d_out_cu_seqlens, + int* __restrict__ d_out_cu_seqlens_padded, + int batch, + int output_len, + int chunk_size, + int cp_rank, + int cp_size) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= output_len) return; + if (i == 0) + { + d_out_cu_seqlens[0] = 0; + d_out_cu_seqlens_padded[0] = 0; + return; + } + + int pos_id = 0; + int seq_len = 0; + int total_seq_len = 0; + int seq_start_offset = d_cu_seqlens[batch]; + int pad_start_offset = d_cu_seqlens_padded[batch]; + int cur_start = 1; + int cur_last = 0; + + for (int j = 0; j < batch; ++j) { + cur_start = (j == 0) ? 1 : cur_last + 1; + int padded_len = d_cu_seqlens_padded[j + 1] - d_cu_seqlens_padded[j]; + int num_chunks = (padded_len + chunk_size - 1) / chunk_size; + cur_last = cur_start + num_chunks + 3; + if (i >= cur_start && (i <= cur_last || j == batch - 1)) { + pos_id = i - cur_start; + seq_len = d_cu_seqlens[j + 1] - d_cu_seqlens[j]; + total_seq_len = padded_len; + seq_start_offset = d_cu_seqlens[j]; + pad_start_offset = d_cu_seqlens_padded[j]; + break; + } + } + + + if (total_seq_len == 0) { + d_out_cu_seqlens[i] = d_cu_seqlens[batch]; + d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; + return; + } + + int middle = total_seq_len / 2; + + int start_id_1 = (total_seq_len * cp_rank) / 2; + int temp = (chunk_size - start_id_1 - 1) % chunk_size; + int first_chunk_size_1 = ((temp < 0) ? temp + chunk_size : temp) + 1; + first_chunk_size_1 = (first_chunk_size_1 >= middle) ? 0 : first_chunk_size_1; + int num_chunks_1 = (middle - first_chunk_size_1 - 1) / chunk_size; + int last_chunk_size_1 = middle - num_chunks_1 * chunk_size - first_chunk_size_1; + int last_token_id_1 = start_id_1 + middle - 1; + int last_chunk_id_1 = last_token_id_1 / chunk_size; + + int start_id_2 = total_seq_len * cp_size - last_token_id_1 - 1; + int first_chunk_id_2 = start_id_2 / chunk_size; + int temp2 = (chunk_size - start_id_2 - 1) % chunk_size; + int first_chunk_size_2 = ((temp2 < 0) ? temp2 + chunk_size : temp2) + 1; + int num_chunks_2 = (middle - first_chunk_size_2) / chunk_size; + + bool merge_chunks = (last_chunk_id_1 == first_chunk_id_2); + int middle_chunk_1 = merge_chunks ? (last_chunk_size_1 + first_chunk_size_2) : last_chunk_size_1; + int middle_chunk_2 = merge_chunks ? 0 : first_chunk_size_2; + + int out_pad_len = 0; + if (pos_id == 0) { + out_pad_len = first_chunk_size_1; + } else if (pos_id <= num_chunks_1 + 1) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size; + } else if (pos_id == num_chunks_1 + 2) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1; + } else if (pos_id == num_chunks_1 + 3) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2; + } else if (pos_id <= num_chunks_1 + 3 + num_chunks_2) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2 + num_chunks_2 * chunk_size; + } else { + out_pad_len = total_seq_len; + } + + int out_seq_len = (out_pad_len < seq_len) ? out_pad_len : seq_len; + + d_out_cu_seqlens[i] = seq_start_offset + out_seq_len; + d_out_cu_seqlens_padded[i] = pad_start_offset + out_pad_len; +} + + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + **************************************************************************************************/ + +__global__ void thd_seq_tweak_below_diag_kernel( + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ cu_seqlens_kv_halfs, + const int* __restrict__ cu_seqlens_padded, + int* __restrict__ q_chunks, + int* __restrict__ kv_chunks, + int* __restrict__ q_pads, + int* __restrict__ kv_pads, + int cp_rank_q, + int cp_rank_kv, + int cp_size, + int chunk_size, + int batch) +{ + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= batch) return; + + // ───────── prefix-sum diffs ──────────────────────────────── + const int32_t q_start = cu_seqlens_q[i]; + const int32_t q_end = cu_seqlens_q[i + 1]; + const int32_t kv_start = cu_seqlens_kv_halfs[i]; + const int32_t kv_end = cu_seqlens_kv_halfs[i + 1]; + const int32_t pad_start = cu_seqlens_padded[i]; + const int32_t pad_end = cu_seqlens_padded[i + 1]; + + const int32_t seq_len_q = q_end - q_start; + const int32_t seq_len_kv_half = kv_end - kv_start; + const int32_t seq_plus_pad = pad_end - pad_start; + const int32_t half_seq_len = seq_plus_pad >> 1; + const int32_t pad_len_q = seq_plus_pad - seq_len_q; + const int32_t pad_len_kv = half_seq_len - seq_len_kv_half; + + // ───────── below-diagonal logic ─────────────────────────── + const int32_t last_kv_id = (cp_rank_kv + 1) * half_seq_len - 1; + const int32_t last_kv_chunk_id = last_kv_id / chunk_size; + const int32_t last_kv_chunk_len = min( + half_seq_len - + (last_kv_chunk_id * chunk_size - cp_rank_kv * half_seq_len), + half_seq_len); + + const int32_t first_half_id = (cp_rank_q * half_seq_len) / chunk_size; + const int32_t second_half_id = ((2 * cp_size - cp_rank_q - 1) * half_seq_len) + / chunk_size; + + const int32_t first_half_len = min( + half_seq_len, + (first_half_id + 1) * chunk_size - cp_rank_q * half_seq_len); + + const int32_t second_half_len = min( + seq_len_q, + half_seq_len + + (second_half_id + 1) * chunk_size - + (2 * cp_size - 1 - cp_rank_q) * half_seq_len); + + const bool take_nothing = (last_kv_chunk_id != first_half_id); + const bool take_first_half_q = (!take_nothing) && + (last_kv_chunk_id != second_half_id); + const bool take_second_half_q = (!take_nothing) && (!take_first_half_q); + + int32_t q_seq_len = 0; + if (take_first_half_q) q_seq_len = first_half_len; + else if (take_second_half_q) q_seq_len = second_half_len; + + int32_t kv_seq_len = half_seq_len; + if (!take_nothing) kv_seq_len = last_kv_chunk_len; + + q_seq_len = min(q_seq_len, max(0, seq_plus_pad - pad_len_q)); + kv_seq_len = max(0, kv_seq_len - pad_len_kv); + + // ───────── flat output (row-major) ───────────────────────── + const int out_base = 3 * i; + + q_chunks[out_base + 0] = q_start; + q_chunks[out_base + 1] = q_start + q_seq_len; + q_chunks[out_base + 2] = q_start + q_seq_len; + + q_pads [out_base + 0] = pad_start; + q_pads [out_base + 1] = pad_start + q_seq_len; + q_pads [out_base + 2] = pad_start + seq_plus_pad; + + kv_chunks[out_base + 0] = kv_start; + kv_chunks[out_base + 1] = kv_start + kv_seq_len; + kv_chunks[out_base + 2] = kv_start + kv_seq_len; + + const int32_t kv_pad_base = pad_start >> 1; + kv_pads[out_base + 0] = kv_pad_base + (half_seq_len - kv_seq_len - pad_len_kv); + kv_pads[out_base + 1] = kv_pad_base + (half_seq_len - pad_len_kv); + kv_pads[out_base + 2] = kv_pad_base + half_seq_len; +} + + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + **************************************************************************************************/ + + +__global__ void thd_seq_tweak_above_diag_kernel( + const int* __restrict__ cu_seqlens_q_halfs, // len = batch+1 + const int* __restrict__ cu_seqlens_kv, // len = batch+1 + const int* __restrict__ cu_seqlens_padded, // len = batch+1 (full‑len prefix sums) + int* __restrict__ q_chunks, // len = 3·batch (row‑major) + int* __restrict__ kv_chunks, // len = 3·batch + int* __restrict__ q_pads, // len = 3·batch + int* __restrict__ kv_pads, // len = 3·batch + int cp_rank_q, + int cp_rank_kv, + int cp_size, + int chunk_size, + int batch) +{ + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= batch) return; + + // ───────── prefix‑sum diffs ─────────────────────────────── + const int32_t q_start = cu_seqlens_q_halfs[i]; + const int32_t q_end = cu_seqlens_q_halfs[i + 1]; + const int32_t kv_start = cu_seqlens_kv[i]; + const int32_t kv_end = cu_seqlens_kv[i + 1]; + const int32_t pad_start = cu_seqlens_padded[i]; + const int32_t pad_end = cu_seqlens_padded[i + 1]; + + const int32_t seq_len_q_half = q_end - q_start; // |Q|/2 (actual) + const int32_t seq_len_kv = kv_end - kv_start; // |KV| (full) + const int32_t seq_plus_pad = pad_end - pad_start; // |KV|+pads (full) + + const int32_t half_seq_len = seq_plus_pad >> 1; // |KV|/2 + pads/2 + + // pad lengths for later clamping + const int32_t pad_len_kv = seq_plus_pad - seq_len_kv; + + // ───────── above‑diagonal core logic ───────────────────── + // 1. Tokens from Q (first half from the *opposite* side) + const int32_t first_q_id = (2 * cp_size - 1 - cp_rank_q) * half_seq_len; + const int32_t first_q_chunk_id = first_q_id / chunk_size; + const int32_t first_q_chunk_len = min( + seq_len_q_half, + (first_q_chunk_id + 1) * chunk_size - first_q_id); + + // 2. Tokens from KV (might come from first or second half) + const int32_t first_half_kv_last_el_total_id = ((cp_rank_kv + 1) * half_seq_len) - 1; + const int32_t first_half_kv_last_chunk_id = first_half_kv_last_el_total_id / chunk_size; + + const int32_t second_half_kv_last_el_total_id = ((2 * cp_size - cp_rank_kv) * half_seq_len) - 1; + const int32_t second_half_kv_last_chunk_id = second_half_kv_last_el_total_id / chunk_size; + + // last‑trimmed chunk lengths + const int32_t first_half_kv_last_el_id_in_chunk = first_half_kv_last_el_total_id % chunk_size; + const int32_t first_half_kv_last_chunk_len = min( + seq_plus_pad, + half_seq_len + first_half_kv_last_el_id_in_chunk + 1); + + const int32_t second_half_kv_last_el_id_in_chunk = second_half_kv_last_el_total_id % chunk_size; + const int32_t second_half_kv_last_chunk_len = min( + half_seq_len, + second_half_kv_last_el_id_in_chunk + 1); + + // 3. Decide which half(s) we actually take + const bool take_nothing = (first_q_chunk_id != second_half_kv_last_chunk_id); + const bool take_second_half_kv = (!take_nothing) && (first_q_chunk_id != first_half_kv_last_chunk_id); + const bool take_first_half_kv = (!take_nothing) && (!take_second_half_kv); + + // ───────── resulting subseq lengths ─────────────────────── + int32_t q_seq_len = take_nothing ? 0 : first_q_chunk_len; + + int32_t kv_seq_len = 0; + if (take_second_half_kv) kv_seq_len = second_half_kv_last_chunk_len; + else if (take_first_half_kv) kv_seq_len = first_half_kv_last_chunk_len; + + // clamp against padding + kv_seq_len = max(0, kv_seq_len - pad_len_kv); + + // ───────── flat output (row‑major) ──────────────────────── + const int out_base = 3 * i; + + // Q chunks: [0, sequence, 0] + q_chunks[out_base + 0] = q_start; // beginning of Q half‑sequence + q_chunks[out_base + 1] = q_start + q_seq_len; // after the chunk we keep + q_chunks[out_base + 2] = q_start + q_seq_len; // stays flat afterwards + + // Q pads: [0, 0, garbage] (pads live in the *first* half of padded area) + const int32_t half_pad_start = pad_start >> 1; // start of Q‑related pad area + q_pads[out_base + 0] = half_pad_start; + q_pads[out_base + 1] = half_pad_start + q_seq_len; + q_pads[out_base + 2] = half_pad_start + half_seq_len; // complete half padded length + + // KV chunks: [0, sequence, 0] + kv_chunks[out_base + 0] = kv_start; + kv_chunks[out_base + 1] = kv_start + kv_seq_len; + kv_chunks[out_base + 2] = kv_start + seq_len_kv; + + // KV pads: [garbage, 0, 0] (pads precede KV if from first half) + kv_pads[out_base + 0] = pad_start + (seq_len_kv - kv_seq_len); + kv_pads[out_base + 1] = pad_start + seq_len_kv; + kv_pads[out_base + 2] = pad_start + seq_plus_pad; +} + + /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ @@ -669,6 +1036,180 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to batch, total_tokens, world_size, rank); } +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +void thd_chunkify( + const Tensor &cu_seqlens, + const Tensor &cu_seqlens_padded, + Tensor &out_cu_seqlens, + Tensor &out_cu_seqlens_padded, + int batch, + int output_len, + int chunk_size, + cudaStream_t stream +) { + using namespace transformer_engine; + NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_padded.dtype() == DType::kInt32); + + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens.dim() == 1); + NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); + + NVTE_CHECK(cu_seqlens_shape[0] == batch + 1); + NVTE_CHECK(cu_seqlens_padded_shape[0] == batch + 1); + + NVTE_CHECK(out_cu_seqlens_shape[0] == output_len); + NVTE_CHECK(out_cu_seqlens_padded_shape[0] == output_len); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_chunkify_kernel<<>>( + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(out_cu_seqlens.data.dptr), + reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size); +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +void thd_chunkify_p2p( + const Tensor &cu_seqlens, + const Tensor &cu_seqlens_padded, + Tensor &out_cu_seqlens, + Tensor &out_cu_seqlens_padded, + int batch, + int output_len, + int chunk_size, + cudaStream_t stream +) { + using namespace transformer_engine; + NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_padded.dtype() == DType::kInt32); + + // This tensors should be one dimensional + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens.dim() == 1); + NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); + + // length of cu_seqlens and cu_seqlens_padded should be batch + 1 + NVTE_CHECK(cu_seqlens_shape[0] == batch + 1); + NVTE_CHECK(cu_seqlens_padded_shape[0] == batch + 1); + + // length of out_cu_seqlens and out_cu_seqlens_padded should be output_len + NVTE_CHECK(out_cu_seqlens_shape[0] == output_len); + NVTE_CHECK(out_cu_seqlens_padded_shape[0] == output_len); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_chunkify_p2p_kernel<<>>( + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(out_cu_seqlens.data.dptr), + reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size); +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + **************************************************************************************************/ + +void thd_seq_tweak_below_diag( + const Tensor &cu_seqlens_q, + const Tensor &cu_seqlens_kv_halfs, + Tensor &out_cu_seqlens_q, + Tensor &out_cu_seqlens_kv_halfs, + int batch, + int output_len, + int chunk_size, + cudaStream_t stream +) { + using namespace transformer_engine; + + NVTE_CHECK(cu_seqlens_q.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_kv_halfs.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_q.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv_halfs.dtype() == DType::kInt32); + + // This tensors should be one dimensional + NVTE_CHECK(cu_seqlens_q.dim() == 1); + NVTE_CHECK(cu_seqlens_kv_halfs.dim() == 1); + NVTE_CHECK(out_cu_seqlens_q.dim() == 1); + NVTE_CHECK(out_cu_seqlens_kv_halfs.dim() == 1); + + // length of cu_seqlens_q and cu_seqlens_kv_halfs should be batch + 1 + NVTE_CHECK(cu_seqlens_q_shape[0] == batch + 1); + NVTE_CHECK(cu_seqlens_kv_halfs_shape[0] == batch + 1); + + // length of out_cu_seqlens_q and out_cu_seqlens_kv_halfs should be output_len + NVTE_CHECK(out_cu_seqlens_q_shape[0] == output_len); + NVTE_CHECK(out_cu_seqlens_kv_halfs_shape[0] == output_len); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_seq_tweak_below_diag_kernel<<>>( + reinterpret_cast(cu_seqlens_q.data.dptr), + reinterpret_cast(cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(out_cu_seqlens_q.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), batch, output_len, chunk_size); +} + + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + **************************************************************************************************/ + + +void thd_seq_tweak_above_diag( + + const Tensor &cu_seqlens_q_halfs, + const Tensor &cu_seqlens_kv_halfs, + Tensor &out_cu_seqlens_q_halfs, + Tensor &out_cu_seqlens_kv_halfs, + int batch, + int output_len, + int chunk_size, + cudaStream_t stream +) { + using namespace transformer_engine; + + NVTE_CHECK(cu_seqlens_q_halfs.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_kv_halfs.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_q_halfs.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv_halfs.dtype() == DType::kInt32); + + // This tensors should be one dimensional + NVTE_CHECK(cu_seqlens_q_halfs.dim() == 1); + NVTE_CHECK(cu_seqlens_kv_halfs.dim() == 1); + NVTE_CHECK(out_cu_seqlens_q_halfs.dim() == 1); + NVTE_CHECK(out_cu_seqlens_kv_halfs.dim() == 1); + + // length of cu_seqlens_q_halfs and cu_seqlens_kv_halfs should be batch + 1 + NVTE_CHECK(cu_seqlens_q_halfs_shape[0] == batch + 1); + NVTE_CHECK(cu_seqlens_kv_halfs_shape[0] == batch + 1); + + // length of out_cu_seqlens_q_halfs and out_cu_seqlens_kv_halfs should be output_len + NVTE_CHECK(out_cu_seqlens_q_halfs_shape[0] == output_len); + NVTE_CHECK(out_cu_seqlens_kv_halfs_shape[0] == output_len); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_seq_tweak_above_diag_kernel<<>>( + reinterpret_cast(cu_seqlens_q_halfs.data.dptr), + reinterpret_cast(cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(out_cu_seqlens_q_halfs.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), batch, output_len, chunk_size); +} + } // namespace context_parallel } // namespace transformer_engine @@ -741,3 +1282,54 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso *convertNVTETensorCheck(output), total_tokens, world_size, rank, stream); } + +void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_chunkify); + using namespace transformer_engine; + + context_parallel::thd_chunkify(*convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens), + *convertNVTETensorCheck(out_cu_seqlens_padded), batch, output_len, chunk_size, stream); +} + +void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_chunkify_p2p); + using namespace transformer_engine; + + context_parallel::thd_chunkify_p2p(*convertNVTETensorCheck(cu_seqlens), + *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens), + *convertNVTETensorCheck(out_cu_seqlens_padded), batch, output_len, chunk_size, stream); +} + +void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, + NVTETensor &out_cu_seqlens_q, NVTETensor &out_cu_seqlens_kv_halfs, + int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_seq_tweak_below_diag); + using namespace transformer_engine; + + context_parallel::thd_seq_tweak_below_diag(*convertNVTETensorCheck(cu_seqlens_q), + *convertNVTETensorCheck(cu_seqlens_kv_halfs), + *convertNVTETensorCheck(out_cu_seqlens_q), + *convertNVTETensorCheck(out_cu_seqlens_kv_halfs), + batch, output_len, chunk_size, stream); +} + +void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv_halfs, + NVTETensor &out_cu_seqlens_q_halfs, NVTETensor &out_cu_seqlens_kv_halfs, + int batch, int output_len, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_seq_tweak_above_diag); + using namespace transformer_engine; + + context_parallel::thd_seq_tweak_above_diag(*convertNVTETensorCheck(cu_seqlens_q_halfs), + *convertNVTETensorCheck(cu_seqlens_kv_halfs), + *convertNVTETensorCheck(out_cu_seqlens_q_halfs), + *convertNVTETensorCheck(out_cu_seqlens_kv_halfs), + batch, output_len, chunk_size, stream); +} + diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ebe8341cca..4f795faaa5 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -741,6 +741,79 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso int total_tokens, int world_size, int rank, cudaStream_t stream); +/*! \brief Split sequence into chunks for one P2P part on diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens Output tensor. + * \param[out] out_cu_seqlens_padded Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, cudaStream_t stream); + +/*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens Output tensor. + * \param[out] out_cu_seqlens_padded Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, cudaStream_t stream) + +/*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens_q Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_kv_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens_q Output tensor. + * \param[out] out_cu_seqlens_kv_halfs Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, + NVTETensor &out_cu_seqlens_q, NVTETensor &out_cu_seqlens_kv_halfs, + int batch, int output_len, int chunk_size, cudaStream_t stream); + + +/*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens_q_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_kv_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens_q_halfs Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv_halfs, + NVTETensor &out_cu_seqlens_q_halfs, NVTETensor &out_cu_seqlens_kv_halfs, + int batch, int output_len, int chunk_size, cudaStream_t stream); + + /*! \brief Convert tensor from THD to BSHD format. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index e6127e54f5..fbf4599d91 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -518,6 +518,8 @@ def forward( max_seqlen_kv = max_seqlen_kv // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + cu_seqlens_q_padded_per_step = [None for _ in range(cp_size)] + cu_seqlens_kv_padded_per_step = [None for _ in range(cp_size)] fused_attn_backend = None qkv_dtype = q.dtype @@ -683,6 +685,11 @@ def forward( p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] + import torch.distributed as dist + if dist.get_rank() == 0: + print(f"rank = {rank}") + dist.barrier() + out = None for i in range(cp_size + 1): if i < cp_size: @@ -723,8 +730,8 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q cu_seqlens_kv_per_step[i] = cu_seqlens_kv - cu_seqlens_q_padded_per_step = cu_seqlens_q_padded - cu_seqlens_kv_padded_per_step = cu_seqlens_kv_padded + cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded + cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -742,10 +749,10 @@ def forward( elif qkv_format == "thd": q_inputs[i % 2] = q if chunk_size is not None: - cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step = dpa_utils.thd_chunkify( - cu_seqlens_q, cu_seqlens_q_padded, None, chunk_size, True, rank, cp_size) - cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step = dpa_utils.thd_chunkify( - cu_seqlens_kv, cu_seqlens_kv_padded, None, chunk_size, True, rank, cp_size) + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i] = dpa_utils.thd_chunkify( + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i], None, chunk_size, True, rank, cp_size) + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i] = dpa_utils.thd_chunkify( + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i], None, chunk_size, True, rank, cp_size) if use_fused_attention: if attn_bias is not None: @@ -800,10 +807,12 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded_per_step, - cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[i], **fp8_meta_kwargs, ) + + if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: @@ -865,8 +874,9 @@ def forward( cu_seqlens_q_per_step[i] = cu_seqlens_q cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half - cu_seqlens_q_padded_per_step = cu_seqlens_q_padded - cu_seqlens_kv_padded_per_step = cu_seqlens_kv_padded + cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded + cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded // 2 \ + if cu_seqlens_kv_padded is not None else None if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -884,15 +894,12 @@ def forward( kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) - if chunk_size is not None: - assert q_inputs[i % 2].shape[1] == 2 * kv_inputs[i % 2].shape[1], \ - "THD+chunking is not supported for cross attention - initial q length should be the same as initial kv length" - cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step =\ + cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded_per_step[i], cu_seqlens_kv_padded_per_step[i] =\ dpa_utils.thd_seq_tweak_below_diagonal( - cu_seqlens_q, cu_seqlens_q_padded, chunk_size, rank, cp_size - ) - + cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded,\ + rank, rank - i, cp_size, chunk_size + ) if use_fused_attention: kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: @@ -940,12 +947,8 @@ def forward( attn_mask_type="padding" if padding else "no_mask", attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded_per_step, - cu_seqlens_kv_padded=( - None - if cu_seqlens_q_padded_per_step is None - else cu_seqlens_q_padded_per_step // 2 - ), + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[i], **fp8_meta_kwargs, ) if fp8: @@ -1015,6 +1018,10 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q_half cu_seqlens_kv_per_step[i] = cu_seqlens_kv + + cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded // 2 \ + if cu_seqlens_q_padded is not None else None + cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...] @@ -1034,6 +1041,13 @@ def forward( q_inputs[i % 2] = tex.thd_read_half_tensor( q, cu_seqlens_q_padded, 1 ) + if chunk_size is not None: + cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded_per_step[i], cu_seqlens_kv_padded_per_step[i] =\ + dpa_utils.thd_seq_tweak_above_diagonal( + cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded,\ + rank, rank - i + cp_size, cp_size, chunk_size + ) + if use_fused_attention: q_inputs[i % 2] = q_inputs[i % 2].contiguous() if attn_bias is not None: @@ -1087,12 +1101,8 @@ def forward( attn_mask_type="padding" if padding else "no_mask", attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[i], **fp8_meta_kwargs, ) if fp8: @@ -1409,6 +1419,8 @@ def forward( cu_seqlens_kv_padded, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, + *cu_seqlens_q_padded_per_step, + *cu_seqlens_kv_padded_per_step, *rng_states, *attn_biases, ) @@ -1479,8 +1491,10 @@ def backward(ctx, dout): ) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] - rng_states = other_tensors[cp_size * 2 : cp_size * 3] - attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + cu_seqlens_q_padded_per_step = other_tensors[cp_size * 2 : cp_size * 3] + cu_seqlens_kv_padded_per_step = other_tensors[cp_size * 3 : cp_size * 4] + rng_states = other_tensors[cp_size * 4 : cp_size * 5] + attn_biases = other_tensors[cp_size * 5 : cp_size * 6] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -1760,8 +1774,8 @@ def backward(ctx, dout): fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[cp_size - i - 1], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[cp_size - i - 1], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1870,6 +1884,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, @@ -1884,10 +1899,8 @@ def backward(ctx, dout): fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 - ), + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[cp_size - i - 1], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[cp_size - i - 1], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1998,6 +2011,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, @@ -2012,10 +2026,8 @@ def backward(ctx, dout): fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[cp_size - i - 1], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[cp_size - i - 1], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -2391,6 +2403,7 @@ def backward(ctx, dout): None, None, None, + None, attn_dbias, None, None, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 835d4d1459..03e8bbf6b6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -752,14 +752,11 @@ def forward( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" - if self.chunk_size is not None and self.cp_group is not None: - # todo: check if this condition is correct + if self.chunk_size is not None and self.cp_group is None: cu_seqlens_q, cu_seqlens_q_padded = dpa_utils.thd_chunkify( - cu_seqlens_q, cu_seqlens_q_padded, 0, self.chunk_size) + cu_seqlens_q, cu_seqlens_q_padded, torch.zeros_like(cu_seqlens_q, device=cu_seqlens_q.device), self.chunk_size) cu_seqlens_kv, cu_seqlens_kv_padded = dpa_utils.thd_chunkify( - cu_seqlens_kv, cu_seqlens_kv_padded, 0, self.chunk_size) - - + cu_seqlens_kv, cu_seqlens_kv_padded, torch.zeros_like(cu_seqlens_kv, device=cu_seqlens_kv.device), self.chunk_size) batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: @@ -1095,6 +1092,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + chunk_size=self.chunk_size, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1122,6 +1120,7 @@ def forward( max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + chunk_size=self.chunk_size, window_size=window_size, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index bb2d321370..c9a69115c0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1797,7 +1797,6 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer - def thd_chunkify( cu_seqlens: torch.Tensor, cu_seqlens_padded: torch.Tensor, @@ -1820,179 +1819,276 @@ def thd_chunkify( new_cu_seqlens_padded = [] - new_seq_lens = [] - new_seq_lens_padded = [] - for i in range(cu_seqlens.size(0)): + all_new_seq_lens = [] + all_new_seq_lens_padded = [] + for i in range(seq_lens.size(0)): + new_seq_lens = [] + new_seq_lens_padded = [] seq_len = seq_lens[i] new_seq_len = 0 pad_len = pad_lens[i] + total_seq_len = seq_len + pad_len + middle = total_seq_len // 2 if start_idx is None: - start_id = seq_len // 2 * cp_rank + start_id = total_seq_len // 2 * cp_rank else: start_id = start_idx[i] - # first chunk - first_chunk_lenght = (chunk_size - start_id) % chunk_size - - if cp_load_balance: - first_chunk_lenght = min(first_chunk_lenght, seq_len // 2) - first_chunk_lenght = min(first_chunk_lenght, seq_len) - - new_seq_lens.append(first_chunk_lenght) - new_seq_lens_padded.append(first_chunk_lenght) - new_seq_len += first_chunk_lenght - - if new_seq_len == seq_len: - continue - while True: + if new_seq_len == 0: + new_chunk_size = (chunk_size - start_id) % chunk_size or chunk_size + else: + new_chunk_size = chunk_size + if cp_load_balance: - if new_seq_len + chunk_size > seq_len // 2: + if new_seq_len + new_chunk_size >= middle: break else: - if new_seq_len + chunk_size > seq_len: + if new_seq_len + new_chunk_size > total_seq_len: break - new_seq_lens.append(chunk_size) - new_seq_lens_padded.append(chunk_size) - new_seq_len += chunk_size + new_seq_lens_padded.append(new_chunk_size) + new_seq_len += new_chunk_size if cp_load_balance: - last_token_first_part_id = start_id + seq_len // 2 - 1 - total_seq_size = seq_len * cp_size // 2 - first_token_second_part_id = total_seq_size - last_token_first_part_id # is symmetrical to last token of first part with respect to the middle of the sequence + last_token_first_part_id = start_id + middle - 1 + total_seq_size = total_seq_len * cp_size + first_token_second_part_id = total_seq_size - last_token_first_part_id - 1 # is symmetrical to last token of first part with respect to the middle of the sequence last_chunk_of_first_part_id = last_token_first_part_id // chunk_size first_chunk_of_second_part_id = first_token_second_part_id // chunk_size - extend_last_chunk = last_chunk_of_first_part_id == first_chunk_of_second_part_id + # last chunk of first part + last_chunk_lenght = middle - new_seq_len + # first chunk of second part - first_chunk_lenght = (chunk_size - first_token_second_part_id) % chunk_size + first_chunk_lenght = (chunk_size - first_token_second_part_id) % chunk_size or chunk_size + first_chunk_lenght = min(first_chunk_lenght, middle) - first_chunk_lenght = min(first_chunk_lenght, seq_len // 2) if extend_last_chunk: - new_seq_lens[-1] += first_chunk_lenght - new_seq_lens_padded[-1] += first_chunk_lenght + new_seq_lens_padded.append(int(last_chunk_lenght + first_chunk_lenght)) else: - new_seq_lens.append(first_chunk_lenght) - new_seq_lens_padded.append(first_chunk_lenght) - new_seq_len += first_chunk_lenght + new_seq_lens_padded.append(int(last_chunk_lenght)) + new_seq_lens_padded.append(int(first_chunk_lenght)) + new_seq_len += first_chunk_lenght + last_chunk_lenght while True: - if new_seq_len + chunk_size > seq_len // 2: + if new_seq_len + chunk_size > total_seq_len: break - new_seq_lens.append(chunk_size) new_seq_lens_padded.append(chunk_size) new_seq_len += chunk_size - last_chunk_lenght = seq_len - new_seq_len - new_seq_lens.append(last_chunk_lenght) - new_seq_len += last_chunk_lenght - # add last_chunk + padding to new_seq_lens_padded - new_seq_lens_padded.append(last_chunk_lenght + pad_len) - assert new_seq_len == seq_len - - new_cu_seqlens = torch.cumsum(torch.tensor(new_seq_lens), dim=0) - new_cu_seqlens_padded = torch.cumsum(torch.tensor(new_seq_lens_padded), dim=0) + last_chunk_lenght = 0 + if new_seq_len != total_seq_len: + last_chunk_lenght = total_seq_len - new_seq_len + new_seq_lens_padded.append(int(last_chunk_lenght)) + new_seq_len += last_chunk_lenght - return new_cu_seqlens, new_cu_seqlens_padded + unpadded_left = int(seq_len) + for chunk in new_seq_lens_padded: + new_seq_lens.append(int(min(unpadded_left, chunk))) + unpadded_left -= min(unpadded_left, chunk) - + all_new_seq_lens.extend(new_seq_lens) + all_new_seq_lens_padded.extend(new_seq_lens_padded) -@jit_fuser + assert new_seq_len == total_seq_len + + new_cu_seqlens = torch.zeros(len(all_new_seq_lens) + 1, device=cu_seqlens.device).to(torch.int32) + new_cu_seqlens[1:] = torch.cumsum(torch.tensor(all_new_seq_lens), dim=0) + new_cu_seqlens_padded = torch.zeros(len(all_new_seq_lens_padded) + 1, device=cu_seqlens.device).to(torch.int32) + new_cu_seqlens_padded[1:] = torch.cumsum(torch.tensor(all_new_seq_lens_padded), dim=0) + + return new_cu_seqlens, new_cu_seqlens_padded + +#@jit_fuser def thd_seq_tweak_below_diagonal( - cu_seqlens: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv_halfs: torch.Tensor, cu_seqlens_padded: torch.Tensor, - cp_rank: int, + cp_rank_q: int, + cp_rank_kv: int, cp_size: int, chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - pad_lens = ( - cu_seqlens_padded[1:] - - cu_seqlens_padded[:-1] - - seq_lens + assert cp_rank_q > cp_rank_kv + seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seq_lens_kv_halfs = cu_seqlens_kv_halfs[1:] - cu_seqlens_kv_halfs[:-1] + seq_plus_pad = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + half_seq_lens = seq_plus_pad // 2 + pad_lens_q = seq_plus_pad - seq_lens_q + pad_lens_kv = half_seq_lens - seq_lens_kv_halfs + + last_kv_id = (cp_rank_kv + 1) * half_seq_lens - 1 + last_kv_chunk_id = last_kv_id // chunk_size + last_kv_chunk_len = torch.min( + half_seq_lens - (last_kv_chunk_id * chunk_size - cp_rank_kv * half_seq_lens), + half_seq_lens ) - - seq_plus_pad = seq_lens + pad_lens - half_seq_lens = seq_plus_pad // 2 - - last_kv_first_half_id = cp_rank * (half_seq_lens + 1) - 1 - last_kv_chunk_pos = half_seq_lens - (half_seq_lens % chunk_size) - last_kv_chunk_id = last_kv_first_half_id // chunk_size - - first_half_q_first_chunk_id = (cp_rank * seq_plus_pad) // chunk_size - first_half_q_second_chunk_id = ( - (2 * cp_size - 1 - cp_rank) * half_seq_lens + + first_half_q_first_chunk_id = (cp_rank_q * half_seq_lens) // chunk_size + second_half_q_first_chunk_id = ( + (2 * cp_size - cp_rank_q - 1) * half_seq_lens ) // chunk_size - first_half_q_first_chunk_last_pos = torch.min( - half_seq_lens, (first_half_q_first_chunk_id * chunk_size) % half_seq_lens) - second_half_q_first_chunk_last_pos = torch.min( - seq_lens, half_seq_lens + (first_half_q_second_chunk_id * chunk_size) % half_seq_lens) + first_half_q_first_chunk_len = torch.min( + half_seq_lens, ((first_half_q_first_chunk_id + 1) * chunk_size) - (cp_rank_q * half_seq_lens)) + second_half_q_first_chunk_len = torch.min( + seq_lens_q, half_seq_lens + ((second_half_q_first_chunk_id + 1) * chunk_size) - ((2 * cp_size - 1 - cp_rank_q) * half_seq_lens)) - take_0 = last_kv_chunk_id != first_half_q_first_chunk_last_pos - take_first_half_q = last_kv_chunk_id != second_half_q_first_chunk_last_pos - take_second_half_q = (~take_0) & (~take_first_half_q) + take_nothing = last_kv_chunk_id != first_half_q_first_chunk_id + take_first_half_q = (~take_nothing) & (last_kv_chunk_id != second_half_q_first_chunk_id) + take_second_half_q = (~take_nothing) & (~take_first_half_q) - chunk_end_q_seqs = torch.zeros_like(seq_lens) - chunk_end_q_seqs[take_0] = 0 - chunk_end_q_seqs[take_first_half_q] = first_half_q_first_chunk_last_pos - chunk_end_q_seqs[take_second_half_q] = second_half_q_first_chunk_last_pos + zeros = lambda: torch.zeros_like(seq_lens_q) - chunk_start_kv_seqs = torch.zeros_like(seq_lens) - chunk_start_kv_seqs[~take_0] = last_kv_chunk_pos[~take_0] + q_seq_len = zeros() + + q_seq_len[take_nothing] = 0 + q_seq_len[take_first_half_q] = first_half_q_first_chunk_len[take_first_half_q] + q_seq_len[take_second_half_q] = second_half_q_first_chunk_len[take_second_half_q] - # Helper aliases - zeros_like = torch.zeros_like - minimum = torch.minimum + kv_seq_len = zeros() + half_seq_lens + kv_seq_len[~take_nothing] = last_kv_chunk_len[~take_nothing] + + + q_seq_len = torch.min(q_seq_len, torch.max(zeros(), seq_plus_pad - pad_lens_q)) + kv_seq_len = torch.max(zeros(), kv_seq_len - pad_lens_kv) # 1. Build per-sequence chunk sizes - # Q chunks : [0, first_part, second_part] - # KV chunks : [first_part, second_part, 0 third_part] - q_0 = zeros_like(seq_lens) - q_1 = chunk_end_q_seqs - q_2 = seq_plus_pad - chunk_end_q_seqs - q_3 = zeros_like(seq_lens) - q_chunks = torch.stack((q_0, q_1, q_2, q_3), dim=1) - - kv_0 = chunk_start_kv_seqs - kv_1 = zeros_like(seq_lens) - kv_2 = half_seq_lens - chunk_start_kv_seqs - kv_3 = half_seq_lens - kv_chunks = torch.stack((kv_0, kv_1, kv_2, kv_3), dim=1) + # Q chunks : [0, sequence, 0] + # Q pads: [0, 0, garbage] + # KV chunks : [0, sequence, 0] + # KV pads: [garbage, 0, 0] + q_0 = zeros() + q_1 = q_seq_len + q_2 = zeros() + q_chunks = torch.stack((q_0, q_1, q_2), dim=1) + + q_0_pad = zeros() + q_1_pad = zeros() + q_2_pad = seq_plus_pad - q_seq_len + q_pads = torch.stack((q_0_pad, q_1_pad, q_2_pad), dim=1) + + kv_0 = zeros() + kv_1 = kv_seq_len + kv_2 = zeros() + kv_chunks = torch.stack((kv_0, kv_1, kv_2), dim=1) - # 2. Padded variants – keep padding only in the final chunk - q_chunks_padded = torch.stack( - (q_0, q_1, q_2, seq_plus_pad - q_0 - q_1 - q_2), dim=1 - ) - kv_chunks_padded = torch.stack( - (kv_0, kv_1, kv_2, seq_plus_pad - kv_0 - kv_1 - kv_2), dim=1 - ) + + kv_0_pad = half_seq_lens - kv_seq_len - pad_lens_kv + kv_1_pad = zeros() + kv_2_pad = pad_lens_kv + kv_pads = torch.stack((kv_0_pad, kv_1_pad, kv_2_pad), dim=1) + + + q_chunks_padded = q_chunks + q_pads + kv_chunks_padded = kv_chunks + kv_pads - cu_seqlens_q_per_step = q_chunks.flatten().cumsum(0) - cu_seqlens_kv_per_step = kv_chunks.flatten().cumsum(0) - cu_seqlens_q_padded_per_step = q_chunks_padded.flatten().cumsum(0) - cu_seqlens_kv_padded_per_step = kv_chunks_padded.flatten().cumsum(0) + cu_seqlens_q_per_step = q_chunks.flatten().cumsum(0).to(torch.int32) + cu_seqlens_kv_per_step = kv_chunks.flatten().cumsum(0).to(torch.int32) + cu_seqlens_q_padded_per_step = q_chunks_padded.flatten().cumsum(0).to(torch.int32) + cu_seqlens_kv_padded_per_step = kv_chunks_padded.flatten().cumsum(0).to(torch.int32) return cu_seqlens_q_per_step, cu_seqlens_kv_per_step, cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step +#@jit_fuser -@jit_fuser def thd_seq_tweak_above_diagonal( - cu_seqlens: torch.Tensor, + cu_seqlens_q_halfs: torch.Tensor, + cu_seqlens_kv: torch.Tensor, cu_seqlens_padded: torch.Tensor, - cp_rank: int, + cp_rank_q: int, + cp_rank_kv: int, cp_size: int, chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert cp_rank_q < cp_rank_kv + + seq_lens_q_halfs = cu_seqlens_q_halfs[1:] - cu_seqlens_q_halfs[:-1] + seq_lens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + seq_plus_pad = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + half_seq_lens = seq_plus_pad // 2 + pad_lens_q = half_seq_lens - seq_lens_q_halfs + pad_lens_kv = seq_plus_pad - seq_lens_kv + + first_q_id = (2 * cp_size - 1 - cp_rank_q) * half_seq_lens + first_q_chunk_id = first_q_id // chunk_size + first_q_chunk_len = torch.min( + seq_lens_q_halfs, + (first_q_chunk_id + 1) * chunk_size - first_q_id + ) - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - pad_lens = ( - cu_seqlens_padded[1:] - - cu_seqlens_padded[:-1] - - seq_lens \ No newline at end of file + first_half_kv_last_el_total_id = ((cp_rank_kv + 1) * half_seq_lens - 1) + first_half_kv_last_chunk_id = first_half_kv_last_el_total_id // chunk_size + second_half_kv_last_el_total_id = (2 * cp_size - cp_rank_kv) * half_seq_lens - 1 + second_half_kv_last_chunk_id = second_half_kv_last_el_total_id // chunk_size + + + # these 2 are not easy + first_half_kv_last_el_id_in_chunk = first_half_kv_last_el_total_id % chunk_size + first_half_kv_last_chunk_len = torch.min(seq_plus_pad, half_seq_lens + first_half_kv_last_el_id_in_chunk + 1) + + second_half_kv_ast_el_id_in_chunk = second_half_kv_last_el_total_id % chunk_size + second_half_kv_last_chunk_len = torch.min(half_seq_lens, second_half_kv_ast_el_id_in_chunk + 1) + take_nothing = first_q_chunk_id != second_half_kv_last_chunk_id + take_second_half_kv = (~take_nothing) & (first_q_chunk_id != first_half_kv_last_chunk_id) + take_first_half_kv = (~take_nothing) & (~take_second_half_kv) + + + # Helper aliases + zeros = lambda: torch.zeros_like(seq_lens_q_halfs) + + q_seq_len = zeros() + q_seq_len[~take_nothing] = (first_q_chunk_len)[~take_nothing] + + kv_seq_len = zeros() + kv_seq_len[take_nothing] = 0 + kv_seq_len[take_second_half_kv] = second_half_kv_last_chunk_len[take_second_half_kv] + kv_seq_len[take_first_half_kv] = first_half_kv_last_chunk_len[take_first_half_kv] + + kv_seq_len = torch.max(zeros(), kv_seq_len - pad_lens_kv) + + # 1. Build per-sequence chunk sizes + # Q chunks : [0, sequence, 0] + # Q pads : [0, 0, garbage] + # KV chunks : [0, sequence, 0] + # KV pads : [garbage, 0, 0] + q_0 = zeros() + q_1 = q_seq_len + q_2 = zeros() + q_chunks = torch.stack((q_0, q_1, q_2), dim=1) + + q_0_pad = zeros() + q_1_pad = zeros() + q_2_pad = half_seq_lens - q_seq_len + q_pads = torch.stack((q_0_pad, q_1_pad, q_2_pad), dim=1) + + + kv_0 = zeros() + kv_1 = kv_seq_len + kv_2 = zeros() + kv_chunks = torch.stack((kv_0, kv_1, kv_2), dim=1) + + kv_0_pad = seq_lens_kv - kv_seq_len + kv_1_pad = zeros() + kv_2_pad = seq_plus_pad - kv_0_pad - kv_1 + kv_pads = torch.stack((kv_0_pad, kv_1_pad, kv_2_pad), dim=1) + + # 2. Padded variants – keep padding only in the final chunk + q_chunks_padded = q_chunks + q_pads + + kv_chunks_padded = kv_chunks + kv_pads + + # moze trzeba dorzucic zero + + cu_seqlens_q_per_step = q_chunks.flatten().cumsum(0).to(torch.int32) + cu_seqlens_kv_per_step = kv_chunks.flatten().cumsum(0).to(torch.int32) + cu_seqlens_q_padded_per_step = q_chunks_padded.flatten().cumsum(0).to(torch.int32) + cu_seqlens_kv_padded_per_step = kv_chunks_padded.flatten().cumsum(0).to(torch.int32) + + return cu_seqlens_q_per_step, cu_seqlens_kv_per_step, cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step \ No newline at end of file diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b9810bf861..3eca9dc51b 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -304,6 +304,7 @@ def fused_attn_fwd( rng_elts_per_thread, ) + # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:] diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f86b60f612..29e3979952 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -239,7 +239,10 @@ at::Tensor allocateSpace(const std::vector& shape, const transformer_eng if (init_to_zeros) { return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); } else { - return at::empty(ar_shape, at::CUDA(GetATenDType(type))); + // init to minus inf + at::Tensor ret = at::empty(ar_shape, at::CUDA(GetATenDType(type))); + ret.fill_(-std::numeric_limits::infinity()); + return ret; } } diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index efe825f0db..f041c3300b 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -739,6 +739,151 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t return output; } + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.size(0) >= 2); + NVTE_CHECK(chunk_size > 0); + + int batch = cu_seqlens.size(0) - 1; + int output_len = cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + + // Allocate output tensors + at::Tensor out_cu_seqlens = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_padded = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + + // Create tensor wrappers + auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_cu_seqlens = makeTransformerEngineTensor(out_cu_seqlens); + auto te_out_cu_seqlens_padded = makeTransformerEngineTensor(out_cu_seqlens_padded); + + nvte_cp_thd_chunkify( + te_cu_seqlens.data(), te_cu_seqlens_padded.data(), + te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), + batch, output_len, chunk_size, at::cuda::getCurrentCUDAStream() + ); + + return {out_cu_seqlens, out_cu_seqlens_padded}; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size, int world_size, int rank) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + + int batch = cu_seqlens.size(0) - 1; + int output_len = cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + + at::Tensor out_cu_seqlens = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_padded = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + + auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_cu_seqlens = makeTransformerEngineTensor(out_cu_seqlens); + auto te_out_cu_seqlens_padded = makeTransformerEngineTensor(out_cu_seqlens_padded); + + nvte_cp_thd_chunkify_p2p( + te_cu_seqlens.data(), te_cu_seqlens_padded.data(), + te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), + batch, output_len, chunk_size, world_size, rank, at::cuda::getCurrentCUDAStream() + ); + + return {out_cu_seqlens, out_cu_seqlens_padded}; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + **************************************************************************************************/ + +std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size, int total_seq_len) { + NVTE_CHECK(cu_seqlens_q.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_kv_halfs.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_q.dim() == 1); + NVTE_CHECK(cu_seqlens_kv_halfs.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(cu_seqlens_q.size(0) >= 2); + NVTE_CHECK(cu_seqlens_kv_halfs.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.size(0) >= 2); + NVTE_CHECK(cp_rank_q >= 0 && cp_rank_q < cp_size); + NVTE_CHECK(cp_rank_kv >= 0 && cp_rank_kv < cp_size); + NVTE_CHECK(cp_size > 0); + NVTE_CHECK(chunk_size > 0); + + int batch = cu_seqlens_q.size(0) - 1; + int output_len = 5 * cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + + at::Tensor out_cu_seqlens_q = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_kv_halfs = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + + auto te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q); + auto te_cu_seqlens_kv_halfs = makeTransformerEngineTensor(cu_seqlens_kv_halfs); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_cu_seqlens_q = makeTransformerEngineTensor(out_cu_seqlens_q); + auto te_out_cu_seqlens_kv_halfs = makeTransformerEngineTensor(out_cu_seqlens_kv_halfs); + + nvte_cp_thd_seq_tweak_below_diag( + te_cu_seqlens_q.data(), te_cu_seqlens_kv_halfs.data(), + te_cu_seqlens_padded.data(), te_out_cu_seqlens_q.data(), te_out_cu_seqlens_kv_halfs.data(), + batch, output_len, chunk_size, cp_rank_q, cp_rank_kv, cp_size, at::cuda::getCurrentCUDAStream() + ); + + return {out_cu_seqlens_q, out_cu_seqlens_kv_halfs}; +} + + + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + **************************************************************************************************/ + +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size, int total_seq_len) { + NVTE_CHECK(cu_seqlens_q_halfs.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_kv_halfs.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + + int batch = cu_seqlens_q_halfs.size(0) - 1; + int output_len = 5 * cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + + at::Tensor out_cu_seqlens_q_halfs = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_kv_halfs = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + + auto te_cu_seqlens_q_halfs = makeTransformerEngineTensor(cu_seqlens_q_halfs); + auto te_cu_seqlens_kv_halfs = makeTransformerEngineTensor(cu_seqlens_kv_halfs); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_cu_seqlens_q_halfs = makeTransformerEngineTensor(out_cu_seqlens_q_halfs); + auto te_out_cu_seqlens_kv_halfs = makeTransformerEngineTensor(out_cu_seqlens_kv_halfs); + + nvte_cp_thd_seq_tweak_above_diag( + te_cu_seqlens_q_halfs.data(), te_cu_seqlens_kv_halfs.data(), + te_cu_seqlens_padded.data(), te_out_cu_seqlens_q_halfs.data(), te_out_cu_seqlens_kv_halfs.data(), + batch, output_len, chunk_size, cp_rank_q, cp_rank_kv, cp_size, at::cuda::getCurrentCUDAStream() + ); + + return {out_cu_seqlens_q_halfs, out_cu_seqlens_kv_halfs}; +} + + /*************************************************************************************************** * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd **************************************************************************************************/ From 09942a1f612273ade21c69b818c007b38a50e4a0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 17 Jun 2025 13:05:04 +0000 Subject: [PATCH 3/5] is almost works Signed-off-by: Pawel Gadzinski --- .../common/fused_attn/context_parallel.cu | 201 +++++++----- .../include/transformer_engine/fused_attn.h | 35 ++- .../debug/features/utils/stats_buffer.py | 5 + .../dot_product_attention/context_parallel.py | 9 +- .../dot_product_attention.py | 5 +- .../attention/dot_product_attention/utils.py | 289 ++---------------- transformer_engine/pytorch/csrc/extensions.h | 14 + .../pytorch/csrc/extensions/attention.cpp | 70 +++-- .../pytorch/csrc/extensions/pybind.cpp | 9 + 9 files changed, 255 insertions(+), 382 deletions(-) diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 4c164cbb48..27fa301680 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -487,6 +487,15 @@ __global__ void thd_seq_tweak_below_diag_kernel( int batch) { const int i = blockIdx.x * blockDim.x + threadIdx.x; + + if(i == 0) { + q_chunks[0] = 0; + kv_chunks[0] = 0; + q_pads[0] = 0; + kv_pads[0] = 0; + } + __syncthreads(); + if (i >= batch) return; // ───────── prefix-sum diffs ──────────────────────────────── @@ -544,22 +553,22 @@ __global__ void thd_seq_tweak_below_diag_kernel( // ───────── flat output (row-major) ───────────────────────── const int out_base = 3 * i; - q_chunks[out_base + 0] = q_start; - q_chunks[out_base + 1] = q_start + q_seq_len; - q_chunks[out_base + 2] = q_start + q_seq_len; + q_chunks[out_base + 0 + 1] = 0; + q_chunks[out_base + 1 + 1] = q_seq_len; + q_chunks[out_base + 2 + 1] = 0; - q_pads [out_base + 0] = pad_start; - q_pads [out_base + 1] = pad_start + q_seq_len; - q_pads [out_base + 2] = pad_start + seq_plus_pad; + q_pads [out_base + 0 + 1] = pad_start; + q_pads [out_base + 1 + 1] = pad_start + q_seq_len; + q_pads [out_base + 2 + 1] = pad_start + seq_plus_pad; - kv_chunks[out_base + 0] = kv_start; - kv_chunks[out_base + 1] = kv_start + kv_seq_len; - kv_chunks[out_base + 2] = kv_start + kv_seq_len; + kv_chunks[out_base + 0 + 1] = 0; + kv_chunks[out_base + 1 + 1] = kv_seq_len; + kv_chunks[out_base + 2 + 1] = 0; const int32_t kv_pad_base = pad_start >> 1; - kv_pads[out_base + 0] = kv_pad_base + (half_seq_len - kv_seq_len - pad_len_kv); - kv_pads[out_base + 1] = kv_pad_base + (half_seq_len - pad_len_kv); - kv_pads[out_base + 2] = kv_pad_base + half_seq_len; + kv_pads[out_base + 0 + 1] = kv_pad_base + (half_seq_len - kv_seq_len - pad_len_kv); + kv_pads[out_base + 1 + 1] = kv_pad_base + (half_seq_len - pad_len_kv); + kv_pads[out_base + 2 + 1] = kv_pad_base + half_seq_len; } @@ -583,6 +592,14 @@ __global__ void thd_seq_tweak_above_diag_kernel( int batch) { const int i = blockIdx.x * blockDim.x + threadIdx.x; + + if(i == 0) { + q_chunks[0] = 0; + kv_chunks[0] = 0; + q_pads[0] = 0; + kv_pads[0] = 0; + } + __syncthreads(); if (i >= batch) return; // ───────── prefix‑sum diffs ─────────────────────────────── @@ -647,25 +664,25 @@ __global__ void thd_seq_tweak_above_diag_kernel( const int out_base = 3 * i; // Q chunks: [0, sequence, 0] - q_chunks[out_base + 0] = q_start; // beginning of Q half‑sequence - q_chunks[out_base + 1] = q_start + q_seq_len; // after the chunk we keep - q_chunks[out_base + 2] = q_start + q_seq_len; // stays flat afterwards + q_chunks[out_base + 0 + 1] = 0; // beginning of Q half‑sequence + q_chunks[out_base + 1 + 1] = q_seq_len; // after the chunk we keep + q_chunks[out_base + 2 + 1] = 0; // stays flat afterwards // Q pads: [0, 0, garbage] (pads live in the *first* half of padded area) const int32_t half_pad_start = pad_start >> 1; // start of Q‑related pad area - q_pads[out_base + 0] = half_pad_start; - q_pads[out_base + 1] = half_pad_start + q_seq_len; - q_pads[out_base + 2] = half_pad_start + half_seq_len; // complete half padded length + q_pads[out_base + 0 + 1] = half_pad_start; + q_pads[out_base + 1 + 1] = half_pad_start + q_seq_len; + q_pads[out_base + 2 + 1] = half_pad_start + half_seq_len; // complete half padded length // KV chunks: [0, sequence, 0] - kv_chunks[out_base + 0] = kv_start; - kv_chunks[out_base + 1] = kv_start + kv_seq_len; - kv_chunks[out_base + 2] = kv_start + seq_len_kv; + kv_chunks[out_base + 0 + 1] = 0; + kv_chunks[out_base + 1 + 1] = kv_seq_len; + kv_chunks[out_base + 2 + 1] = 0; // KV pads: [garbage, 0, 0] (pads precede KV if from first half) - kv_pads[out_base + 0] = pad_start + (seq_len_kv - kv_seq_len); - kv_pads[out_base + 1] = pad_start + seq_len_kv; - kv_pads[out_base + 2] = pad_start + seq_plus_pad; + kv_pads[out_base + 0 + 1] = pad_start + (seq_len_kv - kv_seq_len); + kv_pads[out_base + 1 + 1] = pad_start + seq_len_kv; + kv_pads[out_base + 2 + 1] = pad_start + seq_plus_pad; } @@ -1061,12 +1078,6 @@ void thd_chunkify( NVTE_CHECK(out_cu_seqlens.dim() == 1); NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); - NVTE_CHECK(cu_seqlens_shape[0] == batch + 1); - NVTE_CHECK(cu_seqlens_padded_shape[0] == batch + 1); - - NVTE_CHECK(out_cu_seqlens_shape[0] == output_len); - NVTE_CHECK(out_cu_seqlens_padded_shape[0] == output_len); - const unsigned int block = 256; const unsigned int grid = (output_len + block - 1) / block; thd_chunkify_kernel<<>>( @@ -1088,6 +1099,8 @@ void thd_chunkify_p2p( int batch, int output_len, int chunk_size, + int cp_rank, + int cp_size, cudaStream_t stream ) { using namespace transformer_engine; @@ -1102,13 +1115,6 @@ void thd_chunkify_p2p( NVTE_CHECK(out_cu_seqlens.dim() == 1); NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); - // length of cu_seqlens and cu_seqlens_padded should be batch + 1 - NVTE_CHECK(cu_seqlens_shape[0] == batch + 1); - NVTE_CHECK(cu_seqlens_padded_shape[0] == batch + 1); - - // length of out_cu_seqlens and out_cu_seqlens_padded should be output_len - NVTE_CHECK(out_cu_seqlens_shape[0] == output_len); - NVTE_CHECK(out_cu_seqlens_padded_shape[0] == output_len); const unsigned int block = 256; const unsigned int grid = (output_len + block - 1) / block; @@ -1116,7 +1122,7 @@ void thd_chunkify_p2p( reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(cu_seqlens_padded.data.dptr), reinterpret_cast(out_cu_seqlens.data.dptr), - reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size); + reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size, cp_rank, cp_size); } /*************************************************************************************************** @@ -1126,11 +1132,17 @@ void thd_chunkify_p2p( void thd_seq_tweak_below_diag( const Tensor &cu_seqlens_q, const Tensor &cu_seqlens_kv_halfs, + const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens_q, Tensor &out_cu_seqlens_kv_halfs, + Tensor &out_cu_seqlens_q_padded, + Tensor &out_cu_seqlens_kv_halfs_padded, int batch, int output_len, int chunk_size, + int cp_rank_q, + int cp_rank_kv, + int cp_size, cudaStream_t stream ) { using namespace transformer_engine; @@ -1146,35 +1158,40 @@ void thd_seq_tweak_below_diag( NVTE_CHECK(out_cu_seqlens_q.dim() == 1); NVTE_CHECK(out_cu_seqlens_kv_halfs.dim() == 1); - // length of cu_seqlens_q and cu_seqlens_kv_halfs should be batch + 1 - NVTE_CHECK(cu_seqlens_q_shape[0] == batch + 1); - NVTE_CHECK(cu_seqlens_kv_halfs_shape[0] == batch + 1); - - // length of out_cu_seqlens_q and out_cu_seqlens_kv_halfs should be output_len - NVTE_CHECK(out_cu_seqlens_q_shape[0] == output_len); - NVTE_CHECK(out_cu_seqlens_kv_halfs_shape[0] == output_len); - const unsigned int block = 256; const unsigned int grid = (output_len + block - 1) / block; thd_seq_tweak_below_diag_kernel<<>>( reinterpret_cast(cu_seqlens_q.data.dptr), reinterpret_cast(cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), reinterpret_cast(out_cu_seqlens_q.data.dptr), - reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), batch, output_len, chunk_size); + reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(out_cu_seqlens_q_padded.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_halfs_padded.data.dptr), + cp_rank_q, + cp_rank_kv, + cp_size, + chunk_size, + batch + ); } - /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. **************************************************************************************************/ void thd_seq_tweak_above_diag( - const Tensor &cu_seqlens_q_halfs, - const Tensor &cu_seqlens_kv_halfs, + const Tensor &cu_seqlens_kv, + const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens_q_halfs, - Tensor &out_cu_seqlens_kv_halfs, + Tensor &out_cu_seqlens_kv, + Tensor &out_cu_seqlens_q_halfs_padded, + Tensor &out_cu_seqlens_kv_padded, + int cp_rank_q, + int cp_rank_kv, + int cp_size, int batch, int output_len, int chunk_size, @@ -1183,31 +1200,40 @@ void thd_seq_tweak_above_diag( using namespace transformer_engine; NVTE_CHECK(cu_seqlens_q_halfs.dtype() == DType::kInt32); - NVTE_CHECK(cu_seqlens_kv_halfs.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_kv.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); NVTE_CHECK(out_cu_seqlens_q_halfs.dtype() == DType::kInt32); - NVTE_CHECK(out_cu_seqlens_kv_halfs.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_q_halfs_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv_padded.dtype() == DType::kInt32); // This tensors should be one dimensional NVTE_CHECK(cu_seqlens_q_halfs.dim() == 1); - NVTE_CHECK(cu_seqlens_kv_halfs.dim() == 1); + NVTE_CHECK(cu_seqlens_kv.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); NVTE_CHECK(out_cu_seqlens_q_halfs.dim() == 1); - NVTE_CHECK(out_cu_seqlens_kv_halfs.dim() == 1); - - // length of cu_seqlens_q_halfs and cu_seqlens_kv_halfs should be batch + 1 - NVTE_CHECK(cu_seqlens_q_halfs_shape[0] == batch + 1); - NVTE_CHECK(cu_seqlens_kv_halfs_shape[0] == batch + 1); - - // length of out_cu_seqlens_q_halfs and out_cu_seqlens_kv_halfs should be output_len - NVTE_CHECK(out_cu_seqlens_q_halfs_shape[0] == output_len); - NVTE_CHECK(out_cu_seqlens_kv_halfs_shape[0] == output_len); + NVTE_CHECK(out_cu_seqlens_kv.dim() == 1); + NVTE_CHECK(out_cu_seqlens_q_halfs_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens_kv_padded.dim() == 1); const unsigned int block = 256; - const unsigned int grid = (output_len + block - 1) / block; + const unsigned int grid = (batch + 1 + block - 1) / block; thd_seq_tweak_above_diag_kernel<<>>( reinterpret_cast(cu_seqlens_q_halfs.data.dptr), - reinterpret_cast(cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(cu_seqlens_kv.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), reinterpret_cast(out_cu_seqlens_q_halfs.data.dptr), - reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), batch, output_len, chunk_size); + reinterpret_cast(out_cu_seqlens_kv.data.dptr), + reinterpret_cast(out_cu_seqlens_q_halfs_padded.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_padded.data.dptr), + cp_rank_q, + cp_rank_kv, + cp_size, + chunk_size, + batch + ); + // synchronize + cudaStreamSynchronize(stream); } } // namespace context_parallel @@ -1284,7 +1310,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso } void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, - NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, int batch, int output_len, int chunk_size, cudaStream_t stream) { NVTE_API_CALL(nvte_thd_chunkify); using namespace transformer_engine; @@ -1296,40 +1322,55 @@ void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seq } void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, - NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, - int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, int cp_rank, int cp_size, + cudaStream_t stream) { NVTE_API_CALL(nvte_thd_chunkify_p2p); using namespace transformer_engine; context_parallel::thd_chunkify_p2p(*convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(cu_seqlens_padded), *convertNVTETensorCheck(out_cu_seqlens), - *convertNVTETensorCheck(out_cu_seqlens_padded), batch, output_len, chunk_size, stream); + *convertNVTETensorCheck(out_cu_seqlens_padded), + batch, output_len, chunk_size, cp_rank, cp_size, stream); } void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, - NVTETensor &out_cu_seqlens_q, NVTETensor &out_cu_seqlens_kv_halfs, + const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, + int cp_rank_q, int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream) { NVTE_API_CALL(nvte_thd_seq_tweak_below_diag); using namespace transformer_engine; - context_parallel::thd_seq_tweak_below_diag(*convertNVTETensorCheck(cu_seqlens_q), *convertNVTETensorCheck(cu_seqlens_kv_halfs), + *convertNVTETensorCheck(cu_seqlens_padded), *convertNVTETensorCheck(out_cu_seqlens_q), - *convertNVTETensorCheck(out_cu_seqlens_kv_halfs), - batch, output_len, chunk_size, stream); + *convertNVTETensorCheck(out_cu_seqlens_kv), + *convertNVTETensorCheck(out_cu_seqlens_q_padded), + *convertNVTETensorCheck(out_cu_seqlens_kv_padded), + batch, output_len, chunk_size, cp_rank_q, cp_rank_kv, cp_size, stream); } -void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv_halfs, - NVTETensor &out_cu_seqlens_q_halfs, NVTETensor &out_cu_seqlens_kv_halfs, - int batch, int output_len, cudaStream_t stream) { +void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, + const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, + int cp_rank_q, int cp_rank_kv, int cp_size, + int batch, int output_len, + int chunk_size, cudaStream_t stream) { NVTE_API_CALL(nvte_thd_seq_tweak_above_diag); using namespace transformer_engine; context_parallel::thd_seq_tweak_above_diag(*convertNVTETensorCheck(cu_seqlens_q_halfs), - *convertNVTETensorCheck(cu_seqlens_kv_halfs), - *convertNVTETensorCheck(out_cu_seqlens_q_halfs), - *convertNVTETensorCheck(out_cu_seqlens_kv_halfs), + *convertNVTETensorCheck(cu_seqlens_kv), + *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens_q), + *convertNVTETensorCheck(out_cu_seqlens_kv), + *convertNVTETensorCheck(out_cu_seqlens_q_padded), + *convertNVTETensorCheck(out_cu_seqlens_kv_padded), + cp_rank_q, cp_rank_kv, cp_size, batch, output_len, chunk_size, stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 4f795faaa5..5a1df3392c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -756,7 +756,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso */ void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, - NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, int batch, int output_len, int chunk_size, cudaStream_t stream); /*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. @@ -770,21 +770,30 @@ void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seq * \param[in] batch Batch size. * \param[in] output_len Output length. * \param[in] chunk_size Chunk size. + * \param[in] cp_rank Context Parallel rank. + * \param[in] cp_size Context Parallel size. * \param[in] stream CUDA stream used for this operation. */ void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, - NVTETensor &out_cu_seqlens, NVTETensor &out_cu_seqlens_padded, - int batch, int output_len, int chunk_size, cudaStream_t stream) - + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, int cp_rank, int cp_size, + cudaStream_t stream); + /*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. * * \warning This API is **experimental** and subject to change. * * \param[in] cu_seqlens_q Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens_kv_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence lengths, [batch_size + 1]. * \param[out] out_cu_seqlens_q Output tensor. - * \param[out] out_cu_seqlens_kv_halfs Output tensor. + * \param[out] out_cu_seqlens_kv Output tensor. + * \param[out] out_cu_seqlens_q_padded Output tensor. + * \param[out] out_cu_seqlens_kv_padded Output tensor. + * \param[in] cp_rank_q Context Parallel rank for Q. + * \param[in] cp_rank_kv Context Parallel rank for KV. + * \param[in] cp_size Context Parallel size. * \param[in] batch Batch size. * \param[in] output_len Output length. * \param[in] chunk_size Chunk size. @@ -792,10 +801,12 @@ void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu */ void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, - NVTETensor &out_cu_seqlens_q, NVTETensor &out_cu_seqlens_kv_halfs, + const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, + int cp_rank_q, int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream); - /*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. * * \warning This API is **experimental** and subject to change. @@ -809,9 +820,13 @@ void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTE * \param[in] stream CUDA stream used for this operation. */ -void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv_halfs, - NVTETensor &out_cu_seqlens_q_halfs, NVTETensor &out_cu_seqlens_kv_halfs, - int batch, int output_len, int chunk_size, cudaStream_t stream); +void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, + const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, + int cp_rank_q, int cp_rank_kv, int cp_size, + int batch, int output_len, + int chunk_size, cudaStream_t stream); /*! \brief Convert tensor from THD to BSHD format. diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 2313484054..f52136f5f9 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -84,6 +84,11 @@ def feed(self, tensor, iteration): # It is used for weights and microbatching. if self.modified[0] and not self.reduce_within_microbatch: return + + # We do not feed the tensor with 0 elements, + # we behave the same way as if feed() was not called. + if tensor.numel() == 0: + return # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index fbf4599d91..4e4f45a400 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -747,12 +747,13 @@ def forward( -1, k.shape[2], 2, *k.shape[-2:] ) elif qkv_format == "thd": + thd_total_seq_len = q.shape[0] q_inputs[i % 2] = q if chunk_size is not None: - cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i] = dpa_utils.thd_chunkify( - cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i], None, chunk_size, True, rank, cp_size) - cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i] = dpa_utils.thd_chunkify( - cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i], None, chunk_size, True, rank, cp_size) + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i] = dpa_utils.thd_chunkify_p2p( + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i], chunk_size, rank, cp_size, thd_total_seq_len) + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i] = dpa_utils.thd_chunkify_p2p( + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i], chunk_size, rank, cp_size, thd_total_seq_len) if use_fused_attention: if attn_bias is not None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 03e8bbf6b6..7ee852a9a1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -753,10 +753,11 @@ def forward( ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" if self.chunk_size is not None and self.cp_group is None: + total_seq_len = query_layer.shape[0] cu_seqlens_q, cu_seqlens_q_padded = dpa_utils.thd_chunkify( - cu_seqlens_q, cu_seqlens_q_padded, torch.zeros_like(cu_seqlens_q, device=cu_seqlens_q.device), self.chunk_size) + cu_seqlens_q, cu_seqlens_q_padded, self.chunk_size, total_seq_len) cu_seqlens_kv, cu_seqlens_kv_padded = dpa_utils.thd_chunkify( - cu_seqlens_kv, cu_seqlens_kv_padded, torch.zeros_like(cu_seqlens_kv, device=cu_seqlens_kv.device), self.chunk_size) + cu_seqlens_kv, cu_seqlens_kv_padded, self.chunk_size, total_seq_len) batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c9a69115c0..aa7d93e9c6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1800,112 +1800,36 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): def thd_chunkify( cu_seqlens: torch.Tensor, cu_seqlens_padded: torch.Tensor, - start_idx: Optional[torch.Tensor] = None, - chunk_size: int = None, - cp_load_balance: bool = False, - cp_rank: Optional[int] = None, - cp_size: Optional[int] = None, + chunk_size: int, + total_seq_len: int = 20, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Chunkify the cu_seqlens tensor. Returns new cu_seqlens, cu_seqlens_padded tensors - - First and last chunks in every sequence can be not full. """ - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - pad_seq_lens = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] - pad_lens = pad_seq_lens - seq_lens - new_cu_seqlens = [] - new_cu_seqlens_padded = [] - - - all_new_seq_lens = [] - all_new_seq_lens_padded = [] - for i in range(seq_lens.size(0)): - new_seq_lens = [] - new_seq_lens_padded = [] - seq_len = seq_lens[i] - new_seq_len = 0 - pad_len = pad_lens[i] - total_seq_len = seq_len + pad_len - middle = total_seq_len // 2 - if start_idx is None: - start_id = total_seq_len // 2 * cp_rank - else: - start_id = start_idx[i] - - while True: - if new_seq_len == 0: - new_chunk_size = (chunk_size - start_id) % chunk_size or chunk_size - else: - new_chunk_size = chunk_size - - if cp_load_balance: - if new_seq_len + new_chunk_size >= middle: - break - else: - if new_seq_len + new_chunk_size > total_seq_len: - break - - new_seq_lens_padded.append(new_chunk_size) - new_seq_len += new_chunk_size - - if cp_load_balance: - last_token_first_part_id = start_id + middle - 1 - total_seq_size = total_seq_len * cp_size - first_token_second_part_id = total_seq_size - last_token_first_part_id - 1 # is symmetrical to last token of first part with respect to the middle of the sequence - - last_chunk_of_first_part_id = last_token_first_part_id // chunk_size - first_chunk_of_second_part_id = first_token_second_part_id // chunk_size - - extend_last_chunk = last_chunk_of_first_part_id == first_chunk_of_second_part_id - - # last chunk of first part - last_chunk_lenght = middle - new_seq_len - - # first chunk of second part - first_chunk_lenght = (chunk_size - first_token_second_part_id) % chunk_size or chunk_size - first_chunk_lenght = min(first_chunk_lenght, middle) - - - if extend_last_chunk: - new_seq_lens_padded.append(int(last_chunk_lenght + first_chunk_lenght)) - else: - new_seq_lens_padded.append(int(last_chunk_lenght)) - new_seq_lens_padded.append(int(first_chunk_lenght)) - new_seq_len += first_chunk_lenght + last_chunk_lenght - - while True: - if new_seq_len + chunk_size > total_seq_len: - break - - new_seq_lens_padded.append(chunk_size) - new_seq_len += chunk_size - - last_chunk_lenght = 0 - if new_seq_len != total_seq_len: - last_chunk_lenght = total_seq_len - new_seq_len - new_seq_lens_padded.append(int(last_chunk_lenght)) - new_seq_len += last_chunk_lenght + new_cu_seqlens, new_cu_seqlens_padded = tex.thd_chunkify( + cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size) - unpadded_left = int(seq_len) - for chunk in new_seq_lens_padded: - new_seq_lens.append(int(min(unpadded_left, chunk))) - unpadded_left -= min(unpadded_left, chunk) - - all_new_seq_lens.extend(new_seq_lens) - all_new_seq_lens_padded.extend(new_seq_lens_padded) - - assert new_seq_len == total_seq_len + return new_cu_seqlens, new_cu_seqlens_padded - new_cu_seqlens = torch.zeros(len(all_new_seq_lens) + 1, device=cu_seqlens.device).to(torch.int32) - new_cu_seqlens[1:] = torch.cumsum(torch.tensor(all_new_seq_lens), dim=0) - new_cu_seqlens_padded = torch.zeros(len(all_new_seq_lens_padded) + 1, device=cu_seqlens.device).to(torch.int32) - new_cu_seqlens_padded[1:] = torch.cumsum(torch.tensor(all_new_seq_lens_padded), dim=0) +def thd_chunkify_p2p( + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + chunk_size: int, + cp_rank: int, + cp_size: int, + total_seq_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Chunkify the cu_seqlens tensor. + Returns new cu_seqlens, cu_seqlens_padded tensors + """ + + new_cu_seqlens, new_cu_seqlens_padded = tex.thd_chunkify_p2p( + cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size, cp_rank, cp_size) return new_cu_seqlens, new_cu_seqlens_padded -#@jit_fuser def thd_seq_tweak_below_diagonal( cu_seqlens_q: torch.Tensor, cu_seqlens_kv_halfs: torch.Tensor, @@ -1916,87 +1840,15 @@ def thd_seq_tweak_below_diagonal( chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert cp_rank_q > cp_rank_kv - seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - seq_lens_kv_halfs = cu_seqlens_kv_halfs[1:] - cu_seqlens_kv_halfs[:-1] - seq_plus_pad = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] - half_seq_lens = seq_plus_pad // 2 - pad_lens_q = seq_plus_pad - seq_lens_q - pad_lens_kv = half_seq_lens - seq_lens_kv_halfs - - last_kv_id = (cp_rank_kv + 1) * half_seq_lens - 1 - last_kv_chunk_id = last_kv_id // chunk_size - last_kv_chunk_len = torch.min( - half_seq_lens - (last_kv_chunk_id * chunk_size - cp_rank_kv * half_seq_lens), - half_seq_lens - ) - - first_half_q_first_chunk_id = (cp_rank_q * half_seq_lens) // chunk_size - second_half_q_first_chunk_id = ( - (2 * cp_size - cp_rank_q - 1) * half_seq_lens - ) // chunk_size - - first_half_q_first_chunk_len = torch.min( - half_seq_lens, ((first_half_q_first_chunk_id + 1) * chunk_size) - (cp_rank_q * half_seq_lens)) - second_half_q_first_chunk_len = torch.min( - seq_lens_q, half_seq_lens + ((second_half_q_first_chunk_id + 1) * chunk_size) - ((2 * cp_size - 1 - cp_rank_q) * half_seq_lens)) - - take_nothing = last_kv_chunk_id != first_half_q_first_chunk_id - take_first_half_q = (~take_nothing) & (last_kv_chunk_id != second_half_q_first_chunk_id) - take_second_half_q = (~take_nothing) & (~take_first_half_q) - - zeros = lambda: torch.zeros_like(seq_lens_q) - - q_seq_len = zeros() - - q_seq_len[take_nothing] = 0 - q_seq_len[take_first_half_q] = first_half_q_first_chunk_len[take_first_half_q] - q_seq_len[take_second_half_q] = second_half_q_first_chunk_len[take_second_half_q] - - kv_seq_len = zeros() + half_seq_lens - kv_seq_len[~take_nothing] = last_kv_chunk_len[~take_nothing] - - - q_seq_len = torch.min(q_seq_len, torch.max(zeros(), seq_plus_pad - pad_lens_q)) - kv_seq_len = torch.max(zeros(), kv_seq_len - pad_lens_kv) - - # 1. Build per-sequence chunk sizes - # Q chunks : [0, sequence, 0] - # Q pads: [0, 0, garbage] - # KV chunks : [0, sequence, 0] - # KV pads: [garbage, 0, 0] - q_0 = zeros() - q_1 = q_seq_len - q_2 = zeros() - q_chunks = torch.stack((q_0, q_1, q_2), dim=1) - - q_0_pad = zeros() - q_1_pad = zeros() - q_2_pad = seq_plus_pad - q_seq_len - q_pads = torch.stack((q_0_pad, q_1_pad, q_2_pad), dim=1) - - kv_0 = zeros() - kv_1 = kv_seq_len - kv_2 = zeros() - kv_chunks = torch.stack((kv_0, kv_1, kv_2), dim=1) - - - kv_0_pad = half_seq_lens - kv_seq_len - pad_lens_kv - kv_1_pad = zeros() - kv_2_pad = pad_lens_kv - kv_pads = torch.stack((kv_0_pad, kv_1_pad, kv_2_pad), dim=1) + new_seqlens_q, new_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = tex.thd_seq_tweak_below_diag( + cu_seqlens_q, cu_seqlens_kv_halfs, cu_seqlens_padded, cp_rank_q, cp_rank_kv, cp_size, chunk_size + ) - q_chunks_padded = q_chunks + q_pads - kv_chunks_padded = kv_chunks + kv_pads + new_cu_seqlens_q = torch.cumsum(new_seqlens_q, dim=0) + new_cu_seqlens_kv_halfs = torch.cumsum(new_seqlens_kv_halfs, dim=0) - cu_seqlens_q_per_step = q_chunks.flatten().cumsum(0).to(torch.int32) - cu_seqlens_kv_per_step = kv_chunks.flatten().cumsum(0).to(torch.int32) - cu_seqlens_q_padded_per_step = q_chunks_padded.flatten().cumsum(0).to(torch.int32) - cu_seqlens_kv_padded_per_step = kv_chunks_padded.flatten().cumsum(0).to(torch.int32) - - return cu_seqlens_q_per_step, cu_seqlens_kv_per_step, cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step - -#@jit_fuser + return new_cu_seqlens_q, new_cu_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded def thd_seq_tweak_above_diagonal( cu_seqlens_q_halfs: torch.Tensor, @@ -2008,87 +1860,12 @@ def thd_seq_tweak_above_diagonal( chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert cp_rank_q < cp_rank_kv - - seq_lens_q_halfs = cu_seqlens_q_halfs[1:] - cu_seqlens_q_halfs[:-1] - seq_lens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - seq_plus_pad = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] - half_seq_lens = seq_plus_pad // 2 - pad_lens_q = half_seq_lens - seq_lens_q_halfs - pad_lens_kv = seq_plus_pad - seq_lens_kv - - first_q_id = (2 * cp_size - 1 - cp_rank_q) * half_seq_lens - first_q_chunk_id = first_q_id // chunk_size - first_q_chunk_len = torch.min( - seq_lens_q_halfs, - (first_q_chunk_id + 1) * chunk_size - first_q_id - ) - - first_half_kv_last_el_total_id = ((cp_rank_kv + 1) * half_seq_lens - 1) - first_half_kv_last_chunk_id = first_half_kv_last_el_total_id // chunk_size - second_half_kv_last_el_total_id = (2 * cp_size - cp_rank_kv) * half_seq_lens - 1 - second_half_kv_last_chunk_id = second_half_kv_last_el_total_id // chunk_size - - - # these 2 are not easy - first_half_kv_last_el_id_in_chunk = first_half_kv_last_el_total_id % chunk_size - first_half_kv_last_chunk_len = torch.min(seq_plus_pad, half_seq_lens + first_half_kv_last_el_id_in_chunk + 1) - second_half_kv_ast_el_id_in_chunk = second_half_kv_last_el_total_id % chunk_size - second_half_kv_last_chunk_len = torch.min(half_seq_lens, second_half_kv_ast_el_id_in_chunk + 1) - take_nothing = first_q_chunk_id != second_half_kv_last_chunk_id - take_second_half_kv = (~take_nothing) & (first_q_chunk_id != first_half_kv_last_chunk_id) - take_first_half_kv = (~take_nothing) & (~take_second_half_kv) - - - # Helper aliases - zeros = lambda: torch.zeros_like(seq_lens_q_halfs) - - q_seq_len = zeros() - q_seq_len[~take_nothing] = (first_q_chunk_len)[~take_nothing] - - kv_seq_len = zeros() - kv_seq_len[take_nothing] = 0 - kv_seq_len[take_second_half_kv] = second_half_kv_last_chunk_len[take_second_half_kv] - kv_seq_len[take_first_half_kv] = first_half_kv_last_chunk_len[take_first_half_kv] - - kv_seq_len = torch.max(zeros(), kv_seq_len - pad_lens_kv) - - # 1. Build per-sequence chunk sizes - # Q chunks : [0, sequence, 0] - # Q pads : [0, 0, garbage] - # KV chunks : [0, sequence, 0] - # KV pads : [garbage, 0, 0] - q_0 = zeros() - q_1 = q_seq_len - q_2 = zeros() - q_chunks = torch.stack((q_0, q_1, q_2), dim=1) - - q_0_pad = zeros() - q_1_pad = zeros() - q_2_pad = half_seq_lens - q_seq_len - q_pads = torch.stack((q_0_pad, q_1_pad, q_2_pad), dim=1) - - - kv_0 = zeros() - kv_1 = kv_seq_len - kv_2 = zeros() - kv_chunks = torch.stack((kv_0, kv_1, kv_2), dim=1) - - kv_0_pad = seq_lens_kv - kv_seq_len - kv_1_pad = zeros() - kv_2_pad = seq_plus_pad - kv_0_pad - kv_1 - kv_pads = torch.stack((kv_0_pad, kv_1_pad, kv_2_pad), dim=1) - - # 2. Padded variants – keep padding only in the final chunk - q_chunks_padded = q_chunks + q_pads - - kv_chunks_padded = kv_chunks + kv_pads - - # moze trzeba dorzucic zero - - cu_seqlens_q_per_step = q_chunks.flatten().cumsum(0).to(torch.int32) - cu_seqlens_kv_per_step = kv_chunks.flatten().cumsum(0).to(torch.int32) - cu_seqlens_q_padded_per_step = q_chunks_padded.flatten().cumsum(0).to(torch.int32) - cu_seqlens_kv_padded_per_step = kv_chunks_padded.flatten().cumsum(0).to(torch.int32) + new_seqlens_q, new_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = tex.thd_seq_tweak_above_diag( + cu_seqlens_q_halfs, cu_seqlens_kv, cu_seqlens_padded, + cp_rank_q, cp_rank_kv, cp_size, chunk_size + ) + new_cu_seqlens_q_halfs = torch.cumsum(new_seqlens_q, dim=0) + new_cu_seqlens_kv = torch.cumsum(new_seqlens_kv, dim=0) - return cu_seqlens_q_per_step, cu_seqlens_kv_per_step, cu_seqlens_q_padded_per_step, cu_seqlens_kv_padded_per_step \ No newline at end of file + return new_cu_seqlens_q_halfs, new_cu_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 361c24b22c..846220b6fd 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -303,6 +303,20 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); +std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size) ; + +std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size, int cp_rank, int cp_size) ; + +std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size); + +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, at::Tensor cu_seqlens_kv, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size); + /*************************************************************************************************** * multi_tensor_* kernels **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index f041c3300b..9f578dad1f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -755,11 +755,11 @@ std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlen NVTE_CHECK(chunk_size > 0); int batch = cu_seqlens.size(0) - 1; - int output_len = cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + int output_len = cu_seqlens_padded.size(0) + total_seq_len / chunk_size + 1; // Allocate output tensors - at::Tensor out_cu_seqlens = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); - at::Tensor out_cu_seqlens_padded = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens = at::empty({output_len}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_padded = at::empty({output_len}, at::CUDA(at::ScalarType::Int)); // Create tensor wrappers auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); @@ -781,17 +781,17 @@ std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlen **************************************************************************************************/ std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, - int total_seq_len, int chunk_size, int world_size, int rank) { + int total_seq_len, int chunk_size, int cp_rank, int cp_size) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.size(0) >= 2); NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); int batch = cu_seqlens.size(0) - 1; - int output_len = cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + int output_len = 5 * cu_seqlens_padded.size(0) + total_seq_len / chunk_size; - at::Tensor out_cu_seqlens = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); - at::Tensor out_cu_seqlens_padded = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens = at::zeros({output_len}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_padded = at::zeros({output_len}, at::CUDA(at::ScalarType::Int)); auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); @@ -801,7 +801,7 @@ std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_se nvte_cp_thd_chunkify_p2p( te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), - batch, output_len, chunk_size, world_size, rank, at::cuda::getCurrentCUDAStream() + batch, output_len, chunk_size, cp_rank, cp_size, at::cuda::getCurrentCUDAStream() ); return {out_cu_seqlens, out_cu_seqlens_padded}; @@ -813,7 +813,7 @@ std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_se std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Tensor cu_seqlens_kv_halfs, at::Tensor cu_seqlens_padded, int cp_rank_q, - int cp_rank_kv, int cp_size, int chunk_size, int total_seq_len) { + int cp_rank_kv, int cp_size, int chunk_size) { NVTE_CHECK(cu_seqlens_q.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens_kv_halfs.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); @@ -829,24 +829,29 @@ std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Te NVTE_CHECK(chunk_size > 0); int batch = cu_seqlens_q.size(0) - 1; - int output_len = 5 * cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + int output_len = 3 * batch; - at::Tensor out_cu_seqlens_q = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); - at::Tensor out_cu_seqlens_kv_halfs = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_seqlens_q = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_seqlens_kv = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_q_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_kv_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); auto te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q); auto te_cu_seqlens_kv_halfs = makeTransformerEngineTensor(cu_seqlens_kv_halfs); auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); - auto te_out_cu_seqlens_q = makeTransformerEngineTensor(out_cu_seqlens_q); - auto te_out_cu_seqlens_kv_halfs = makeTransformerEngineTensor(out_cu_seqlens_kv_halfs); + auto te_out_seqlens_q = makeTransformerEngineTensor(out_seqlens_q); + auto te_out_seqlens_kv = makeTransformerEngineTensor(out_seqlens_kv); + auto te_out_cu_seqlens_q_padded = makeTransformerEngineTensor(out_cu_seqlens_q_padded); + auto te_out_cu_seqlens_kv_padded = makeTransformerEngineTensor(out_cu_seqlens_kv_padded); nvte_cp_thd_seq_tweak_below_diag( te_cu_seqlens_q.data(), te_cu_seqlens_kv_halfs.data(), - te_cu_seqlens_padded.data(), te_out_cu_seqlens_q.data(), te_out_cu_seqlens_kv_halfs.data(), - batch, output_len, chunk_size, cp_rank_q, cp_rank_kv, cp_size, at::cuda::getCurrentCUDAStream() + te_cu_seqlens_padded.data(), te_out_seqlens_q.data(), te_out_seqlens_kv.data(), + te_out_cu_seqlens_q_padded.data(), te_out_cu_seqlens_kv_padded.data(), + cp_rank_q, cp_rank_kv, cp_size, batch, output_len, chunk_size, at::cuda::getCurrentCUDAStream() ); - return {out_cu_seqlens_q, out_cu_seqlens_kv_halfs}; + return {out_seqlens_q, out_seqlens_kv, out_cu_seqlens_q_padded, out_cu_seqlens_kv_padded}; } @@ -855,32 +860,37 @@ std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Te * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. **************************************************************************************************/ -std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, at::Tensor cu_seqlens_kv_halfs, +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, at::Tensor cu_seqlens_kv, at::Tensor cu_seqlens_padded, int cp_rank_q, - int cp_rank_kv, int cp_size, int chunk_size, int total_seq_len) { + int cp_rank_kv, int cp_size, int chunk_size) { NVTE_CHECK(cu_seqlens_q_halfs.scalar_type() == at::ScalarType::Int); - NVTE_CHECK(cu_seqlens_kv_halfs.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_kv.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); int batch = cu_seqlens_q_halfs.size(0) - 1; - int output_len = 5 * cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + int output_len = 3 * batch; - at::Tensor out_cu_seqlens_q_halfs = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); - at::Tensor out_cu_seqlens_kv_halfs = at::zeros({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_seqlens_q_halfs = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_seqlens_kv = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_q_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_kv_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); auto te_cu_seqlens_q_halfs = makeTransformerEngineTensor(cu_seqlens_q_halfs); - auto te_cu_seqlens_kv_halfs = makeTransformerEngineTensor(cu_seqlens_kv_halfs); + auto te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv); auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); - auto te_out_cu_seqlens_q_halfs = makeTransformerEngineTensor(out_cu_seqlens_q_halfs); - auto te_out_cu_seqlens_kv_halfs = makeTransformerEngineTensor(out_cu_seqlens_kv_halfs); + auto te_out_seqlens_q_halfs = makeTransformerEngineTensor(out_seqlens_q_halfs); + auto te_out_seqlens_kv = makeTransformerEngineTensor(out_seqlens_kv); + auto te_out_cu_seqlens_q_padded = makeTransformerEngineTensor(out_cu_seqlens_q_padded); + auto te_out_cu_seqlens_kv_padded = makeTransformerEngineTensor(out_cu_seqlens_kv_padded); nvte_cp_thd_seq_tweak_above_diag( - te_cu_seqlens_q_halfs.data(), te_cu_seqlens_kv_halfs.data(), - te_cu_seqlens_padded.data(), te_out_cu_seqlens_q_halfs.data(), te_out_cu_seqlens_kv_halfs.data(), - batch, output_len, chunk_size, cp_rank_q, cp_rank_kv, cp_size, at::cuda::getCurrentCUDAStream() + te_cu_seqlens_q_halfs.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_padded.data(), te_out_seqlens_q_halfs.data(), te_out_seqlens_kv.data(), + te_out_cu_seqlens_q_padded.data(), te_out_cu_seqlens_kv_padded.data(), + cp_rank_q, cp_rank_kv, cp_size, batch, output_len, chunk_size, at::cuda::getCurrentCUDAStream() ); - return {out_cu_seqlens_q_halfs, out_cu_seqlens_kv_halfs}; + return {out_seqlens_q_halfs, out_seqlens_kv, out_cu_seqlens_q_padded, out_cu_seqlens_kv_padded}; } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0a1b76e697..8a6f54d45e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -279,6 +279,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices, "Generate partitioned indices for inputs in THD format", py::call_guard()); + m.def("thd_chunkify", &transformer_engine::pytorch::thd_chunkify, "Chunkify THD tensor", + py::arg("cu_seqlens"), py::arg("cu_seqlens_padded"), py::arg("total_seq_len"), py::arg("chunk_size"), + py::call_guard()); + m.def("thd_chunkify_p2p", &transformer_engine::pytorch::thd_chunkify_p2p, "Chunkify THD tensor for P2P communication", + py::call_guard()); + m.def("thd_seq_tweak_below_diag", &transformer_engine::pytorch::thd_seq_tweak_below_diag, "Tweak the sequence below the diagonal for THD tensor", + py::call_guard()); + m.def("thd_seq_tweak_above_diag", &transformer_engine::pytorch::thd_seq_tweak_above_diag, "Tweak the sequence above the diagonal for THD tensor", + py::call_guard()); // nvshmem functions m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend, From 4563f977a00f57018e7a9e144c1a64660c6a174a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 17 Jun 2025 16:26:00 +0200 Subject: [PATCH 4/5] tests passing Signed-off-by: Pawel Gadzinski --- tests/pytorch/fused_attn/run_fused_attn_with_cp.py | 1 + .../dot_product_attention/context_parallel.py | 10 ++-------- .../pytorch/attention/dot_product_attention/utils.py | 8 ++++---- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 918f878c03..9c1e0f3bb3 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -286,6 +286,7 @@ def run_dpa_with_cp( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) + if fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 4e4f45a400..eb8e04010a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -685,11 +685,6 @@ def forward( p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] - import torch.distributed as dist - if dist.get_rank() == 0: - print(f"rank = {rank}") - dist.barrier() - out = None for i in range(cp_size + 1): if i < cp_size: @@ -1885,7 +1880,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, @@ -2012,7 +2007,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, @@ -2140,7 +2135,6 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) - if ctx.fp8: dq_ = dq_._data dk_ = dk_._data diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index aa7d93e9c6..61b209cedb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1845,8 +1845,8 @@ def thd_seq_tweak_below_diagonal( cu_seqlens_q, cu_seqlens_kv_halfs, cu_seqlens_padded, cp_rank_q, cp_rank_kv, cp_size, chunk_size ) - new_cu_seqlens_q = torch.cumsum(new_seqlens_q, dim=0) - new_cu_seqlens_kv_halfs = torch.cumsum(new_seqlens_kv_halfs, dim=0) + new_cu_seqlens_q = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) + new_cu_seqlens_kv_halfs = torch.cumsum(new_seqlens_kv_halfs, dim=0, dtype=torch.int32) return new_cu_seqlens_q, new_cu_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded @@ -1865,7 +1865,7 @@ def thd_seq_tweak_above_diagonal( cu_seqlens_q_halfs, cu_seqlens_kv, cu_seqlens_padded, cp_rank_q, cp_rank_kv, cp_size, chunk_size ) - new_cu_seqlens_q_halfs = torch.cumsum(new_seqlens_q, dim=0) - new_cu_seqlens_kv = torch.cumsum(new_seqlens_kv, dim=0) + new_cu_seqlens_q_halfs = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) + new_cu_seqlens_kv = torch.cumsum(new_seqlens_kv, dim=0, dtype=torch.int32) return new_cu_seqlens_q_halfs, new_cu_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded \ No newline at end of file From e32654ddf704746310e979cef6a1c82eb59f1e1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Jun 2025 14:39:53 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn/run_fused_attn_with_cp.py | 5 +- tests/pytorch/fused_attn/test_fused_attn.py | 1 + .../fused_attn/test_fused_attn_with_cp.py | 8 +- .../common/fused_attn/context_parallel.cu | 834 ++++++++---------- .../include/transformer_engine/fused_attn.h | 32 +- .../debug/features/utils/stats_buffer.py | 2 +- .../dot_product_attention/context_parallel.py | 81 +- .../dot_product_attention.py | 11 +- .../attention/dot_product_attention/utils.py | 55 +- .../pytorch/cpp_extensions/fused_attn.py | 1 - transformer_engine/pytorch/csrc/extensions.h | 19 +- .../pytorch/csrc/extensions/attention.cpp | 56 +- .../pytorch/csrc/extensions/pybind.cpp | 16 +- 13 files changed, 548 insertions(+), 573 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 9c1e0f3bb3..1e46ecb8c8 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -41,7 +41,7 @@ def run_dpa_with_cp( if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" config = model_configs_fused_attn[model] - + assert config.attn_mask_type in [ "causal", "no_mask", @@ -273,7 +273,6 @@ def run_dpa_with_cp( else: fp8_context = nullcontext() - with fp8_context: out_ = core_attn( q_, @@ -298,8 +297,6 @@ def run_dpa_with_cp( assert isinstance(out_, Float8Tensor) out = out.dequantize() out_ = out_.dequantize() - - for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 1857e11d8e..87b77b9b11 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -113,6 +113,7 @@ def __init__( self.max_ctx_len = max_ctx_len self.chunk_size = chunk_size + @contextmanager def logging_context(highest_level=logging.WARNING): previous_level = logging.root.manager.disable diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1ef9a0659b..385e39bbfd 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -24,7 +24,9 @@ "cp_1_3": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) ), # MHA - "cp_1_4": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024), # MHA with chunks + "cp_1_4": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024 + ), # MHA with chunks "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig( @@ -101,7 +103,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # MHA - "cp_1_5": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024), # MHA + "cp_1_5": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024 + ), # MHA "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 0a35479d67..e8e2fddbbb 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -302,389 +302,351 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in } } - /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks. **************************************************************************************************/ -__global__ void thd_chunkify_kernel(const int* __restrict__ d_cu_seqlens, - const int* __restrict__ d_cu_seqlens_padded, - int* __restrict__ d_out_cu_seqlens, - int* __restrict__ d_out_cu_seqlens_padded, - int batch, // = len(cu_seqlens)-1 - int output_len, - int chunk_size) -{ - const int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= output_len) return; - if (i == 0) - { - d_out_cu_seqlens[0] = 0; - d_out_cu_seqlens_padded[0] = 0; - return; - } - - int pos_id = 0; - int seq_len = 0; - int pad_len = 0; - int seq_start = d_cu_seqlens[batch]; - int pad_start = d_cu_seqlens_padded[batch]; - int cur_start = 0; - int cur_last = -1; - bool found = false; - - for (int j = 0; j < batch; ++j) - { - cur_start = (j == 0) ? 1 : cur_last + 1; - - int d_cu_seqlens_padded_j = d_cu_seqlens_padded[j]; - int d_cu_seqlens_padded_j_1 = d_cu_seqlens_padded[j + 1]; - int d_cu_seqlens_j = d_cu_seqlens[j]; - int d_cu_seqlens_j_1 = d_cu_seqlens[j + 1]; - - int num_chunks = (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j + (chunk_size - 1)) / chunk_size; - cur_last = cur_start + num_chunks - 1; - - bool match = (i >= cur_start) && (i <= cur_last); - pos_id = match ? (i - cur_start) : pos_id; - seq_len = match ? (d_cu_seqlens_j_1 - d_cu_seqlens_j) : seq_len; - pad_len = match ? (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j) : pad_len; - seq_start = match ? d_cu_seqlens_j : seq_start; - pad_start = match ? d_cu_seqlens_padded_j : pad_start; - found = match || found; - } +__global__ void thd_chunkify_kernel(const int *__restrict__ d_cu_seqlens, + const int *__restrict__ d_cu_seqlens_padded, + int *__restrict__ d_out_cu_seqlens, + int *__restrict__ d_out_cu_seqlens_padded, + int batch, // = len(cu_seqlens)-1 + int output_len, int chunk_size) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= output_len) return; + if (i == 0) { + d_out_cu_seqlens[0] = 0; + d_out_cu_seqlens_padded[0] = 0; + return; + } - if (!found) - { - d_out_cu_seqlens[i] = d_cu_seqlens[batch]; - d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; - } - else{ + int pos_id = 0; + int seq_len = 0; + int pad_len = 0; + int seq_start = d_cu_seqlens[batch]; + int pad_start = d_cu_seqlens_padded[batch]; + int cur_start = 0; + int cur_last = -1; + bool found = false; + + for (int j = 0; j < batch; ++j) { + cur_start = (j == 0) ? 1 : cur_last + 1; + + int d_cu_seqlens_padded_j = d_cu_seqlens_padded[j]; + int d_cu_seqlens_padded_j_1 = d_cu_seqlens_padded[j + 1]; + int d_cu_seqlens_j = d_cu_seqlens[j]; + int d_cu_seqlens_j_1 = d_cu_seqlens[j + 1]; + + int num_chunks = + (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j + (chunk_size - 1)) / chunk_size; + cur_last = cur_start + num_chunks - 1; + + bool match = (i >= cur_start) && (i <= cur_last); + pos_id = match ? (i - cur_start) : pos_id; + seq_len = match ? (d_cu_seqlens_j_1 - d_cu_seqlens_j) : seq_len; + pad_len = match ? (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j) : pad_len; + seq_start = match ? d_cu_seqlens_j : seq_start; + pad_start = match ? d_cu_seqlens_padded_j : pad_start; + found = match || found; + } - int32_t out_seq = ((pos_id > (seq_len / chunk_size)) ? seq_len : chunk_size * pos_id) + seq_start; - int32_t out_pad = ((pos_id > (pad_len / chunk_size)) ? pad_len : chunk_size * pos_id) + pad_start; + if (!found) { + d_out_cu_seqlens[i] = d_cu_seqlens[batch]; + d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; + } else { + int32_t out_seq = + ((pos_id > (seq_len / chunk_size)) ? seq_len : chunk_size * pos_id) + seq_start; + int32_t out_pad = + ((pos_id > (pad_len / chunk_size)) ? pad_len : chunk_size * pos_id) + pad_start; - d_out_cu_seqlens[i] = out_seq; - d_out_cu_seqlens_padded[i] = out_pad; - } + d_out_cu_seqlens[i] = out_seq; + d_out_cu_seqlens_padded[i] = out_pad; + } } - /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. **************************************************************************************************/ +__global__ void thd_chunkify_p2p_kernel(const int *__restrict__ d_cu_seqlens, + const int *__restrict__ d_cu_seqlens_padded, + int *__restrict__ d_out_cu_seqlens, + int *__restrict__ d_out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, int cp_rank, int cp_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= output_len) return; + if (i == 0) { + d_out_cu_seqlens[0] = 0; + d_out_cu_seqlens_padded[0] = 0; + return; + } -__global__ void thd_chunkify_p2p_kernel( - const int* __restrict__ d_cu_seqlens, - const int* __restrict__ d_cu_seqlens_padded, - int* __restrict__ d_out_cu_seqlens, - int* __restrict__ d_out_cu_seqlens_padded, - int batch, - int output_len, - int chunk_size, - int cp_rank, - int cp_size) -{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i >= output_len) return; - if (i == 0) - { - d_out_cu_seqlens[0] = 0; - d_out_cu_seqlens_padded[0] = 0; - return; - } - - int pos_id = 0; - int seq_len = 0; - int total_seq_len = 0; - int seq_start_offset = d_cu_seqlens[batch]; - int pad_start_offset = d_cu_seqlens_padded[batch]; - int cur_start = 1; - int cur_last = 0; - - for (int j = 0; j < batch; ++j) { - cur_start = (j == 0) ? 1 : cur_last + 1; - int padded_len = d_cu_seqlens_padded[j + 1] - d_cu_seqlens_padded[j]; - int num_chunks = (padded_len + chunk_size - 1) / chunk_size; - cur_last = cur_start + num_chunks + 3; - if (i >= cur_start && (i <= cur_last || j == batch - 1)) { - pos_id = i - cur_start; - seq_len = d_cu_seqlens[j + 1] - d_cu_seqlens[j]; - total_seq_len = padded_len; - seq_start_offset = d_cu_seqlens[j]; - pad_start_offset = d_cu_seqlens_padded[j]; - break; - } + int pos_id = 0; + int seq_len = 0; + int total_seq_len = 0; + int seq_start_offset = d_cu_seqlens[batch]; + int pad_start_offset = d_cu_seqlens_padded[batch]; + int cur_start = 1; + int cur_last = 0; + + for (int j = 0; j < batch; ++j) { + cur_start = (j == 0) ? 1 : cur_last + 1; + int padded_len = d_cu_seqlens_padded[j + 1] - d_cu_seqlens_padded[j]; + int num_chunks = (padded_len + chunk_size - 1) / chunk_size; + cur_last = cur_start + num_chunks + 3; + if (i >= cur_start && (i <= cur_last || j == batch - 1)) { + pos_id = i - cur_start; + seq_len = d_cu_seqlens[j + 1] - d_cu_seqlens[j]; + total_seq_len = padded_len; + seq_start_offset = d_cu_seqlens[j]; + pad_start_offset = d_cu_seqlens_padded[j]; + break; } + } + if (total_seq_len == 0) { + d_out_cu_seqlens[i] = d_cu_seqlens[batch]; + d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; + return; + } - if (total_seq_len == 0) { - d_out_cu_seqlens[i] = d_cu_seqlens[batch]; - d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; - return; - } - - int middle = total_seq_len / 2; - - int start_id_1 = (total_seq_len * cp_rank) / 2; - int temp = (chunk_size - start_id_1 - 1) % chunk_size; - int first_chunk_size_1 = ((temp < 0) ? temp + chunk_size : temp) + 1; - first_chunk_size_1 = (first_chunk_size_1 >= middle) ? 0 : first_chunk_size_1; - int num_chunks_1 = (middle - first_chunk_size_1 - 1) / chunk_size; - int last_chunk_size_1 = middle - num_chunks_1 * chunk_size - first_chunk_size_1; - int last_token_id_1 = start_id_1 + middle - 1; - int last_chunk_id_1 = last_token_id_1 / chunk_size; - - int start_id_2 = total_seq_len * cp_size - last_token_id_1 - 1; - int first_chunk_id_2 = start_id_2 / chunk_size; - int temp2 = (chunk_size - start_id_2 - 1) % chunk_size; - int first_chunk_size_2 = ((temp2 < 0) ? temp2 + chunk_size : temp2) + 1; - int num_chunks_2 = (middle - first_chunk_size_2) / chunk_size; - - bool merge_chunks = (last_chunk_id_1 == first_chunk_id_2); - int middle_chunk_1 = merge_chunks ? (last_chunk_size_1 + first_chunk_size_2) : last_chunk_size_1; - int middle_chunk_2 = merge_chunks ? 0 : first_chunk_size_2; - - int out_pad_len = 0; - if (pos_id == 0) { - out_pad_len = first_chunk_size_1; - } else if (pos_id <= num_chunks_1 + 1) { - out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size; - } else if (pos_id == num_chunks_1 + 2) { - out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1; - } else if (pos_id == num_chunks_1 + 3) { - out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2; - } else if (pos_id <= num_chunks_1 + 3 + num_chunks_2) { - out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2 + num_chunks_2 * chunk_size; - } else { - out_pad_len = total_seq_len; - } + int middle = total_seq_len / 2; + + int start_id_1 = (total_seq_len * cp_rank) / 2; + int temp = (chunk_size - start_id_1 - 1) % chunk_size; + int first_chunk_size_1 = ((temp < 0) ? temp + chunk_size : temp) + 1; + first_chunk_size_1 = (first_chunk_size_1 >= middle) ? 0 : first_chunk_size_1; + int num_chunks_1 = (middle - first_chunk_size_1 - 1) / chunk_size; + int last_chunk_size_1 = middle - num_chunks_1 * chunk_size - first_chunk_size_1; + int last_token_id_1 = start_id_1 + middle - 1; + int last_chunk_id_1 = last_token_id_1 / chunk_size; + + int start_id_2 = total_seq_len * cp_size - last_token_id_1 - 1; + int first_chunk_id_2 = start_id_2 / chunk_size; + int temp2 = (chunk_size - start_id_2 - 1) % chunk_size; + int first_chunk_size_2 = ((temp2 < 0) ? temp2 + chunk_size : temp2) + 1; + int num_chunks_2 = (middle - first_chunk_size_2) / chunk_size; + + bool merge_chunks = (last_chunk_id_1 == first_chunk_id_2); + int middle_chunk_1 = merge_chunks ? (last_chunk_size_1 + first_chunk_size_2) : last_chunk_size_1; + int middle_chunk_2 = merge_chunks ? 0 : first_chunk_size_2; + + int out_pad_len = 0; + if (pos_id == 0) { + out_pad_len = first_chunk_size_1; + } else if (pos_id <= num_chunks_1 + 1) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size; + } else if (pos_id == num_chunks_1 + 2) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1; + } else if (pos_id == num_chunks_1 + 3) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2; + } else if (pos_id <= num_chunks_1 + 3 + num_chunks_2) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2 + + num_chunks_2 * chunk_size; + } else { + out_pad_len = total_seq_len; + } - int out_seq_len = (out_pad_len < seq_len) ? out_pad_len : seq_len; + int out_seq_len = (out_pad_len < seq_len) ? out_pad_len : seq_len; - d_out_cu_seqlens[i] = seq_start_offset + out_seq_len; - d_out_cu_seqlens_padded[i] = pad_start_offset + out_pad_len; + d_out_cu_seqlens[i] = seq_start_offset + out_seq_len; + d_out_cu_seqlens_padded[i] = pad_start_offset + out_pad_len; } - /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. **************************************************************************************************/ __global__ void thd_seq_tweak_below_diag_kernel( - const int* __restrict__ cu_seqlens_q, - const int* __restrict__ cu_seqlens_kv_halfs, - const int* __restrict__ cu_seqlens_padded, - int* __restrict__ q_chunks, - int* __restrict__ kv_chunks, - int* __restrict__ q_pads, - int* __restrict__ kv_pads, - int cp_rank_q, - int cp_rank_kv, - int cp_size, - int chunk_size, - int batch) -{ - const int i = blockIdx.x * blockDim.x + threadIdx.x; - - if(i == 0) { - q_chunks[0] = 0; - kv_chunks[0] = 0; - q_pads[0] = 0; - kv_pads[0] = 0; - } - __syncthreads(); - - if (i >= batch) return; - - // ───────── prefix-sum diffs ──────────────────────────────── - const int32_t q_start = cu_seqlens_q[i]; - const int32_t q_end = cu_seqlens_q[i + 1]; - const int32_t kv_start = cu_seqlens_kv_halfs[i]; - const int32_t kv_end = cu_seqlens_kv_halfs[i + 1]; - const int32_t pad_start = cu_seqlens_padded[i]; - const int32_t pad_end = cu_seqlens_padded[i + 1]; - - const int32_t seq_len_q = q_end - q_start; - const int32_t seq_len_kv_half = kv_end - kv_start; - const int32_t seq_plus_pad = pad_end - pad_start; - const int32_t half_seq_len = seq_plus_pad >> 1; - const int32_t pad_len_q = seq_plus_pad - seq_len_q; - const int32_t pad_len_kv = half_seq_len - seq_len_kv_half; - - // ───────── below-diagonal logic ─────────────────────────── - const int32_t last_kv_id = (cp_rank_kv + 1) * half_seq_len - 1; - const int32_t last_kv_chunk_id = last_kv_id / chunk_size; - const int32_t last_kv_chunk_len = min( - half_seq_len - - (last_kv_chunk_id * chunk_size - cp_rank_kv * half_seq_len), - half_seq_len); - - const int32_t first_half_id = (cp_rank_q * half_seq_len) / chunk_size; - const int32_t second_half_id = ((2 * cp_size - cp_rank_q - 1) * half_seq_len) - / chunk_size; - - const int32_t first_half_len = min( - half_seq_len, - (first_half_id + 1) * chunk_size - cp_rank_q * half_seq_len); - - const int32_t second_half_len = min( - seq_len_q, - half_seq_len + - (second_half_id + 1) * chunk_size - - (2 * cp_size - 1 - cp_rank_q) * half_seq_len); - - const bool take_nothing = (last_kv_chunk_id != first_half_id); - const bool take_first_half_q = (!take_nothing) && - (last_kv_chunk_id != second_half_id); - const bool take_second_half_q = (!take_nothing) && (!take_first_half_q); - - int32_t q_seq_len = 0; - if (take_first_half_q) q_seq_len = first_half_len; - else if (take_second_half_q) q_seq_len = second_half_len; - - int32_t kv_seq_len = half_seq_len; - if (!take_nothing) kv_seq_len = last_kv_chunk_len; - - q_seq_len = min(q_seq_len, max(0, seq_plus_pad - pad_len_q)); - kv_seq_len = max(0, kv_seq_len - pad_len_kv); - - // ───────── flat output (row-major) ───────────────────────── - const int out_base = 3 * i; - - q_chunks[out_base + 0 + 1] = 0; - q_chunks[out_base + 1 + 1] = q_seq_len; - q_chunks[out_base + 2 + 1] = 0; - - q_pads [out_base + 0 + 1] = pad_start; - q_pads [out_base + 1 + 1] = pad_start + q_seq_len; - q_pads [out_base + 2 + 1] = pad_start + seq_plus_pad; - - kv_chunks[out_base + 0 + 1] = 0; - kv_chunks[out_base + 1 + 1] = kv_seq_len; - kv_chunks[out_base + 2 + 1] = 0; - - const int32_t kv_pad_base = pad_start >> 1; - kv_pads[out_base + 0 + 1] = kv_pad_base + (half_seq_len - kv_seq_len - pad_len_kv); - kv_pads[out_base + 1 + 1] = kv_pad_base + (half_seq_len - pad_len_kv); - kv_pads[out_base + 2 + 1] = kv_pad_base + half_seq_len; -} + const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_kv_halfs, + const int *__restrict__ cu_seqlens_padded, int *__restrict__ q_chunks, + int *__restrict__ kv_chunks, int *__restrict__ q_pads, int *__restrict__ kv_pads, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size, int batch) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i == 0) { + q_chunks[0] = 0; + kv_chunks[0] = 0; + q_pads[0] = 0; + kv_pads[0] = 0; + } + __syncthreads(); + if (i >= batch) return; + + // ───────── prefix-sum diffs ──────────────────────────────── + const int32_t q_start = cu_seqlens_q[i]; + const int32_t q_end = cu_seqlens_q[i + 1]; + const int32_t kv_start = cu_seqlens_kv_halfs[i]; + const int32_t kv_end = cu_seqlens_kv_halfs[i + 1]; + const int32_t pad_start = cu_seqlens_padded[i]; + const int32_t pad_end = cu_seqlens_padded[i + 1]; + + const int32_t seq_len_q = q_end - q_start; + const int32_t seq_len_kv_half = kv_end - kv_start; + const int32_t seq_plus_pad = pad_end - pad_start; + const int32_t half_seq_len = seq_plus_pad >> 1; + const int32_t pad_len_q = seq_plus_pad - seq_len_q; + const int32_t pad_len_kv = half_seq_len - seq_len_kv_half; + + // ───────── below-diagonal logic ─────────────────────────── + const int32_t last_kv_id = (cp_rank_kv + 1) * half_seq_len - 1; + const int32_t last_kv_chunk_id = last_kv_id / chunk_size; + const int32_t last_kv_chunk_len = + min(half_seq_len - (last_kv_chunk_id * chunk_size - cp_rank_kv * half_seq_len), half_seq_len); + + const int32_t first_half_id = (cp_rank_q * half_seq_len) / chunk_size; + const int32_t second_half_id = ((2 * cp_size - cp_rank_q - 1) * half_seq_len) / chunk_size; + + const int32_t first_half_len = + min(half_seq_len, (first_half_id + 1) * chunk_size - cp_rank_q * half_seq_len); + + const int32_t second_half_len = min(seq_len_q, half_seq_len + (second_half_id + 1) * chunk_size - + (2 * cp_size - 1 - cp_rank_q) * half_seq_len); + + const bool take_nothing = (last_kv_chunk_id != first_half_id); + const bool take_first_half_q = (!take_nothing) && (last_kv_chunk_id != second_half_id); + const bool take_second_half_q = (!take_nothing) && (!take_first_half_q); + + int32_t q_seq_len = 0; + if (take_first_half_q) + q_seq_len = first_half_len; + else if (take_second_half_q) + q_seq_len = second_half_len; + + int32_t kv_seq_len = half_seq_len; + if (!take_nothing) kv_seq_len = last_kv_chunk_len; + + q_seq_len = min(q_seq_len, max(0, seq_plus_pad - pad_len_q)); + kv_seq_len = max(0, kv_seq_len - pad_len_kv); + + // ───────── flat output (row-major) ───────────────────────── + const int out_base = 3 * i; + + q_chunks[out_base + 0 + 1] = 0; + q_chunks[out_base + 1 + 1] = q_seq_len; + q_chunks[out_base + 2 + 1] = 0; + + q_pads[out_base + 0 + 1] = pad_start; + q_pads[out_base + 1 + 1] = pad_start + q_seq_len; + q_pads[out_base + 2 + 1] = pad_start + seq_plus_pad; + + kv_chunks[out_base + 0 + 1] = 0; + kv_chunks[out_base + 1 + 1] = kv_seq_len; + kv_chunks[out_base + 2 + 1] = 0; + + const int32_t kv_pad_base = pad_start >> 1; + kv_pads[out_base + 0 + 1] = kv_pad_base + (half_seq_len - kv_seq_len - pad_len_kv); + kv_pads[out_base + 1 + 1] = kv_pad_base + (half_seq_len - pad_len_kv); + kv_pads[out_base + 2 + 1] = kv_pad_base + half_seq_len; +} /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. **************************************************************************************************/ - __global__ void thd_seq_tweak_above_diag_kernel( - const int* __restrict__ cu_seqlens_q_halfs, // len = batch+1 - const int* __restrict__ cu_seqlens_kv, // len = batch+1 - const int* __restrict__ cu_seqlens_padded, // len = batch+1 (full‑len prefix sums) - int* __restrict__ q_chunks, // len = 3·batch (row‑major) - int* __restrict__ kv_chunks, // len = 3·batch - int* __restrict__ q_pads, // len = 3·batch - int* __restrict__ kv_pads, // len = 3·batch - int cp_rank_q, - int cp_rank_kv, - int cp_size, - int chunk_size, - int batch) -{ - const int i = blockIdx.x * blockDim.x + threadIdx.x; - - if(i == 0) { - q_chunks[0] = 0; - kv_chunks[0] = 0; - q_pads[0] = 0; - kv_pads[0] = 0; - } - __syncthreads(); - if (i >= batch) return; - - // ───────── prefix‑sum diffs ─────────────────────────────── - const int32_t q_start = cu_seqlens_q_halfs[i]; - const int32_t q_end = cu_seqlens_q_halfs[i + 1]; - const int32_t kv_start = cu_seqlens_kv[i]; - const int32_t kv_end = cu_seqlens_kv[i + 1]; - const int32_t pad_start = cu_seqlens_padded[i]; - const int32_t pad_end = cu_seqlens_padded[i + 1]; - - const int32_t seq_len_q_half = q_end - q_start; // |Q|/2 (actual) - const int32_t seq_len_kv = kv_end - kv_start; // |KV| (full) - const int32_t seq_plus_pad = pad_end - pad_start; // |KV|+pads (full) - - const int32_t half_seq_len = seq_plus_pad >> 1; // |KV|/2 + pads/2 - - // pad lengths for later clamping - const int32_t pad_len_kv = seq_plus_pad - seq_len_kv; - - // ───────── above‑diagonal core logic ───────────────────── - // 1. Tokens from Q (first half from the *opposite* side) - const int32_t first_q_id = (2 * cp_size - 1 - cp_rank_q) * half_seq_len; - const int32_t first_q_chunk_id = first_q_id / chunk_size; - const int32_t first_q_chunk_len = min( - seq_len_q_half, - (first_q_chunk_id + 1) * chunk_size - first_q_id); - - // 2. Tokens from KV (might come from first or second half) - const int32_t first_half_kv_last_el_total_id = ((cp_rank_kv + 1) * half_seq_len) - 1; - const int32_t first_half_kv_last_chunk_id = first_half_kv_last_el_total_id / chunk_size; - - const int32_t second_half_kv_last_el_total_id = ((2 * cp_size - cp_rank_kv) * half_seq_len) - 1; - const int32_t second_half_kv_last_chunk_id = second_half_kv_last_el_total_id / chunk_size; - - // last‑trimmed chunk lengths - const int32_t first_half_kv_last_el_id_in_chunk = first_half_kv_last_el_total_id % chunk_size; - const int32_t first_half_kv_last_chunk_len = min( - seq_plus_pad, - half_seq_len + first_half_kv_last_el_id_in_chunk + 1); - - const int32_t second_half_kv_last_el_id_in_chunk = second_half_kv_last_el_total_id % chunk_size; - const int32_t second_half_kv_last_chunk_len = min( - half_seq_len, - second_half_kv_last_el_id_in_chunk + 1); - - // 3. Decide which half(s) we actually take - const bool take_nothing = (first_q_chunk_id != second_half_kv_last_chunk_id); - const bool take_second_half_kv = (!take_nothing) && (first_q_chunk_id != first_half_kv_last_chunk_id); - const bool take_first_half_kv = (!take_nothing) && (!take_second_half_kv); - - // ───────── resulting subseq lengths ─────────────────────── - int32_t q_seq_len = take_nothing ? 0 : first_q_chunk_len; - - int32_t kv_seq_len = 0; - if (take_second_half_kv) kv_seq_len = second_half_kv_last_chunk_len; - else if (take_first_half_kv) kv_seq_len = first_half_kv_last_chunk_len; - - // clamp against padding - kv_seq_len = max(0, kv_seq_len - pad_len_kv); - - // ───────── flat output (row‑major) ──────────────────────── - const int out_base = 3 * i; - - // Q chunks: [0, sequence, 0] - q_chunks[out_base + 0 + 1] = 0; // beginning of Q half‑sequence - q_chunks[out_base + 1 + 1] = q_seq_len; // after the chunk we keep - q_chunks[out_base + 2 + 1] = 0; // stays flat afterwards - - // Q pads: [0, 0, garbage] (pads live in the *first* half of padded area) - const int32_t half_pad_start = pad_start >> 1; // start of Q‑related pad area - q_pads[out_base + 0 + 1] = half_pad_start; - q_pads[out_base + 1 + 1] = half_pad_start + q_seq_len; - q_pads[out_base + 2 + 1] = half_pad_start + half_seq_len; // complete half padded length - - // KV chunks: [0, sequence, 0] - kv_chunks[out_base + 0 + 1] = 0; - kv_chunks[out_base + 1 + 1] = kv_seq_len; - kv_chunks[out_base + 2 + 1] = 0; - - // KV pads: [garbage, 0, 0] (pads precede KV if from first half) - kv_pads[out_base + 0 + 1] = pad_start + (seq_len_kv - kv_seq_len); - kv_pads[out_base + 1 + 1] = pad_start + seq_len_kv; - kv_pads[out_base + 2 + 1] = pad_start + seq_plus_pad; + const int *__restrict__ cu_seqlens_q_halfs, // len = batch+1 + const int *__restrict__ cu_seqlens_kv, // len = batch+1 + const int *__restrict__ cu_seqlens_padded, // len = batch+1 (full‑len prefix sums) + int *__restrict__ q_chunks, // len = 3·batch (row‑major) + int *__restrict__ kv_chunks, // len = 3·batch + int *__restrict__ q_pads, // len = 3·batch + int *__restrict__ kv_pads, // len = 3·batch + int cp_rank_q, int cp_rank_kv, int cp_size, int chunk_size, int batch) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i == 0) { + q_chunks[0] = 0; + kv_chunks[0] = 0; + q_pads[0] = 0; + kv_pads[0] = 0; + } + __syncthreads(); + if (i >= batch) return; + + // ───────── prefix‑sum diffs ─────────────────────────────── + const int32_t q_start = cu_seqlens_q_halfs[i]; + const int32_t q_end = cu_seqlens_q_halfs[i + 1]; + const int32_t kv_start = cu_seqlens_kv[i]; + const int32_t kv_end = cu_seqlens_kv[i + 1]; + const int32_t pad_start = cu_seqlens_padded[i]; + const int32_t pad_end = cu_seqlens_padded[i + 1]; + + const int32_t seq_len_q_half = q_end - q_start; // |Q|/2 (actual) + const int32_t seq_len_kv = kv_end - kv_start; // |KV| (full) + const int32_t seq_plus_pad = pad_end - pad_start; // |KV|+pads (full) + + const int32_t half_seq_len = seq_plus_pad >> 1; // |KV|/2 + pads/2 + + // pad lengths for later clamping + const int32_t pad_len_kv = seq_plus_pad - seq_len_kv; + + // ───────── above‑diagonal core logic ───────────────────── + // 1. Tokens from Q (first half from the *opposite* side) + const int32_t first_q_id = (2 * cp_size - 1 - cp_rank_q) * half_seq_len; + const int32_t first_q_chunk_id = first_q_id / chunk_size; + const int32_t first_q_chunk_len = + min(seq_len_q_half, (first_q_chunk_id + 1) * chunk_size - first_q_id); + + // 2. Tokens from KV (might come from first or second half) + const int32_t first_half_kv_last_el_total_id = ((cp_rank_kv + 1) * half_seq_len) - 1; + const int32_t first_half_kv_last_chunk_id = first_half_kv_last_el_total_id / chunk_size; + + const int32_t second_half_kv_last_el_total_id = ((2 * cp_size - cp_rank_kv) * half_seq_len) - 1; + const int32_t second_half_kv_last_chunk_id = second_half_kv_last_el_total_id / chunk_size; + + // last‑trimmed chunk lengths + const int32_t first_half_kv_last_el_id_in_chunk = first_half_kv_last_el_total_id % chunk_size; + const int32_t first_half_kv_last_chunk_len = + min(seq_plus_pad, half_seq_len + first_half_kv_last_el_id_in_chunk + 1); + + const int32_t second_half_kv_last_el_id_in_chunk = second_half_kv_last_el_total_id % chunk_size; + const int32_t second_half_kv_last_chunk_len = + min(half_seq_len, second_half_kv_last_el_id_in_chunk + 1); + + // 3. Decide which half(s) we actually take + const bool take_nothing = (first_q_chunk_id != second_half_kv_last_chunk_id); + const bool take_second_half_kv = + (!take_nothing) && (first_q_chunk_id != first_half_kv_last_chunk_id); + const bool take_first_half_kv = (!take_nothing) && (!take_second_half_kv); + + // ───────── resulting subseq lengths ─────────────────────── + int32_t q_seq_len = take_nothing ? 0 : first_q_chunk_len; + + int32_t kv_seq_len = 0; + if (take_second_half_kv) + kv_seq_len = second_half_kv_last_chunk_len; + else if (take_first_half_kv) + kv_seq_len = first_half_kv_last_chunk_len; + + // clamp against padding + kv_seq_len = max(0, kv_seq_len - pad_len_kv); + + // ───────── flat output (row‑major) ──────────────────────── + const int out_base = 3 * i; + + // Q chunks: [0, sequence, 0] + q_chunks[out_base + 0 + 1] = 0; // beginning of Q half‑sequence + q_chunks[out_base + 1 + 1] = q_seq_len; // after the chunk we keep + q_chunks[out_base + 2 + 1] = 0; // stays flat afterwards + + // Q pads: [0, 0, garbage] (pads live in the *first* half of padded area) + const int32_t half_pad_start = pad_start >> 1; // start of Q‑related pad area + q_pads[out_base + 0 + 1] = half_pad_start; + q_pads[out_base + 1 + 1] = half_pad_start + q_seq_len; + q_pads[out_base + 2 + 1] = half_pad_start + half_seq_len; // complete half padded length + + // KV chunks: [0, sequence, 0] + kv_chunks[out_base + 0 + 1] = 0; + kv_chunks[out_base + 1 + 1] = kv_seq_len; + kv_chunks[out_base + 2 + 1] = 0; + + // KV pads: [garbage, 0, 0] (pads precede KV if from first half) + kv_pads[out_base + 0 + 1] = pad_start + (seq_len_kv - kv_seq_len); + kv_pads[out_base + 1 + 1] = pad_start + seq_len_kv; + kv_pads[out_base + 2 + 1] = pad_start + seq_plus_pad; } - /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor @@ -1057,16 +1019,9 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. **************************************************************************************************/ -void thd_chunkify( - const Tensor &cu_seqlens, - const Tensor &cu_seqlens_padded, - Tensor &out_cu_seqlens, - Tensor &out_cu_seqlens_padded, - int batch, - int output_len, - int chunk_size, - cudaStream_t stream -) { +void thd_chunkify(const Tensor &cu_seqlens, const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens, + Tensor &out_cu_seqlens_padded, int batch, int output_len, int chunk_size, + cudaStream_t stream) { using namespace transformer_engine; NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); @@ -1091,18 +1046,10 @@ void thd_chunkify( * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. **************************************************************************************************/ -void thd_chunkify_p2p( - const Tensor &cu_seqlens, - const Tensor &cu_seqlens_padded, - Tensor &out_cu_seqlens, - Tensor &out_cu_seqlens_padded, - int batch, - int output_len, - int chunk_size, - int cp_rank, - int cp_size, - cudaStream_t stream -) { +void thd_chunkify_p2p(const Tensor &cu_seqlens, const Tensor &cu_seqlens_padded, + Tensor &out_cu_seqlens, Tensor &out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, int cp_rank, int cp_size, + cudaStream_t stream) { using namespace transformer_engine; NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); @@ -1115,36 +1062,26 @@ void thd_chunkify_p2p( NVTE_CHECK(out_cu_seqlens.dim() == 1); NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); - const unsigned int block = 256; const unsigned int grid = (output_len + block - 1) / block; thd_chunkify_p2p_kernel<<>>( reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(cu_seqlens_padded.data.dptr), reinterpret_cast(out_cu_seqlens.data.dptr), - reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size, cp_rank, cp_size); + reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size, + cp_rank, cp_size); } /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. **************************************************************************************************/ -void thd_seq_tweak_below_diag( - const Tensor &cu_seqlens_q, - const Tensor &cu_seqlens_kv_halfs, - const Tensor &cu_seqlens_padded, - Tensor &out_cu_seqlens_q, - Tensor &out_cu_seqlens_kv_halfs, - Tensor &out_cu_seqlens_q_padded, - Tensor &out_cu_seqlens_kv_halfs_padded, - int batch, - int output_len, - int chunk_size, - int cp_rank_q, - int cp_rank_kv, - int cp_size, - cudaStream_t stream -) { +void thd_seq_tweak_below_diag(const Tensor &cu_seqlens_q, const Tensor &cu_seqlens_kv_halfs, + const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens_q, + Tensor &out_cu_seqlens_kv_halfs, Tensor &out_cu_seqlens_q_padded, + Tensor &out_cu_seqlens_kv_halfs_padded, int batch, int output_len, + int chunk_size, int cp_rank_q, int cp_rank_kv, int cp_size, + cudaStream_t stream) { using namespace transformer_engine; NVTE_CHECK(cu_seqlens_q.dtype() == DType::kInt32); @@ -1167,36 +1104,20 @@ void thd_seq_tweak_below_diag( reinterpret_cast(out_cu_seqlens_q.data.dptr), reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), reinterpret_cast(out_cu_seqlens_q_padded.data.dptr), - reinterpret_cast(out_cu_seqlens_kv_halfs_padded.data.dptr), - cp_rank_q, - cp_rank_kv, - cp_size, - chunk_size, - batch - ); + reinterpret_cast(out_cu_seqlens_kv_halfs_padded.data.dptr), cp_rank_q, cp_rank_kv, + cp_size, chunk_size, batch); } /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. **************************************************************************************************/ - -void thd_seq_tweak_above_diag( - const Tensor &cu_seqlens_q_halfs, - const Tensor &cu_seqlens_kv, - const Tensor &cu_seqlens_padded, - Tensor &out_cu_seqlens_q_halfs, - Tensor &out_cu_seqlens_kv, - Tensor &out_cu_seqlens_q_halfs_padded, - Tensor &out_cu_seqlens_kv_padded, - int cp_rank_q, - int cp_rank_kv, - int cp_size, - int batch, - int output_len, - int chunk_size, - cudaStream_t stream -) { +void thd_seq_tweak_above_diag(const Tensor &cu_seqlens_q_halfs, const Tensor &cu_seqlens_kv, + const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens_q_halfs, + Tensor &out_cu_seqlens_kv, Tensor &out_cu_seqlens_q_halfs_padded, + Tensor &out_cu_seqlens_kv_padded, int cp_rank_q, int cp_rank_kv, + int cp_size, int batch, int output_len, int chunk_size, + cudaStream_t stream) { using namespace transformer_engine; NVTE_CHECK(cu_seqlens_q_halfs.dtype() == DType::kInt32); @@ -1225,13 +1146,8 @@ void thd_seq_tweak_above_diag( reinterpret_cast(out_cu_seqlens_q_halfs.data.dptr), reinterpret_cast(out_cu_seqlens_kv.data.dptr), reinterpret_cast(out_cu_seqlens_q_halfs_padded.data.dptr), - reinterpret_cast(out_cu_seqlens_kv_padded.data.dptr), - cp_rank_q, - cp_rank_kv, - cp_size, - chunk_size, - batch - ); + reinterpret_cast(out_cu_seqlens_kv_padded.data.dptr), cp_rank_q, cp_rank_kv, cp_size, + chunk_size, batch); // synchronize cudaStreamSynchronize(stream); } @@ -1310,67 +1226,57 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso } void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, - NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, - int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, cudaStream_t stream) { NVTE_API_CALL(nvte_thd_chunkify); using namespace transformer_engine; - context_parallel::thd_chunkify(*convertNVTETensorCheck(cu_seqlens), - *convertNVTETensorCheck(cu_seqlens_padded), - *convertNVTETensorCheck(out_cu_seqlens), - *convertNVTETensorCheck(out_cu_seqlens_padded), batch, output_len, chunk_size, stream); + context_parallel::thd_chunkify( + *convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens), *convertNVTETensorCheck(out_cu_seqlens_padded), + batch, output_len, chunk_size, stream); } void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, - int batch, int output_len, int chunk_size, int cp_rank, int cp_size, + int batch, int output_len, int chunk_size, int cp_rank, int cp_size, cudaStream_t stream) { NVTE_API_CALL(nvte_thd_chunkify_p2p); using namespace transformer_engine; - context_parallel::thd_chunkify_p2p(*convertNVTETensorCheck(cu_seqlens), - *convertNVTETensorCheck(cu_seqlens_padded), - *convertNVTETensorCheck(out_cu_seqlens), - *convertNVTETensorCheck(out_cu_seqlens_padded), - batch, output_len, chunk_size, cp_rank, cp_size, stream); + context_parallel::thd_chunkify_p2p( + *convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens), *convertNVTETensorCheck(out_cu_seqlens_padded), + batch, output_len, chunk_size, cp_rank, cp_size, stream); } -void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, - const NVTETensor &cu_seqlens_padded, - NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, - NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, - int cp_rank_q, int cp_rank_kv, int cp_size, - int batch, int output_len, int chunk_size, cudaStream_t stream) { +void nvte_cp_thd_seq_tweak_below_diag( + const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream) { NVTE_API_CALL(nvte_thd_seq_tweak_below_diag); using namespace transformer_engine; - context_parallel::thd_seq_tweak_below_diag(*convertNVTETensorCheck(cu_seqlens_q), - *convertNVTETensorCheck(cu_seqlens_kv_halfs), - *convertNVTETensorCheck(cu_seqlens_padded), - *convertNVTETensorCheck(out_cu_seqlens_q), - *convertNVTETensorCheck(out_cu_seqlens_kv), - *convertNVTETensorCheck(out_cu_seqlens_q_padded), - *convertNVTETensorCheck(out_cu_seqlens_kv_padded), - batch, output_len, chunk_size, cp_rank_q, cp_rank_kv, cp_size, stream); + context_parallel::thd_seq_tweak_below_diag( + *convertNVTETensorCheck(cu_seqlens_q), *convertNVTETensorCheck(cu_seqlens_kv_halfs), + *convertNVTETensorCheck(cu_seqlens_padded), *convertNVTETensorCheck(out_cu_seqlens_q), + *convertNVTETensorCheck(out_cu_seqlens_kv), *convertNVTETensorCheck(out_cu_seqlens_q_padded), + *convertNVTETensorCheck(out_cu_seqlens_kv_padded), batch, output_len, chunk_size, cp_rank_q, + cp_rank_kv, cp_size, stream); } -void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, - const NVTETensor &cu_seqlens_padded, - NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, - NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, - int cp_rank_q, int cp_rank_kv, int cp_size, - int batch, int output_len, - int chunk_size, cudaStream_t stream) { - NVTE_API_CALL(nvte_thd_seq_tweak_above_diag); +void nvte_cp_thd_seq_tweak_above_diag( + const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_seq_tweak_above_diag); using namespace transformer_engine; - context_parallel::thd_seq_tweak_above_diag(*convertNVTETensorCheck(cu_seqlens_q_halfs), - *convertNVTETensorCheck(cu_seqlens_kv), - *convertNVTETensorCheck(cu_seqlens_padded), - *convertNVTETensorCheck(out_cu_seqlens_q), - *convertNVTETensorCheck(out_cu_seqlens_kv), - *convertNVTETensorCheck(out_cu_seqlens_q_padded), - *convertNVTETensorCheck(out_cu_seqlens_kv_padded), - cp_rank_q, cp_rank_kv, cp_size, - batch, output_len, chunk_size, stream); + context_parallel::thd_seq_tweak_above_diag( + *convertNVTETensorCheck(cu_seqlens_q_halfs), *convertNVTETensorCheck(cu_seqlens_kv), + *convertNVTETensorCheck(cu_seqlens_padded), *convertNVTETensorCheck(out_cu_seqlens_q), + *convertNVTETensorCheck(out_cu_seqlens_kv), *convertNVTETensorCheck(out_cu_seqlens_q_padded), + *convertNVTETensorCheck(out_cu_seqlens_kv_padded), cp_rank_q, cp_rank_kv, cp_size, batch, + output_len, chunk_size, stream); } - diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 7b9c2d15a5..b6ee193b1f 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -757,8 +757,8 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso */ void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, - NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, - int batch, int output_len, int chunk_size, cudaStream_t stream); + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, cudaStream_t stream); /*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. * @@ -778,9 +778,9 @@ void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seq void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, - int batch, int output_len, int chunk_size, int cp_rank, int cp_size, + int batch, int output_len, int chunk_size, int cp_rank, int cp_size, cudaStream_t stream); - + /*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. * * \warning This API is **experimental** and subject to change. @@ -801,12 +801,11 @@ void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu * \param[in] stream CUDA stream used for this operation. */ -void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, - const NVTETensor &cu_seqlens_padded, - NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, - NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, - int cp_rank_q, int cp_rank_kv, int cp_size, - int batch, int output_len, int chunk_size, cudaStream_t stream); +void nvte_cp_thd_seq_tweak_below_diag( + const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream); /*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. * @@ -821,14 +820,11 @@ void nvte_cp_thd_seq_tweak_below_diag(const NVTETensor &cu_seqlens_q, const NVTE * \param[in] stream CUDA stream used for this operation. */ -void nvte_cp_thd_seq_tweak_above_diag(const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, - const NVTETensor &cu_seqlens_padded, - NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, - NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, - int cp_rank_q, int cp_rank_kv, int cp_size, - int batch, int output_len, - int chunk_size, cudaStream_t stream); - +void nvte_cp_thd_seq_tweak_above_diag( + const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream); /*! \brief Convert tensor from THD to BSHD format. * diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index f52136f5f9..43c2f95ee1 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -84,7 +84,7 @@ def feed(self, tensor, iteration): # It is used for weights and microbatching. if self.modified[0] and not self.reduce_within_microbatch: return - + # We do not feed the tensor with 0 elements, # we behave the same way as if feed() was not called. if tensor.numel() == 0: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 03f52808dc..016885e674 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -770,10 +770,26 @@ def forward( thd_total_seq_len = q.shape[0] q_inputs[i % 2] = q if chunk_size is not None: - cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i] = dpa_utils.thd_chunkify_p2p( - cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i], chunk_size, rank, cp_size, thd_total_seq_len) - cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i] = dpa_utils.thd_chunkify_p2p( - cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i], chunk_size, rank, cp_size, thd_total_seq_len) + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i] = ( + dpa_utils.thd_chunkify_p2p( + cu_seqlens_q_per_step[i], + cu_seqlens_q_padded_per_step[i], + chunk_size, + rank, + cp_size, + thd_total_seq_len, + ) + ) + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i] = ( + dpa_utils.thd_chunkify_p2p( + cu_seqlens_kv_per_step[i], + cu_seqlens_kv_padded_per_step[i], + chunk_size, + rank, + cp_size, + thd_total_seq_len, + ) + ) if use_fused_attention: if attn_bias is not None: @@ -836,7 +852,6 @@ def forward( **fp8_meta_kwargs, ) - if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: @@ -900,8 +915,11 @@ def forward( cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded - cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded // 2 \ - if cu_seqlens_kv_padded is not None else None + cu_seqlens_kv_padded_per_step[i] = ( + cu_seqlens_kv_padded // 2 + if cu_seqlens_kv_padded is not None + else None + ) if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -925,7 +943,7 @@ def forward( elif qkv_format == "thd": q_inputs[i % 2] = q # [2, t, np, hn] -> [2, t/2, np, hn] - + if enable_mla: assert chunk_size is None # [t, np, hn] -> [t/2, np, hn] @@ -938,11 +956,20 @@ def forward( else: # [2, t, np, hn] -> [2, t/2, np, hn] if chunk_size is not None: - cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded_per_step[i], cu_seqlens_kv_padded_per_step[i] =\ - dpa_utils.thd_seq_tweak_below_diagonal( - cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded,\ - rank, rank - i, cp_size, chunk_size - ) + ( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded_per_step[i], + ) = dpa_utils.thd_seq_tweak_below_diagonal( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded, + rank, + rank - i, + cp_size, + chunk_size, + ) kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) @@ -1070,9 +1097,12 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q_half cu_seqlens_kv_per_step[i] = cu_seqlens_kv - - cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded // 2 \ - if cu_seqlens_q_padded is not None else None + + cu_seqlens_q_padded_per_step[i] = ( + cu_seqlens_q_padded // 2 + if cu_seqlens_q_padded is not None + else None + ) cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] @@ -1104,12 +1134,21 @@ def forward( q, cu_seqlens_q_padded, 1 ) if chunk_size is not None: - cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded_per_step[i], cu_seqlens_kv_padded_per_step[i] =\ - dpa_utils.thd_seq_tweak_above_diagonal( - cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], cu_seqlens_q_padded,\ - rank, rank - i + cp_size, cp_size, chunk_size + ( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded_per_step[i], + ) = dpa_utils.thd_seq_tweak_above_diagonal( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded, + rank, + rank - i + cp_size, + cp_size, + chunk_size, ) - + if use_fused_attention: q_inputs[i % 2] = q_inputs[i % 2].contiguous() if attn_bias is not None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7ee852a9a1..9eef6b8b57 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -265,7 +265,9 @@ def __init__( num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" - assert chunk_size is None or qkv_format == "thd", "Chunk size is only supported for thd format" + assert ( + chunk_size is None or qkv_format == "thd" + ), "Chunk size is only supported for thd format" self.chunk_size = chunk_size self.rng_states_tracker = None @@ -755,9 +757,11 @@ def forward( if self.chunk_size is not None and self.cp_group is None: total_seq_len = query_layer.shape[0] cu_seqlens_q, cu_seqlens_q_padded = dpa_utils.thd_chunkify( - cu_seqlens_q, cu_seqlens_q_padded, self.chunk_size, total_seq_len) + cu_seqlens_q, cu_seqlens_q_padded, self.chunk_size, total_seq_len + ) cu_seqlens_kv, cu_seqlens_kv_padded = dpa_utils.thd_chunkify( - cu_seqlens_kv, cu_seqlens_kv_padded, self.chunk_size, total_seq_len) + cu_seqlens_kv, cu_seqlens_kv_padded, self.chunk_size, total_seq_len + ) batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: @@ -771,7 +775,6 @@ def forward( else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - # update KV cache and retrieve saved tokens from cache for inference if inference_params is not None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 82b1b6e411..6b487a98de 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1806,6 +1806,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + def thd_chunkify( cu_seqlens: torch.Tensor, cu_seqlens_padded: torch.Tensor, @@ -1817,10 +1818,12 @@ def thd_chunkify( Returns new cu_seqlens, cu_seqlens_padded tensors """ new_cu_seqlens, new_cu_seqlens_padded = tex.thd_chunkify( - cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size) + cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size + ) return new_cu_seqlens, new_cu_seqlens_padded + def thd_chunkify_p2p( cu_seqlens: torch.Tensor, cu_seqlens_padded: torch.Tensor, @@ -1833,12 +1836,14 @@ def thd_chunkify_p2p( Chunkify the cu_seqlens tensor. Returns new cu_seqlens, cu_seqlens_padded tensors """ - + new_cu_seqlens, new_cu_seqlens_padded = tex.thd_chunkify_p2p( - cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size, cp_rank, cp_size) + cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size, cp_rank, cp_size + ) return new_cu_seqlens, new_cu_seqlens_padded + def thd_seq_tweak_below_diagonal( cu_seqlens_q: torch.Tensor, cu_seqlens_kv_halfs: torch.Tensor, @@ -1850,14 +1855,28 @@ def thd_seq_tweak_below_diagonal( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert cp_rank_q > cp_rank_kv - new_seqlens_q, new_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = tex.thd_seq_tweak_below_diag( - cu_seqlens_q, cu_seqlens_kv_halfs, cu_seqlens_padded, cp_rank_q, cp_rank_kv, cp_size, chunk_size + new_seqlens_q, new_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = ( + tex.thd_seq_tweak_below_diag( + cu_seqlens_q, + cu_seqlens_kv_halfs, + cu_seqlens_padded, + cp_rank_q, + cp_rank_kv, + cp_size, + chunk_size, + ) ) - - new_cu_seqlens_q = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) + + new_cu_seqlens_q = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) new_cu_seqlens_kv_halfs = torch.cumsum(new_seqlens_kv_halfs, dim=0, dtype=torch.int32) - return new_cu_seqlens_q, new_cu_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded + return ( + new_cu_seqlens_q, + new_cu_seqlens_kv_halfs, + new_cu_seqlens_q_padded, + new_cu_seqlens_kv_padded, + ) + def thd_seq_tweak_above_diagonal( cu_seqlens_q_halfs: torch.Tensor, @@ -1870,11 +1889,23 @@ def thd_seq_tweak_above_diagonal( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert cp_rank_q < cp_rank_kv - new_seqlens_q, new_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = tex.thd_seq_tweak_above_diag( - cu_seqlens_q_halfs, cu_seqlens_kv, cu_seqlens_padded, - cp_rank_q, cp_rank_kv, cp_size, chunk_size + new_seqlens_q, new_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = ( + tex.thd_seq_tweak_above_diag( + cu_seqlens_q_halfs, + cu_seqlens_kv, + cu_seqlens_padded, + cp_rank_q, + cp_rank_kv, + cp_size, + chunk_size, + ) ) new_cu_seqlens_q_halfs = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) new_cu_seqlens_kv = torch.cumsum(new_seqlens_kv, dim=0, dtype=torch.int32) - return new_cu_seqlens_q_halfs, new_cu_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded \ No newline at end of file + return ( + new_cu_seqlens_q_halfs, + new_cu_seqlens_kv, + new_cu_seqlens_q_padded, + new_cu_seqlens_kv_padded, + ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 3eca9dc51b..b9810bf861 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -304,7 +304,6 @@ def fused_attn_fwd( rng_elts_per_thread, ) - # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7be4dc0435..cadb8914fd 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -301,18 +301,21 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); -std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, - int total_seq_len, int chunk_size) ; +std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size); -std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, - int total_seq_len, int chunk_size, int cp_rank, int cp_size) ; +std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size, int cp_rank, + int cp_size); -std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Tensor cu_seqlens_kv_halfs, - at::Tensor cu_seqlens_padded, int cp_rank_q, +std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, + at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, int cp_rank_kv, int cp_size, int chunk_size); -std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, at::Tensor cu_seqlens_kv, - at::Tensor cu_seqlens_padded, int cp_rank_q, +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, + at::Tensor cu_seqlens_kv, + at::Tensor cu_seqlens_padded, int cp_rank_q, int cp_rank_kv, int cp_size, int chunk_size); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index a4c5f6bab7..e18c88fc37 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -739,12 +739,11 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t return output; } - /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. **************************************************************************************************/ -std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, +std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, int total_seq_len, int chunk_size) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); @@ -767,11 +766,9 @@ std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlen auto te_out_cu_seqlens = makeTransformerEngineTensor(out_cu_seqlens); auto te_out_cu_seqlens_padded = makeTransformerEngineTensor(out_cu_seqlens_padded); - nvte_cp_thd_chunkify( - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), - batch, output_len, chunk_size, at::cuda::getCurrentCUDAStream() - ); + nvte_cp_thd_chunkify(te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_out_cu_seqlens.data(), + te_out_cu_seqlens_padded.data(), batch, output_len, chunk_size, + at::cuda::getCurrentCUDAStream()); return {out_cu_seqlens, out_cu_seqlens_padded}; } @@ -780,8 +777,9 @@ std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlen * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. **************************************************************************************************/ -std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, - int total_seq_len, int chunk_size, int cp_rank, int cp_size) { +std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size, int cp_rank, + int cp_size) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.size(0) >= 2); @@ -798,11 +796,10 @@ std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_se auto te_out_cu_seqlens = makeTransformerEngineTensor(out_cu_seqlens); auto te_out_cu_seqlens_padded = makeTransformerEngineTensor(out_cu_seqlens_padded); - nvte_cp_thd_chunkify_p2p( - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), - batch, output_len, chunk_size, cp_rank, cp_size, at::cuda::getCurrentCUDAStream() - ); + nvte_cp_thd_chunkify_p2p(te_cu_seqlens.data(), te_cu_seqlens_padded.data(), + te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), batch, + output_len, chunk_size, cp_rank, cp_size, + at::cuda::getCurrentCUDAStream()); return {out_cu_seqlens, out_cu_seqlens_padded}; } @@ -811,8 +808,9 @@ std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_se * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. **************************************************************************************************/ -std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Tensor cu_seqlens_kv_halfs, - at::Tensor cu_seqlens_padded, int cp_rank_q, +std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, + at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, int cp_rank_kv, int cp_size, int chunk_size) { NVTE_CHECK(cu_seqlens_q.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens_kv_halfs.scalar_type() == at::ScalarType::Int); @@ -845,23 +843,21 @@ std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, at::Te auto te_out_cu_seqlens_kv_padded = makeTransformerEngineTensor(out_cu_seqlens_kv_padded); nvte_cp_thd_seq_tweak_below_diag( - te_cu_seqlens_q.data(), te_cu_seqlens_kv_halfs.data(), - te_cu_seqlens_padded.data(), te_out_seqlens_q.data(), te_out_seqlens_kv.data(), - te_out_cu_seqlens_q_padded.data(), te_out_cu_seqlens_kv_padded.data(), - cp_rank_q, cp_rank_kv, cp_size, batch, output_len, chunk_size, at::cuda::getCurrentCUDAStream() - ); + te_cu_seqlens_q.data(), te_cu_seqlens_kv_halfs.data(), te_cu_seqlens_padded.data(), + te_out_seqlens_q.data(), te_out_seqlens_kv.data(), te_out_cu_seqlens_q_padded.data(), + te_out_cu_seqlens_kv_padded.data(), cp_rank_q, cp_rank_kv, cp_size, batch, output_len, + chunk_size, at::cuda::getCurrentCUDAStream()); return {out_seqlens_q, out_seqlens_kv, out_cu_seqlens_q_padded, out_cu_seqlens_kv_padded}; } - - /*************************************************************************************************** * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. **************************************************************************************************/ -std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, at::Tensor cu_seqlens_kv, - at::Tensor cu_seqlens_padded, int cp_rank_q, +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, + at::Tensor cu_seqlens_kv, + at::Tensor cu_seqlens_padded, int cp_rank_q, int cp_rank_kv, int cp_size, int chunk_size) { NVTE_CHECK(cu_seqlens_q_halfs.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens_kv.scalar_type() == at::ScalarType::Int); @@ -884,16 +880,14 @@ std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, auto te_out_cu_seqlens_kv_padded = makeTransformerEngineTensor(out_cu_seqlens_kv_padded); nvte_cp_thd_seq_tweak_above_diag( - te_cu_seqlens_q_halfs.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_padded.data(), te_out_seqlens_q_halfs.data(), te_out_seqlens_kv.data(), - te_out_cu_seqlens_q_padded.data(), te_out_cu_seqlens_kv_padded.data(), - cp_rank_q, cp_rank_kv, cp_size, batch, output_len, chunk_size, at::cuda::getCurrentCUDAStream() - ); + te_cu_seqlens_q_halfs.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_padded.data(), + te_out_seqlens_q_halfs.data(), te_out_seqlens_kv.data(), te_out_cu_seqlens_q_padded.data(), + te_out_cu_seqlens_kv_padded.data(), cp_rank_q, cp_rank_kv, cp_size, batch, output_len, + chunk_size, at::cuda::getCurrentCUDAStream()); return {out_seqlens_q_halfs, out_seqlens_kv, out_cu_seqlens_q_padded, out_cu_seqlens_kv_padded}; } - /*************************************************************************************************** * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 03c90e3e7c..3430c710ad 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -281,13 +281,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Generate partitioned indices for inputs in THD format", py::call_guard()); m.def("thd_chunkify", &transformer_engine::pytorch::thd_chunkify, "Chunkify THD tensor", - py::arg("cu_seqlens"), py::arg("cu_seqlens_padded"), py::arg("total_seq_len"), py::arg("chunk_size"), - py::call_guard()); - m.def("thd_chunkify_p2p", &transformer_engine::pytorch::thd_chunkify_p2p, "Chunkify THD tensor for P2P communication", - py::call_guard()); - m.def("thd_seq_tweak_below_diag", &transformer_engine::pytorch::thd_seq_tweak_below_diag, "Tweak the sequence below the diagonal for THD tensor", - py::call_guard()); - m.def("thd_seq_tweak_above_diag", &transformer_engine::pytorch::thd_seq_tweak_above_diag, "Tweak the sequence above the diagonal for THD tensor", + py::arg("cu_seqlens"), py::arg("cu_seqlens_padded"), py::arg("total_seq_len"), + py::arg("chunk_size"), py::call_guard()); + m.def("thd_chunkify_p2p", &transformer_engine::pytorch::thd_chunkify_p2p, + "Chunkify THD tensor for P2P communication", py::call_guard()); + m.def("thd_seq_tweak_below_diag", &transformer_engine::pytorch::thd_seq_tweak_below_diag, + "Tweak the sequence below the diagonal for THD tensor", + py::call_guard()); + m.def("thd_seq_tweak_above_diag", &transformer_engine::pytorch::thd_seq_tweak_above_diag, + "Tweak the sequence above the diagonal for THD tensor", py::call_guard()); // nvshmem functions