diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index ad3bc32079..1e46ecb8c8 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -95,6 +95,7 @@ def run_dpa_with_cp( qkv_format=qkv_format, attn_mask_type=config.attn_mask_type, window_size=config.window_size, + chunk_size=config.chunk_size, ) core_attn = core_attn.cuda() @@ -284,6 +285,7 @@ def run_dpa_with_cp( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) + if fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) @@ -401,8 +403,12 @@ def _error(a, b): _error(a[0], b[0]) _error(a[1], b[1]) elif qkv_format == "thd": + i = 0 for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): _error(a, b) + str_names = ["out_", "dq_", "dk_", "dv_"] + print(f"{str_names[i]} passed on rank {rank}") + i += 1 else: assert False, f"{qkv_format} is an unsupported qkv_format!" diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index a05e64fca3..87b77b9b11 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -90,6 +90,7 @@ def __init__( window_size: Tuple[int, int] = (-1, -1), total_requests: int = None, max_ctx_len: int = None, + chunk_size: int = None, ): self.batch_size = batch_size self.num_heads = num_heads @@ -110,6 +111,7 @@ def __init__( self.window_size = window_size self.total_requests = total_requests self.max_ctx_len = max_ctx_len + self.chunk_size = chunk_size @contextmanager diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 4ecc54b530..385e39bbfd 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -24,6 +24,9 @@ "cp_1_3": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) ), # MHA + "cp_1_4": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024 + ), # MHA with chunks "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig( @@ -100,6 +103,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): "cp_1_4": ModelConfig( 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) ), # MHA + "cp_1_5": ModelConfig( + 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", chunk_size=1024 + ), # MHA "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA @@ -144,6 +150,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha config = model_configs_fused_attn[model] if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") + if qkv_format != "thd" and config.chunk_size is not None: + pytest.skip("Only THD format supports chunking!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if qkv_format == "thd" and "a2a" in cp_comm_type: diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 15708d2d59..e8e2fddbbb 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -302,6 +302,352 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in } } +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks. + **************************************************************************************************/ + +__global__ void thd_chunkify_kernel(const int *__restrict__ d_cu_seqlens, + const int *__restrict__ d_cu_seqlens_padded, + int *__restrict__ d_out_cu_seqlens, + int *__restrict__ d_out_cu_seqlens_padded, + int batch, // = len(cu_seqlens)-1 + int output_len, int chunk_size) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= output_len) return; + if (i == 0) { + d_out_cu_seqlens[0] = 0; + d_out_cu_seqlens_padded[0] = 0; + return; + } + + int pos_id = 0; + int seq_len = 0; + int pad_len = 0; + int seq_start = d_cu_seqlens[batch]; + int pad_start = d_cu_seqlens_padded[batch]; + int cur_start = 0; + int cur_last = -1; + bool found = false; + + for (int j = 0; j < batch; ++j) { + cur_start = (j == 0) ? 1 : cur_last + 1; + + int d_cu_seqlens_padded_j = d_cu_seqlens_padded[j]; + int d_cu_seqlens_padded_j_1 = d_cu_seqlens_padded[j + 1]; + int d_cu_seqlens_j = d_cu_seqlens[j]; + int d_cu_seqlens_j_1 = d_cu_seqlens[j + 1]; + + int num_chunks = + (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j + (chunk_size - 1)) / chunk_size; + cur_last = cur_start + num_chunks - 1; + + bool match = (i >= cur_start) && (i <= cur_last); + pos_id = match ? (i - cur_start) : pos_id; + seq_len = match ? (d_cu_seqlens_j_1 - d_cu_seqlens_j) : seq_len; + pad_len = match ? (d_cu_seqlens_padded_j_1 - d_cu_seqlens_padded_j) : pad_len; + seq_start = match ? d_cu_seqlens_j : seq_start; + pad_start = match ? d_cu_seqlens_padded_j : pad_start; + found = match || found; + } + + if (!found) { + d_out_cu_seqlens[i] = d_cu_seqlens[batch]; + d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; + } else { + int32_t out_seq = + ((pos_id > (seq_len / chunk_size)) ? seq_len : chunk_size * pos_id) + seq_start; + int32_t out_pad = + ((pos_id > (pad_len / chunk_size)) ? pad_len : chunk_size * pos_id) + pad_start; + + d_out_cu_seqlens[i] = out_seq; + d_out_cu_seqlens_padded[i] = out_pad; + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +__global__ void thd_chunkify_p2p_kernel(const int *__restrict__ d_cu_seqlens, + const int *__restrict__ d_cu_seqlens_padded, + int *__restrict__ d_out_cu_seqlens, + int *__restrict__ d_out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, int cp_rank, int cp_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= output_len) return; + if (i == 0) { + d_out_cu_seqlens[0] = 0; + d_out_cu_seqlens_padded[0] = 0; + return; + } + + int pos_id = 0; + int seq_len = 0; + int total_seq_len = 0; + int seq_start_offset = d_cu_seqlens[batch]; + int pad_start_offset = d_cu_seqlens_padded[batch]; + int cur_start = 1; + int cur_last = 0; + + for (int j = 0; j < batch; ++j) { + cur_start = (j == 0) ? 1 : cur_last + 1; + int padded_len = d_cu_seqlens_padded[j + 1] - d_cu_seqlens_padded[j]; + int num_chunks = (padded_len + chunk_size - 1) / chunk_size; + cur_last = cur_start + num_chunks + 3; + if (i >= cur_start && (i <= cur_last || j == batch - 1)) { + pos_id = i - cur_start; + seq_len = d_cu_seqlens[j + 1] - d_cu_seqlens[j]; + total_seq_len = padded_len; + seq_start_offset = d_cu_seqlens[j]; + pad_start_offset = d_cu_seqlens_padded[j]; + break; + } + } + + if (total_seq_len == 0) { + d_out_cu_seqlens[i] = d_cu_seqlens[batch]; + d_out_cu_seqlens_padded[i] = d_cu_seqlens_padded[batch]; + return; + } + + int middle = total_seq_len / 2; + + int start_id_1 = (total_seq_len * cp_rank) / 2; + int temp = (chunk_size - start_id_1 - 1) % chunk_size; + int first_chunk_size_1 = ((temp < 0) ? temp + chunk_size : temp) + 1; + first_chunk_size_1 = (first_chunk_size_1 >= middle) ? 0 : first_chunk_size_1; + int num_chunks_1 = (middle - first_chunk_size_1 - 1) / chunk_size; + int last_chunk_size_1 = middle - num_chunks_1 * chunk_size - first_chunk_size_1; + int last_token_id_1 = start_id_1 + middle - 1; + int last_chunk_id_1 = last_token_id_1 / chunk_size; + + int start_id_2 = total_seq_len * cp_size - last_token_id_1 - 1; + int first_chunk_id_2 = start_id_2 / chunk_size; + int temp2 = (chunk_size - start_id_2 - 1) % chunk_size; + int first_chunk_size_2 = ((temp2 < 0) ? temp2 + chunk_size : temp2) + 1; + int num_chunks_2 = (middle - first_chunk_size_2) / chunk_size; + + bool merge_chunks = (last_chunk_id_1 == first_chunk_id_2); + int middle_chunk_1 = merge_chunks ? (last_chunk_size_1 + first_chunk_size_2) : last_chunk_size_1; + int middle_chunk_2 = merge_chunks ? 0 : first_chunk_size_2; + + int out_pad_len = 0; + if (pos_id == 0) { + out_pad_len = first_chunk_size_1; + } else if (pos_id <= num_chunks_1 + 1) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size; + } else if (pos_id == num_chunks_1 + 2) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1; + } else if (pos_id == num_chunks_1 + 3) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2; + } else if (pos_id <= num_chunks_1 + 3 + num_chunks_2) { + out_pad_len = first_chunk_size_1 + num_chunks_1 * chunk_size + middle_chunk_1 + middle_chunk_2 + + num_chunks_2 * chunk_size; + } else { + out_pad_len = total_seq_len; + } + + int out_seq_len = (out_pad_len < seq_len) ? out_pad_len : seq_len; + + d_out_cu_seqlens[i] = seq_start_offset + out_seq_len; + d_out_cu_seqlens_padded[i] = pad_start_offset + out_pad_len; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + **************************************************************************************************/ + +__global__ void thd_seq_tweak_below_diag_kernel( + const int *__restrict__ cu_seqlens_q, const int *__restrict__ cu_seqlens_kv_halfs, + const int *__restrict__ cu_seqlens_padded, int *__restrict__ q_chunks, + int *__restrict__ kv_chunks, int *__restrict__ q_pads, int *__restrict__ kv_pads, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size, int batch) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i == 0) { + q_chunks[0] = 0; + kv_chunks[0] = 0; + q_pads[0] = 0; + kv_pads[0] = 0; + } + __syncthreads(); + + if (i >= batch) return; + + // ───────── prefix-sum diffs ──────────────────────────────── + const int32_t q_start = cu_seqlens_q[i]; + const int32_t q_end = cu_seqlens_q[i + 1]; + const int32_t kv_start = cu_seqlens_kv_halfs[i]; + const int32_t kv_end = cu_seqlens_kv_halfs[i + 1]; + const int32_t pad_start = cu_seqlens_padded[i]; + const int32_t pad_end = cu_seqlens_padded[i + 1]; + + const int32_t seq_len_q = q_end - q_start; + const int32_t seq_len_kv_half = kv_end - kv_start; + const int32_t seq_plus_pad = pad_end - pad_start; + const int32_t half_seq_len = seq_plus_pad >> 1; + const int32_t pad_len_q = seq_plus_pad - seq_len_q; + const int32_t pad_len_kv = half_seq_len - seq_len_kv_half; + + // ───────── below-diagonal logic ─────────────────────────── + const int32_t last_kv_id = (cp_rank_kv + 1) * half_seq_len - 1; + const int32_t last_kv_chunk_id = last_kv_id / chunk_size; + const int32_t last_kv_chunk_len = + min(half_seq_len - (last_kv_chunk_id * chunk_size - cp_rank_kv * half_seq_len), half_seq_len); + + const int32_t first_half_id = (cp_rank_q * half_seq_len) / chunk_size; + const int32_t second_half_id = ((2 * cp_size - cp_rank_q - 1) * half_seq_len) / chunk_size; + + const int32_t first_half_len = + min(half_seq_len, (first_half_id + 1) * chunk_size - cp_rank_q * half_seq_len); + + const int32_t second_half_len = min(seq_len_q, half_seq_len + (second_half_id + 1) * chunk_size - + (2 * cp_size - 1 - cp_rank_q) * half_seq_len); + + const bool take_nothing = (last_kv_chunk_id != first_half_id); + const bool take_first_half_q = (!take_nothing) && (last_kv_chunk_id != second_half_id); + const bool take_second_half_q = (!take_nothing) && (!take_first_half_q); + + int32_t q_seq_len = 0; + if (take_first_half_q) + q_seq_len = first_half_len; + else if (take_second_half_q) + q_seq_len = second_half_len; + + int32_t kv_seq_len = half_seq_len; + if (!take_nothing) kv_seq_len = last_kv_chunk_len; + + q_seq_len = min(q_seq_len, max(0, seq_plus_pad - pad_len_q)); + kv_seq_len = max(0, kv_seq_len - pad_len_kv); + + // ───────── flat output (row-major) ───────────────────────── + const int out_base = 3 * i; + + q_chunks[out_base + 0 + 1] = 0; + q_chunks[out_base + 1 + 1] = q_seq_len; + q_chunks[out_base + 2 + 1] = 0; + + q_pads[out_base + 0 + 1] = pad_start; + q_pads[out_base + 1 + 1] = pad_start + q_seq_len; + q_pads[out_base + 2 + 1] = pad_start + seq_plus_pad; + + kv_chunks[out_base + 0 + 1] = 0; + kv_chunks[out_base + 1 + 1] = kv_seq_len; + kv_chunks[out_base + 2 + 1] = 0; + + const int32_t kv_pad_base = pad_start >> 1; + kv_pads[out_base + 0 + 1] = kv_pad_base + (half_seq_len - kv_seq_len - pad_len_kv); + kv_pads[out_base + 1 + 1] = kv_pad_base + (half_seq_len - pad_len_kv); + kv_pads[out_base + 2 + 1] = kv_pad_base + half_seq_len; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + **************************************************************************************************/ + +__global__ void thd_seq_tweak_above_diag_kernel( + const int *__restrict__ cu_seqlens_q_halfs, // len = batch+1 + const int *__restrict__ cu_seqlens_kv, // len = batch+1 + const int *__restrict__ cu_seqlens_padded, // len = batch+1 (full‑len prefix sums) + int *__restrict__ q_chunks, // len = 3·batch (row‑major) + int *__restrict__ kv_chunks, // len = 3·batch + int *__restrict__ q_pads, // len = 3·batch + int *__restrict__ kv_pads, // len = 3·batch + int cp_rank_q, int cp_rank_kv, int cp_size, int chunk_size, int batch) { + const int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i == 0) { + q_chunks[0] = 0; + kv_chunks[0] = 0; + q_pads[0] = 0; + kv_pads[0] = 0; + } + __syncthreads(); + if (i >= batch) return; + + // ───────── prefix‑sum diffs ─────────────────────────────── + const int32_t q_start = cu_seqlens_q_halfs[i]; + const int32_t q_end = cu_seqlens_q_halfs[i + 1]; + const int32_t kv_start = cu_seqlens_kv[i]; + const int32_t kv_end = cu_seqlens_kv[i + 1]; + const int32_t pad_start = cu_seqlens_padded[i]; + const int32_t pad_end = cu_seqlens_padded[i + 1]; + + const int32_t seq_len_q_half = q_end - q_start; // |Q|/2 (actual) + const int32_t seq_len_kv = kv_end - kv_start; // |KV| (full) + const int32_t seq_plus_pad = pad_end - pad_start; // |KV|+pads (full) + + const int32_t half_seq_len = seq_plus_pad >> 1; // |KV|/2 + pads/2 + + // pad lengths for later clamping + const int32_t pad_len_kv = seq_plus_pad - seq_len_kv; + + // ───────── above‑diagonal core logic ───────────────────── + // 1. Tokens from Q (first half from the *opposite* side) + const int32_t first_q_id = (2 * cp_size - 1 - cp_rank_q) * half_seq_len; + const int32_t first_q_chunk_id = first_q_id / chunk_size; + const int32_t first_q_chunk_len = + min(seq_len_q_half, (first_q_chunk_id + 1) * chunk_size - first_q_id); + + // 2. Tokens from KV (might come from first or second half) + const int32_t first_half_kv_last_el_total_id = ((cp_rank_kv + 1) * half_seq_len) - 1; + const int32_t first_half_kv_last_chunk_id = first_half_kv_last_el_total_id / chunk_size; + + const int32_t second_half_kv_last_el_total_id = ((2 * cp_size - cp_rank_kv) * half_seq_len) - 1; + const int32_t second_half_kv_last_chunk_id = second_half_kv_last_el_total_id / chunk_size; + + // last‑trimmed chunk lengths + const int32_t first_half_kv_last_el_id_in_chunk = first_half_kv_last_el_total_id % chunk_size; + const int32_t first_half_kv_last_chunk_len = + min(seq_plus_pad, half_seq_len + first_half_kv_last_el_id_in_chunk + 1); + + const int32_t second_half_kv_last_el_id_in_chunk = second_half_kv_last_el_total_id % chunk_size; + const int32_t second_half_kv_last_chunk_len = + min(half_seq_len, second_half_kv_last_el_id_in_chunk + 1); + + // 3. Decide which half(s) we actually take + const bool take_nothing = (first_q_chunk_id != second_half_kv_last_chunk_id); + const bool take_second_half_kv = + (!take_nothing) && (first_q_chunk_id != first_half_kv_last_chunk_id); + const bool take_first_half_kv = (!take_nothing) && (!take_second_half_kv); + + // ───────── resulting subseq lengths ─────────────────────── + int32_t q_seq_len = take_nothing ? 0 : first_q_chunk_len; + + int32_t kv_seq_len = 0; + if (take_second_half_kv) + kv_seq_len = second_half_kv_last_chunk_len; + else if (take_first_half_kv) + kv_seq_len = first_half_kv_last_chunk_len; + + // clamp against padding + kv_seq_len = max(0, kv_seq_len - pad_len_kv); + + // ───────── flat output (row‑major) ──────────────────────── + const int out_base = 3 * i; + + // Q chunks: [0, sequence, 0] + q_chunks[out_base + 0 + 1] = 0; // beginning of Q half‑sequence + q_chunks[out_base + 1 + 1] = q_seq_len; // after the chunk we keep + q_chunks[out_base + 2 + 1] = 0; // stays flat afterwards + + // Q pads: [0, 0, garbage] (pads live in the *first* half of padded area) + const int32_t half_pad_start = pad_start >> 1; // start of Q‑related pad area + q_pads[out_base + 0 + 1] = half_pad_start; + q_pads[out_base + 1 + 1] = half_pad_start + q_seq_len; + q_pads[out_base + 2 + 1] = half_pad_start + half_seq_len; // complete half padded length + + // KV chunks: [0, sequence, 0] + kv_chunks[out_base + 0 + 1] = 0; + kv_chunks[out_base + 1 + 1] = kv_seq_len; + kv_chunks[out_base + 2 + 1] = 0; + + // KV pads: [garbage, 0, 0] (pads precede KV if from first half) + kv_pads[out_base + 0 + 1] = pad_start + (seq_len_kv - kv_seq_len); + kv_pads[out_base + 1 + 1] = pad_start + seq_len_kv; + kv_pads[out_base + 2 + 1] = pad_start + seq_plus_pad; +} + /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ @@ -669,6 +1015,143 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to batch, total_tokens, world_size, rank); } +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +void thd_chunkify(const Tensor &cu_seqlens, const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens, + Tensor &out_cu_seqlens_padded, int batch, int output_len, int chunk_size, + cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_padded.dtype() == DType::kInt32); + + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens.dim() == 1); + NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_chunkify_kernel<<>>( + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(out_cu_seqlens.data.dptr), + reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size); +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +void thd_chunkify_p2p(const Tensor &cu_seqlens, const Tensor &cu_seqlens_padded, + Tensor &out_cu_seqlens, Tensor &out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, int cp_rank, int cp_size, + cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_padded.dtype() == DType::kInt32); + + // This tensors should be one dimensional + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens.dim() == 1); + NVTE_CHECK(out_cu_seqlens_padded.dim() == 1); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_chunkify_p2p_kernel<<>>( + reinterpret_cast(cu_seqlens.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(out_cu_seqlens.data.dptr), + reinterpret_cast(out_cu_seqlens_padded.data.dptr), batch, output_len, chunk_size, + cp_rank, cp_size); +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + **************************************************************************************************/ + +void thd_seq_tweak_below_diag(const Tensor &cu_seqlens_q, const Tensor &cu_seqlens_kv_halfs, + const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens_q, + Tensor &out_cu_seqlens_kv_halfs, Tensor &out_cu_seqlens_q_padded, + Tensor &out_cu_seqlens_kv_halfs_padded, int batch, int output_len, + int chunk_size, int cp_rank_q, int cp_rank_kv, int cp_size, + cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(cu_seqlens_q.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_kv_halfs.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_q.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv_halfs.dtype() == DType::kInt32); + + // This tensors should be one dimensional + NVTE_CHECK(cu_seqlens_q.dim() == 1); + NVTE_CHECK(cu_seqlens_kv_halfs.dim() == 1); + NVTE_CHECK(out_cu_seqlens_q.dim() == 1); + NVTE_CHECK(out_cu_seqlens_kv_halfs.dim() == 1); + + const unsigned int block = 256; + const unsigned int grid = (output_len + block - 1) / block; + thd_seq_tweak_below_diag_kernel<<>>( + reinterpret_cast(cu_seqlens_q.data.dptr), + reinterpret_cast(cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(out_cu_seqlens_q.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_halfs.data.dptr), + reinterpret_cast(out_cu_seqlens_q_padded.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_halfs_padded.data.dptr), cp_rank_q, cp_rank_kv, + cp_size, chunk_size, batch); +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + **************************************************************************************************/ + +void thd_seq_tweak_above_diag(const Tensor &cu_seqlens_q_halfs, const Tensor &cu_seqlens_kv, + const Tensor &cu_seqlens_padded, Tensor &out_cu_seqlens_q_halfs, + Tensor &out_cu_seqlens_kv, Tensor &out_cu_seqlens_q_halfs_padded, + Tensor &out_cu_seqlens_kv_padded, int cp_rank_q, int cp_rank_kv, + int cp_size, int batch, int output_len, int chunk_size, + cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(cu_seqlens_q_halfs.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_kv.dtype() == DType::kInt32); + NVTE_CHECK(cu_seqlens_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_q_halfs.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_q_halfs_padded.dtype() == DType::kInt32); + NVTE_CHECK(out_cu_seqlens_kv_padded.dtype() == DType::kInt32); + + // This tensors should be one dimensional + NVTE_CHECK(cu_seqlens_q_halfs.dim() == 1); + NVTE_CHECK(cu_seqlens_kv.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens_q_halfs.dim() == 1); + NVTE_CHECK(out_cu_seqlens_kv.dim() == 1); + NVTE_CHECK(out_cu_seqlens_q_halfs_padded.dim() == 1); + NVTE_CHECK(out_cu_seqlens_kv_padded.dim() == 1); + + const unsigned int block = 256; + const unsigned int grid = (batch + 1 + block - 1) / block; + thd_seq_tweak_above_diag_kernel<<>>( + reinterpret_cast(cu_seqlens_q_halfs.data.dptr), + reinterpret_cast(cu_seqlens_kv.data.dptr), + reinterpret_cast(cu_seqlens_padded.data.dptr), + reinterpret_cast(out_cu_seqlens_q_halfs.data.dptr), + reinterpret_cast(out_cu_seqlens_kv.data.dptr), + reinterpret_cast(out_cu_seqlens_q_halfs_padded.data.dptr), + reinterpret_cast(out_cu_seqlens_kv_padded.data.dptr), cp_rank_q, cp_rank_kv, cp_size, + chunk_size, batch); + // synchronize + cudaStreamSynchronize(stream); +} + } // namespace context_parallel } // namespace transformer_engine @@ -741,3 +1224,59 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso *convertNVTETensorCheck(output), total_tokens, world_size, rank, stream); } + +void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_chunkify); + using namespace transformer_engine; + + context_parallel::thd_chunkify( + *convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens), *convertNVTETensorCheck(out_cu_seqlens_padded), + batch, output_len, chunk_size, stream); +} + +void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, int cp_rank, int cp_size, + cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_chunkify_p2p); + using namespace transformer_engine; + + context_parallel::thd_chunkify_p2p( + *convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(cu_seqlens_padded), + *convertNVTETensorCheck(out_cu_seqlens), *convertNVTETensorCheck(out_cu_seqlens_padded), + batch, output_len, chunk_size, cp_rank, cp_size, stream); +} + +void nvte_cp_thd_seq_tweak_below_diag( + const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_seq_tweak_below_diag); + using namespace transformer_engine; + context_parallel::thd_seq_tweak_below_diag( + *convertNVTETensorCheck(cu_seqlens_q), *convertNVTETensorCheck(cu_seqlens_kv_halfs), + *convertNVTETensorCheck(cu_seqlens_padded), *convertNVTETensorCheck(out_cu_seqlens_q), + *convertNVTETensorCheck(out_cu_seqlens_kv), *convertNVTETensorCheck(out_cu_seqlens_q_padded), + *convertNVTETensorCheck(out_cu_seqlens_kv_padded), batch, output_len, chunk_size, cp_rank_q, + cp_rank_kv, cp_size, stream); +} + +void nvte_cp_thd_seq_tweak_above_diag( + const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream) { + NVTE_API_CALL(nvte_thd_seq_tweak_above_diag); + using namespace transformer_engine; + + context_parallel::thd_seq_tweak_above_diag( + *convertNVTETensorCheck(cu_seqlens_q_halfs), *convertNVTETensorCheck(cu_seqlens_kv), + *convertNVTETensorCheck(cu_seqlens_padded), *convertNVTETensorCheck(out_cu_seqlens_q), + *convertNVTETensorCheck(out_cu_seqlens_kv), *convertNVTETensorCheck(out_cu_seqlens_q_padded), + *convertNVTETensorCheck(out_cu_seqlens_kv_padded), cp_rank_q, cp_rank_kv, cp_size, batch, + output_len, chunk_size, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 44f5791490..b6ee193b1f 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -742,6 +742,90 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso int total_tokens, int world_size, int rank, cudaStream_t stream); +/*! \brief Split sequence into chunks for one P2P part on diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens Output tensor. + * \param[out] out_cu_seqlens_padded Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_chunkify(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, int batch, + int output_len, int chunk_size, cudaStream_t stream); + +/*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens Output tensor. + * \param[out] out_cu_seqlens_padded Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] cp_rank Context Parallel rank. + * \param[in] cp_size Context Parallel size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_chunkify_p2p(const NVTETensor &cu_seqlens, const NVTETensor &cu_seqlens_padded, + NVTETensor out_cu_seqlens, NVTETensor out_cu_seqlens_padded, + int batch, int output_len, int chunk_size, int cp_rank, int cp_size, + cudaStream_t stream); + +/*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens_q Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_kv_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_padded Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens_q Output tensor. + * \param[out] out_cu_seqlens_kv Output tensor. + * \param[out] out_cu_seqlens_q_padded Output tensor. + * \param[out] out_cu_seqlens_kv_padded Output tensor. + * \param[in] cp_rank_q Context Parallel rank for Q. + * \param[in] cp_rank_kv Context Parallel rank for KV. + * \param[in] cp_size Context Parallel size. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_seq_tweak_below_diag( + const NVTETensor &cu_seqlens_q, const NVTETensor &cu_seqlens_kv_halfs, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream); + +/*! \brief Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] cu_seqlens_q_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens_kv_halfs Cumulative sequence lengths, [batch_size + 1]. + * \param[out] out_cu_seqlens_q_halfs Output tensor. + * \param[in] batch Batch size. + * \param[in] output_len Output length. + * \param[in] chunk_size Chunk size. + * \param[in] stream CUDA stream used for this operation. + */ + +void nvte_cp_thd_seq_tweak_above_diag( + const NVTETensor &cu_seqlens_q_halfs, const NVTETensor &cu_seqlens_kv, + const NVTETensor &cu_seqlens_padded, NVTETensor out_cu_seqlens_q, NVTETensor out_cu_seqlens_kv, + NVTETensor out_cu_seqlens_q_padded, NVTETensor out_cu_seqlens_kv_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int batch, int output_len, int chunk_size, cudaStream_t stream); + /*! \brief Convert tensor from THD to BSHD format. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 2313484054..43c2f95ee1 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -85,6 +85,11 @@ def feed(self, tensor, iteration): if self.modified[0] and not self.reduce_within_microbatch: return + # We do not feed the tensor with 0 elements, + # we behave the same way as if feed() was not called. + if tensor.numel() == 0: + return + # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: fn, _ = STATS[stat_name] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b3b7630df3..582e61c4ea 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -470,6 +470,7 @@ def forward( max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, + chunk_size: Optional[int] = None, alibi_slopes: Optional[torch.Tensor] = None, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, @@ -682,6 +683,7 @@ def forward( attn_mask_type=attn_mask_type, deterministic=self.deterministic, window_size=window_size, + chunk_size=chunk_size, quantizers=quantizers, pad_between_seqs=False, use_flash_attn_3=use_flash_attn_3, @@ -1431,6 +1433,7 @@ def forward( attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, + chunk_size: Optional[int] = None, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, @@ -1583,6 +1586,7 @@ def forward( cp_stream, cp_comm_type, softmax_scale=self.softmax_scale, + chunk_size=chunk_size, qkv_format=qkv_format, attn_mask_type=attn_mask_type, attn_bias_type=core_attention_bias_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 9f4822784e..016885e674 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -444,6 +444,7 @@ def forward( cu_seqlens_kv_padded, dropout_p, softmax_scale, + chunk_size, qkv_format, attn_mask_type, attn_bias_type, @@ -521,6 +522,8 @@ def forward( max_seqlen_kv = max_seqlen_kv // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] + cu_seqlens_q_padded_per_step = [None for _ in range(cp_size)] + cu_seqlens_kv_padded_per_step = [None for _ in range(cp_size)] fused_attn_backend = None qkv_dtype = q.dtype @@ -737,6 +740,8 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q cu_seqlens_kv_per_step[i] = cu_seqlens_kv + cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded + cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -762,7 +767,30 @@ def forward( -1, k.shape[2], 2, *k.shape[-2:] ) elif qkv_format == "thd": + thd_total_seq_len = q.shape[0] q_inputs[i % 2] = q + if chunk_size is not None: + cu_seqlens_q_per_step[i], cu_seqlens_q_padded_per_step[i] = ( + dpa_utils.thd_chunkify_p2p( + cu_seqlens_q_per_step[i], + cu_seqlens_q_padded_per_step[i], + chunk_size, + rank, + cp_size, + thd_total_seq_len, + ) + ) + cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded_per_step[i] = ( + dpa_utils.thd_chunkify_p2p( + cu_seqlens_kv_per_step[i], + cu_seqlens_kv_padded_per_step[i], + chunk_size, + rank, + cp_size, + thd_total_seq_len, + ) + ) + if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size @@ -819,10 +847,11 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[i], **fp8_meta_kwargs, ) + if fp8: softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors else: @@ -884,6 +913,13 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half + + cu_seqlens_q_padded_per_step[i] = cu_seqlens_q_padded + cu_seqlens_kv_padded_per_step[i] = ( + cu_seqlens_kv_padded // 2 + if cu_seqlens_kv_padded is not None + else None + ) if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) @@ -906,7 +942,10 @@ def forward( kv_inputs[i % 2] = kv_inputs[i % 2][0] elif qkv_format == "thd": q_inputs[i % 2] = q + # [2, t, np, hn] -> [2, t/2, np, hn] + if enable_mla: + assert chunk_size is None # [t, np, hn] -> [t/2, np, hn] k_part = tex.thd_read_half_tensor( k_part, cu_seqlens_kv_padded, 0 @@ -916,6 +955,21 @@ def forward( ) else: # [2, t, np, hn] -> [2, t/2, np, hn] + if chunk_size is not None: + ( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded_per_step[i], + ) = dpa_utils.thd_seq_tweak_below_diagonal( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded, + rank, + rank - i, + cp_size, + chunk_size, + ) kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2], cu_seqlens_kv_padded, 0 ) @@ -971,12 +1025,8 @@ def forward( attn_mask_type="padding" if padding else "no_mask", attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[i], **fp8_meta_kwargs, ) if fp8: @@ -1047,6 +1097,13 @@ def forward( else: cu_seqlens_q_per_step[i] = cu_seqlens_q_half cu_seqlens_kv_per_step[i] = cu_seqlens_kv + + cu_seqlens_q_padded_per_step[i] = ( + cu_seqlens_q_padded // 2 + if cu_seqlens_q_padded is not None + else None + ) + cu_seqlens_kv_padded_per_step[i] = cu_seqlens_kv_padded if qkv_format == "bshd": # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] q_inputs[i % 2] = q[:, 1, ...] @@ -1076,6 +1133,22 @@ def forward( q_inputs[i % 2] = tex.thd_read_half_tensor( q, cu_seqlens_q_padded, 1 ) + if chunk_size is not None: + ( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded_per_step[i], + ) = dpa_utils.thd_seq_tweak_above_diagonal( + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded, + rank, + rank - i + cp_size, + cp_size, + chunk_size, + ) + if use_fused_attention: q_inputs[i % 2] = q_inputs[i % 2].contiguous() if attn_bias is not None: @@ -1130,12 +1203,8 @@ def forward( attn_mask_type="padding" if padding else "no_mask", attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[i], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[i], **fp8_meta_kwargs, ) if fp8: @@ -1466,6 +1535,8 @@ def forward( cu_seqlens_kv_padded, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, + *cu_seqlens_q_padded_per_step, + *cu_seqlens_kv_padded_per_step, *rng_states, *attn_biases, ) @@ -1542,8 +1613,10 @@ def backward(ctx, dout): ) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] - rng_states = other_tensors[cp_size * 2 : cp_size * 3] - attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + cu_seqlens_q_padded_per_step = other_tensors[cp_size * 2 : cp_size * 3] + cu_seqlens_kv_padded_per_step = other_tensors[cp_size * 3 : cp_size * 4] + rng_states = other_tensors[cp_size * 4 : cp_size * 5] + attn_biases = other_tensors[cp_size * 5 : cp_size * 6] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -1849,8 +1922,8 @@ def backward(ctx, dout): fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[cp_size - i - 1], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[cp_size - i - 1], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -1984,6 +2057,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, @@ -1998,10 +2072,8 @@ def backward(ctx, dout): fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 - ), + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[cp_size - i - 1], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[cp_size - i - 1], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -2128,6 +2200,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, @@ -2142,10 +2215,8 @@ def backward(ctx, dout): fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, - cu_seqlens_q_padded=( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_padded_per_step[cp_size - i - 1], + cu_seqlens_kv_padded=cu_seqlens_kv_padded_per_step[cp_size - i - 1], attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, @@ -2259,7 +2330,6 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) - if ctx.fp8: dq_ = dq_._data dk_ = dk_._data @@ -2661,6 +2731,7 @@ def backward(ctx, dout): None, None, None, + None, attn_dbias, None, None, @@ -3769,6 +3840,7 @@ def attn_forward_func_with_cp( deterministic=False, use_fused_attention=False, window_size=None, + chunk_size=None, fp8=False, fp8_meta=None, quantizers=None, @@ -3895,6 +3967,7 @@ def attn_forward_func_with_cp( cu_seqlens_kv_padded, dropout_p, softmax_scale, + chunk_size, qkv_format, attn_mask_type, attn_bias_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7d50b9fa54..9eef6b8b57 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -217,6 +217,7 @@ def __init__( tp_group: Optional[dist_group_type] = None, layer_number: Optional[int] = None, attention_type: str = "self", + chunk_size: Optional[int] = None, cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, @@ -264,6 +265,11 @@ def __init__( num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" + assert ( + chunk_size is None or qkv_format == "thd" + ), "Chunk size is only supported for thd format" + self.chunk_size = chunk_size + self.rng_states_tracker = None if sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext @@ -747,6 +753,15 @@ def forward( assert ( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + + if self.chunk_size is not None and self.cp_group is None: + total_seq_len = query_layer.shape[0] + cu_seqlens_q, cu_seqlens_q_padded = dpa_utils.thd_chunkify( + cu_seqlens_q, cu_seqlens_q_padded, self.chunk_size, total_seq_len + ) + cu_seqlens_kv, cu_seqlens_kv_padded = dpa_utils.thd_chunkify( + cu_seqlens_kv, cu_seqlens_kv_padded, self.chunk_size, total_seq_len + ) batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: @@ -947,6 +962,7 @@ def forward( head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, + chunk_size=self.chunk_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, core_attention_bias_type=core_attention_bias_type, core_attention_bias_shape=core_attention_bias_shape, @@ -1080,6 +1096,7 @@ def forward( attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, + chunk_size=self.chunk_size, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias=fu_core_attention_bias, @@ -1107,6 +1124,7 @@ def forward( max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + chunk_size=self.chunk_size, window_size=window_size, fused_attention_backend=fused_attention_backend, core_attention_bias_type=fu_core_attention_bias_type, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index d98dde0159..6b487a98de 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -190,6 +190,8 @@ class AttentionParams: `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size: Tuple[int, int], default = None Sliding window attention size. + chunk_size: int, default = None + Chunk size for context parallelism. alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. core_attention_bias_type: str, default = `no_bias` @@ -229,6 +231,7 @@ class AttentionParams: head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None + chunk_size: int = None alibi_slopes_shape: Union[torch.Size, List, None] = None core_attention_bias_type: str = "no_bias" core_attention_bias_shape: str = "1hss" @@ -1802,3 +1805,107 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): ) return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + +def thd_chunkify( + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + chunk_size: int, + total_seq_len: int = 20, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Chunkify the cu_seqlens tensor. + Returns new cu_seqlens, cu_seqlens_padded tensors + """ + new_cu_seqlens, new_cu_seqlens_padded = tex.thd_chunkify( + cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size + ) + + return new_cu_seqlens, new_cu_seqlens_padded + + +def thd_chunkify_p2p( + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + chunk_size: int, + cp_rank: int, + cp_size: int, + total_seq_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Chunkify the cu_seqlens tensor. + Returns new cu_seqlens, cu_seqlens_padded tensors + """ + + new_cu_seqlens, new_cu_seqlens_padded = tex.thd_chunkify_p2p( + cu_seqlens, cu_seqlens_padded, total_seq_len, chunk_size, cp_rank, cp_size + ) + + return new_cu_seqlens, new_cu_seqlens_padded + + +def thd_seq_tweak_below_diagonal( + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv_halfs: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank_q: int, + cp_rank_kv: int, + cp_size: int, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert cp_rank_q > cp_rank_kv + + new_seqlens_q, new_seqlens_kv_halfs, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = ( + tex.thd_seq_tweak_below_diag( + cu_seqlens_q, + cu_seqlens_kv_halfs, + cu_seqlens_padded, + cp_rank_q, + cp_rank_kv, + cp_size, + chunk_size, + ) + ) + + new_cu_seqlens_q = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) + new_cu_seqlens_kv_halfs = torch.cumsum(new_seqlens_kv_halfs, dim=0, dtype=torch.int32) + + return ( + new_cu_seqlens_q, + new_cu_seqlens_kv_halfs, + new_cu_seqlens_q_padded, + new_cu_seqlens_kv_padded, + ) + + +def thd_seq_tweak_above_diagonal( + cu_seqlens_q_halfs: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank_q: int, + cp_rank_kv: int, + cp_size: int, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + assert cp_rank_q < cp_rank_kv + + new_seqlens_q, new_seqlens_kv, new_cu_seqlens_q_padded, new_cu_seqlens_kv_padded = ( + tex.thd_seq_tweak_above_diag( + cu_seqlens_q_halfs, + cu_seqlens_kv, + cu_seqlens_padded, + cp_rank_q, + cp_rank_kv, + cp_size, + chunk_size, + ) + ) + new_cu_seqlens_q_halfs = torch.cumsum(new_seqlens_q, dim=0, dtype=torch.int32) + new_cu_seqlens_kv = torch.cumsum(new_seqlens_kv, dim=0, dtype=torch.int32) + + return ( + new_cu_seqlens_q_halfs, + new_cu_seqlens_kv, + new_cu_seqlens_q_padded, + new_cu_seqlens_kv_padded, + ) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f86b60f612..29e3979952 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -239,7 +239,10 @@ at::Tensor allocateSpace(const std::vector& shape, const transformer_eng if (init_to_zeros) { return at::zeros(ar_shape, at::CUDA(GetATenDType(type))); } else { - return at::empty(ar_shape, at::CUDA(GetATenDType(type))); + // init to minus inf + at::Tensor ret = at::empty(ar_shape, at::CUDA(GetATenDType(type))); + ret.fill_(-std::numeric_limits::infinity()); + return ret; } } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 72f6f27596..cadb8914fd 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -301,6 +301,23 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank); +std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size); + +std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size, int cp_rank, + int cp_size); + +std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, + at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size); + +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, + at::Tensor cu_seqlens_kv, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size); + /*************************************************************************************************** * multi_tensor_* kernels **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 71a8062b1a..e18c88fc37 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -739,6 +739,155 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t return output; } +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +std::vector thd_chunkify(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.size(0) >= 2); + NVTE_CHECK(chunk_size > 0); + + int batch = cu_seqlens.size(0) - 1; + int output_len = cu_seqlens_padded.size(0) + total_seq_len / chunk_size + 1; + + // Allocate output tensors + at::Tensor out_cu_seqlens = at::empty({output_len}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_padded = at::empty({output_len}, at::CUDA(at::ScalarType::Int)); + + // Create tensor wrappers + auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_cu_seqlens = makeTransformerEngineTensor(out_cu_seqlens); + auto te_out_cu_seqlens_padded = makeTransformerEngineTensor(out_cu_seqlens_padded); + + nvte_cp_thd_chunkify(te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_out_cu_seqlens.data(), + te_out_cu_seqlens_padded.data(), batch, output_len, chunk_size, + at::cuda::getCurrentCUDAStream()); + + return {out_cu_seqlens, out_cu_seqlens_padded}; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part on diagonal. + **************************************************************************************************/ + +std::vector thd_chunkify_p2p(at::Tensor cu_seqlens, at::Tensor cu_seqlens_padded, + int total_seq_len, int chunk_size, int cp_rank, + int cp_size) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + + int batch = cu_seqlens.size(0) - 1; + int output_len = 5 * cu_seqlens_padded.size(0) + total_seq_len / chunk_size; + + at::Tensor out_cu_seqlens = at::zeros({output_len}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_padded = at::zeros({output_len}, at::CUDA(at::ScalarType::Int)); + + auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_cu_seqlens = makeTransformerEngineTensor(out_cu_seqlens); + auto te_out_cu_seqlens_padded = makeTransformerEngineTensor(out_cu_seqlens_padded); + + nvte_cp_thd_chunkify_p2p(te_cu_seqlens.data(), te_cu_seqlens_padded.data(), + te_out_cu_seqlens.data(), te_out_cu_seqlens_padded.data(), batch, + output_len, chunk_size, cp_rank, cp_size, + at::cuda::getCurrentCUDAStream()); + + return {out_cu_seqlens, out_cu_seqlens_padded}; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part below diagonal. + **************************************************************************************************/ + +std::vector thd_seq_tweak_below_diag(at::Tensor cu_seqlens_q, + at::Tensor cu_seqlens_kv_halfs, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size) { + NVTE_CHECK(cu_seqlens_q.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_kv_halfs.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_q.dim() == 1); + NVTE_CHECK(cu_seqlens_kv_halfs.dim() == 1); + NVTE_CHECK(cu_seqlens_padded.dim() == 1); + NVTE_CHECK(cu_seqlens_q.size(0) >= 2); + NVTE_CHECK(cu_seqlens_kv_halfs.size(0) >= 2); + NVTE_CHECK(cu_seqlens_padded.size(0) >= 2); + NVTE_CHECK(cp_rank_q >= 0 && cp_rank_q < cp_size); + NVTE_CHECK(cp_rank_kv >= 0 && cp_rank_kv < cp_size); + NVTE_CHECK(cp_size > 0); + NVTE_CHECK(chunk_size > 0); + + int batch = cu_seqlens_q.size(0) - 1; + int output_len = 3 * batch; + + at::Tensor out_seqlens_q = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_seqlens_kv = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_q_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_kv_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + + auto te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q); + auto te_cu_seqlens_kv_halfs = makeTransformerEngineTensor(cu_seqlens_kv_halfs); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_seqlens_q = makeTransformerEngineTensor(out_seqlens_q); + auto te_out_seqlens_kv = makeTransformerEngineTensor(out_seqlens_kv); + auto te_out_cu_seqlens_q_padded = makeTransformerEngineTensor(out_cu_seqlens_q_padded); + auto te_out_cu_seqlens_kv_padded = makeTransformerEngineTensor(out_cu_seqlens_kv_padded); + + nvte_cp_thd_seq_tweak_below_diag( + te_cu_seqlens_q.data(), te_cu_seqlens_kv_halfs.data(), te_cu_seqlens_padded.data(), + te_out_seqlens_q.data(), te_out_seqlens_kv.data(), te_out_cu_seqlens_q_padded.data(), + te_out_cu_seqlens_kv_padded.data(), cp_rank_q, cp_rank_kv, cp_size, batch, output_len, + chunk_size, at::cuda::getCurrentCUDAStream()); + + return {out_seqlens_q, out_seqlens_kv, out_cu_seqlens_q_padded, out_cu_seqlens_kv_padded}; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Split sequence into chunks for one P2P part above diagonal. + **************************************************************************************************/ + +std::vector thd_seq_tweak_above_diag(at::Tensor cu_seqlens_q_halfs, + at::Tensor cu_seqlens_kv, + at::Tensor cu_seqlens_padded, int cp_rank_q, + int cp_rank_kv, int cp_size, int chunk_size) { + NVTE_CHECK(cu_seqlens_q_halfs.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_kv.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens_padded.scalar_type() == at::ScalarType::Int); + + int batch = cu_seqlens_q_halfs.size(0) - 1; + int output_len = 3 * batch; + + at::Tensor out_seqlens_q_halfs = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_seqlens_kv = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_q_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + at::Tensor out_cu_seqlens_kv_padded = at::empty({output_len + 1}, at::CUDA(at::ScalarType::Int)); + + auto te_cu_seqlens_q_halfs = makeTransformerEngineTensor(cu_seqlens_q_halfs); + auto te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv); + auto te_cu_seqlens_padded = makeTransformerEngineTensor(cu_seqlens_padded); + auto te_out_seqlens_q_halfs = makeTransformerEngineTensor(out_seqlens_q_halfs); + auto te_out_seqlens_kv = makeTransformerEngineTensor(out_seqlens_kv); + auto te_out_cu_seqlens_q_padded = makeTransformerEngineTensor(out_cu_seqlens_q_padded); + auto te_out_cu_seqlens_kv_padded = makeTransformerEngineTensor(out_cu_seqlens_kv_padded); + + nvte_cp_thd_seq_tweak_above_diag( + te_cu_seqlens_q_halfs.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_padded.data(), + te_out_seqlens_q_halfs.data(), te_out_seqlens_kv.data(), te_out_cu_seqlens_q_padded.data(), + te_out_cu_seqlens_kv_padded.data(), cp_rank_q, cp_rank_kv, cp_size, batch, output_len, + chunk_size, at::cuda::getCurrentCUDAStream()); + + return {out_seqlens_q_halfs, out_seqlens_kv, out_cu_seqlens_q_padded, out_cu_seqlens_kv_padded}; +} + /*************************************************************************************************** * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 63c3b434d3..3430c710ad 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -280,6 +280,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("thd_get_partitioned_indices", &transformer_engine::pytorch::thd_get_partitioned_indices, "Generate partitioned indices for inputs in THD format", py::call_guard()); + m.def("thd_chunkify", &transformer_engine::pytorch::thd_chunkify, "Chunkify THD tensor", + py::arg("cu_seqlens"), py::arg("cu_seqlens_padded"), py::arg("total_seq_len"), + py::arg("chunk_size"), py::call_guard()); + m.def("thd_chunkify_p2p", &transformer_engine::pytorch::thd_chunkify_p2p, + "Chunkify THD tensor for P2P communication", py::call_guard()); + m.def("thd_seq_tweak_below_diag", &transformer_engine::pytorch::thd_seq_tweak_below_diag, + "Tweak the sequence below the diagonal for THD tensor", + py::call_guard()); + m.def("thd_seq_tweak_above_diag", &transformer_engine::pytorch::thd_seq_tweak_above_diag, + "Tweak the sequence above the diagonal for THD tensor", + py::call_guard()); // nvshmem functions m.def("init_nvshmem_backend", &transformer_engine::pytorch::init_nvshmem_backend,