diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 07b5f7c529..e1cd48e181 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -326,33 +326,37 @@ def _test_permutation_index_map( te_unpermute_output_ = te_unpermute_output.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() - torch.testing.assert_close( - pytorch_permute_output.float(), - te_permute_output_, - msg=f"Mismatch in te_permute fwd", - ) - torch.testing.assert_close( - pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad, - msg=f"Mismatch in te_permute bwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_output.float(), - te_unpermute_output_, - msg=f"Mismatch in te_unpermute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_fwd_input.grad.float(), - te_unpermute_fwd_input_grad, - msg=f"Mismatch in te_unpermute bwd", - **tols, - ) - if with_probs: + if not BENCHMARK: torch.testing.assert_close( - probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), + te_probs.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) if not pytorch_permute_fwd_input.numel(): print("Empty pytorch_permute_fwd_input activation test passed.") @@ -538,34 +542,38 @@ def _test_permutation_mask_map( te_unpermute_output_ = te_unpermute_output.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() - torch.testing.assert_close( - pytorch_permute_output.float(), - te_permute_output_, - msg=f"Mismatch in te_permute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad, - msg=f"Mismatch in te_permute bwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_output.float(), - te_unpermute_output_, - msg=f"Mismatch in te_unpermute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_unpermute_fwd_input.grad.float(), - te_unpermute_fwd_input_grad, - msg=f"Mismatch in te_unpermute bwd", - **tols, - ) - if with_probs: + if not BENCHMARK: + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) torch.testing.assert_close( - probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), + te_probs.grad.float(), + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) if not pytorch_permute_fwd_input.numel(): print("Empty pytorch_permute_fwd_input activation test passed.") @@ -827,18 +835,19 @@ def _test_moe_chunk_sort( te_output_ = te_output.float() te_fwd_input_grad = te_fwd_input.grad.float() - torch.testing.assert_close( - pytorch_output.float(), - te_output_, - msg=f"Mismatch in te_permute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_fwd_input.grad.float(), - te_fwd_input_grad, - msg=f"Mismatch in te_permute bwd", - **tols, - ) + if not BENCHMARK: + torch.testing.assert_close( + pytorch_output.float(), + te_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_fwd_input.grad.float(), + te_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) if not pytorch_fwd_input.numel(): print("Empty pytorch_fwd_input activation test passed.") @@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs( topK, num_out_tokens, tp_size, + BENCHMARK=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -1016,21 +1026,73 @@ def _test_permutation_mask_map_alongside_probs( te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_unpermute_output_ = te_unpermute_output.float() - torch.testing.assert_close( - pytorch_unpermute_output.float(), - te_unpermute_output_, - msg=f"Mismatch in fused_unpermute fwd", - **tols, - ) - torch.testing.assert_close( - pytorch_permute_fwd_input.grad.float(), - te_permute_fwd_input_grad, - msg=f"Mismatch in fused_permute bwd", - **tols, - ) - torch.testing.assert_close( - probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols - ) + if not BENCHMARK: + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in fused_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in fused_permute bwd", + **tols, + ) + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols + ) + + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: te_permute_with_probs( + te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens + ) + ) + print(f"permute\t\tfwd: TE: {t1:.3f} ms") + + te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( + te_permute_fwd_input, + te_probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: TE: {t2:.3f} ms") + + chunk_sort_fwd_input = te_permute_output.detach() + chunk_sort_fwd_input.requires_grad_(True) + chunk_sort_fwd_probs = te_permuted_probs.detach() + chunk_sort_fwd_probs.requires_grad_(True) + t1 = perf_test_cuda_kernel( + lambda: te_sort_chunks_by_index_with_probs( + chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda + ) + ) + print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms") + + chunk_sort_output, _ = te_sort_chunks_by_index_with_probs( + chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + chunk_sort_output, + te_permute_bwd_input, + forward_input=[chunk_sort_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms") def perf_test_cuda_kernel(cuda_kernel_fn): @@ -1063,7 +1125,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1092,7 +1154,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1193,7 +1255,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) def test_permutation_index_map_topk1_no_probs( te_dtype, @@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) def test_permutation_mask_map_topk1_no_probs( te_dtype, @@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("tp_size", [1, 2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) def test_chunk_permutation( @@ -1372,5 +1434,108 @@ def test_permutation_single_case(): ) +def benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size +): + torch.cuda.nvtx.range_push( + f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}" + ) + + torch.cuda.nvtx.range_push("permutation_index_map_with_probs") + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=True, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_mask_map_with_probs") + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=True, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_mask_map_without_probs") + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=False, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs") + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + BENCHMARK=True, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_pop() + + +def benchmark_multiple_cases(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + te_dtype = tex.DType.kBFloat16 + + ep_size = 64 + tp_size = 2 + num_tokens = 4096 + num_expert = 256 + hidden_size = 7168 + topK = 8 + num_out_tokens = num_tokens * topK + benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size + ) + + ep_size = 8 + tp_size = 1 + num_tokens = 8192 * 2 + num_expert = 128 + hidden_size = 4096 + topK = 6 + num_out_tokens = num_tokens * topK + benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size + ) + + ep_size = 64 + tp_size = 2 + num_tokens = 16384 + num_expert = 4 + hidden_size = 7168 + topK = 1 + num_out_tokens = num_tokens * topK + benchmark_single_case( + te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size + ) + + if __name__ == "__main__": - test_permutation_single_case() + benchmark_multiple_cases() diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index d88047a012..ea3e67a57c 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -349,7 +349,7 @@ def forward( if restore_shape is None: restore_shape = inp.shape num_tokens, hidden_size = restore_shape - num_experts = row_id_map.size(0) + num_experts = (row_id_map.size(1) - 1) // 2 with_probs = merging_probs is not None if with_probs: @@ -651,14 +651,20 @@ def forward( fp8_scale_inv = inp._scale_inv fake_dtype = inp.dtype inp = inp._data - output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx( - inp, + + row_id_map = triton_permutation.make_chunk_sort_map( split_sizes, sorted_idxs, + num_tokens, + num_splits, + ) + output, permuted_probs = triton_permutation.sort_chunks_by_map( + inp, + row_id_map, probs, num_tokens, hidden_size, - num_splits, + is_forward=True, ) if fp8: output = Float8Tensor( @@ -700,6 +706,7 @@ def backward( permuted_probs_grad, ctx.num_tokens, ctx.hidden_size, + is_forward=False, ) if fp8: act_grad = Float8Tensor( diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index ebf8dd551e..9ce01362f7 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -10,6 +10,72 @@ import triton import triton.language as tl +from triton.language import core +from triton.language.standard import _log2 + + +# The following three argsort related kernels are adapted from +# the issue https://github.com/triton-lang/triton/issues/3698 + + +@triton.jit +def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)] + y = tl.reshape(x, shape) + z = tl.reshape(indices, shape) + + mask = tl.arange(0, 2)[None, :, None] + + l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + + l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) + r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) + + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + il_value = l_value.to(idtype, bitcast=True) + ir_value = r_value.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix)) + ret = ix ^ flag1 + flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix)) + ind = indices ^ flag2 + + return ret.to(x.dtype, bitcast=True), ind + + +@triton.jit +def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + """ + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + """ + if order == 2: + shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = tl.full(x.shape, value=order, dtype=tl.int32) + for i in tl.static_range(stage): + x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims) + return x, indices + + +@triton.jit +def _argsort(x, indices, n_dims: tl.constexpr): + for i in tl.static_range(1, n_dims + 1): + x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims) + return x, indices + @triton.jit def _row_id_map_pass_1_kernel( @@ -22,6 +88,8 @@ def _row_id_map_pass_1_kernel( # strides stride_routing_map_token, stride_routing_map_expert, + stride_row_id_map_token, + stride_row_id_map_expert, # metas BLOCK_SIZE: tl.constexpr, ): @@ -32,10 +100,10 @@ def _row_id_map_pass_1_kernel( routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, mask=(offset < num_tokens), other=0, - ).to(tl.int64) + ).to(tl.int32) row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask tl.store( - row_id_map_ptr + pid_m * num_tokens + offset, + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, row_id_within_token_block, mask=offset < num_tokens, ) @@ -50,6 +118,9 @@ def _row_id_map_pass_2_kernel( workspace_ptr, # sizes num_tokens, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, # metas WORKSPACE_LOAD_WIDTH: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -59,7 +130,9 @@ def _row_id_map_pass_2_kernel( chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) row_id_within_token_block = tl.load( - row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + mask=(offset < num_tokens), + other=0, ) workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) @@ -70,23 +143,102 @@ def _row_id_map_pass_2_kernel( row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, ) tl.store( - row_id_map_ptr + pid_m * num_tokens + offset, + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, row_id, mask=(offset < num_tokens), ) +@triton.jit +def _row_id_map_pass_3_kernel( + # pointers + row_id_map_ptr, + # sizes + num_experts: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + LOAD_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_dims: tl.constexpr = _log2(LOAD_SIZE) + off = tl.arange(0, LOAD_SIZE) + row_id_map = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off, + mask=off < num_experts, + other=-1, + ) + n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0)) + indices = off + sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert, + sorted_map, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + off) * stride_row_id_map_expert, + indices, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert, + n_routed, + ) + + def make_row_id_map( routing_map: torch.Tensor, num_tokens: int, num_experts: int, ): - # pylint: disable=missing-function-docstring - row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda") - block_size = 256 + """ + Prepare the row_id_map for the permutation. + + Parameters + ---------- + routing_map: torch.Tensor + Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates + which experts are routed to which tokens. The values in it: 1 means the token is routed to + this expert and 0 means not. + num_tokens: int + Number of tokens in the input tensor. + num_experts: int + Number of experts in the input tensor. + + Returns + ------- + row_id_map: torch.Tensor + The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`. + For each token, the last item is the number of experts that are routed (n_routed). + The first n_routed items are the destination row indices in the permuted tokens. + The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding + to the first n_routed row indices above. + """ + row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda") + block_size = 1024 grid = (num_experts, triton.cdiv(num_tokens, block_size)) - workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda") - # block cumsum + workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda") + + # supposing num_tokens == 5, num_experts == 3, block_size == 3 + # and we have a routing_map like this: + # [[1, 1, 0], + # [1, 0, 1], + # [0, 0, 1], + # [1, 1, 0], + # [0, 0, 0]] + + # pass 1: block cumsum + # for each expert, compute the cumsum of every block_size tokens + # the row_id_map will be like this after pass 1 (r means useless values): + # [[1, 1, 0, r, r, r, r], + # [2, 0, 1, r, r, r, r], + # [0, 0, 2, r, r, r, r], + # [1, 1, 0, r, r, r, r], + # [0, 0, 0, r, r, r, r]] _row_id_map_pass_1_kernel[grid]( routing_map, row_id_map, @@ -94,16 +246,44 @@ def make_row_id_map( num_tokens, routing_map.stride(0), routing_map.stride(1), + row_id_map.stride(0), + row_id_map.stride(1), block_size, ) - # cumsum all and process the mask + + # pass 2: cumsum all and process the mask + # process the block cumsum into the global cumsum and then into the dst row indices + # the row_id_map will be like this after pass 2 (r means useless value): + # [[ 0, 3, -1, r, r, r, r], + # [ 1, -1, 5, r, r, r, r], + # [-1, -1, 6, r, r, r, r], + # [ 2, 4, -1, r, r, r, r], + # [-1, -1, -1, r, r, r, r]] _row_id_map_pass_2_kernel[grid]( row_id_map, workspace_tensor, num_tokens, + row_id_map.stride(0), + row_id_map.stride(1), triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), block_size, ) + + # pass 3: make the row_id_map from the sparse structure to the dense structure + # the row_id_map will be like this after pass 3 (r means useless value): + # [[3, 0, r, 1, 0, r, 2], + # [5, 1, r, 2, 0, r, 2], + # [6, r, r, 2, r, r, 1], + # [4, 2, r, 1, 0, r, 2], + # [r, r, r, r, r, r, 0]] + grid = (num_tokens,) + _row_id_map_pass_3_kernel[grid]( + row_id_map, + num_experts, + row_id_map.stride(0), + row_id_map.stride(1), + triton.next_power_of_2(num_experts), + ) return row_id_map @@ -118,11 +298,12 @@ def _permute_kernel( permuted_probs_ptr, permuted_scale_ptr, # sizes - num_tokens, - num_experts, - hidden_size, + num_experts: tl.constexpr, + hidden_size: tl.constexpr, scale_hidden_dim, # strides + stride_row_id_map_token, + stride_row_id_map_expert, stride_input_token, stride_input_hidden, stride_output_token, @@ -139,35 +320,50 @@ def _permute_kernel( PERMUTE_SCALE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - pid = tl.program_id(0) - cur_pos = 0 - while cur_pos < hidden_size: - cur_off = cur_pos + tl.arange(0, BLOCK_SIZE) - mask = cur_off < hidden_size - input_off = pid * stride_input_token + cur_off * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = cur_off < hidden_size + input_off = pid_t * stride_input_token + cur_off * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + if PERMUTE_SCALE: + mask_scale = cur_off < scale_hidden_dim + scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden + scale = tl.load(scale_ptr + scale_off, mask=mask_scale) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ) + output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: - mask_scale = cur_off < scale_hidden_dim - scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden - scale = tl.load(scale_ptr + scale_off, mask=mask_scale) - for expert_idx in range(num_experts): - dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) - if dst_row != -1: - output_off = dst_row * stride_output_token + cur_off * stride_output_hidden + permuted_scale_off = ( + dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden + ) + tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) + if PERMUTE_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + if pid_h == 0: + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + if prob == 0.0: + # for routing_map padding + # dst_row != -1 and prob == 0.0 means that this slot is padded + tl.store(output_ptr + output_off, 0, mask=mask) + else: tl.store(output_ptr + output_off, inp, mask=mask) - if PERMUTE_SCALE: - permuted_scale_off = ( - dst_row * stride_permuted_scale_token - + cur_off * stride_permuted_scale_hidden - ) - tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) - if PERMUTE_PROBS: - if cur_pos == 0: - prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - cur_pos += BLOCK_SIZE + else: + tl.store(output_ptr + output_off, inp, mask=mask) try: @@ -178,6 +374,8 @@ def _permute_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_permute_kernel) @@ -196,7 +394,30 @@ def permute_with_mask_map( hidden_size: int, scale_hidden_dim: int, ): - # pylint: disable=missing-function-docstring + """ + Permute the input tensor based on the row_id_map. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + probs: torch.Tensor + The probabilities of the input tensor. If it is not None, it will be permuted. + scale: torch.Tensor + The scale of the input tensor. If it is not None, it will be permuted. + num_tokens: int + Number of tokens in the input tensor. + num_experts: int + Number of experts in the input tensor. + num_out_tokens: int + Number of tokens in the permuted tensor. + hidden_size: int + Hidden size of the input tensor. + scale_hidden_dim: int + Hidden size of the scale tensor. + """ output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") if probs is not None: permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") @@ -209,8 +430,8 @@ def permute_with_mask_map( ) else: permuted_scale = None - - grid = (num_tokens,) + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _permute_kernel[grid]( inp, output, @@ -219,10 +440,11 @@ def permute_with_mask_map( scale, permuted_probs, permuted_scale, - num_tokens, num_experts, hidden_size, scale_hidden_dim, + row_id_map.stride(0), + row_id_map.stride(1), inp.stride(0), inp.stride(1), output.stride(0), @@ -250,10 +472,11 @@ def _unpermute_kernel( permuted_probs_ptr, unpermuted_probs_ptr, # sizes - num_tokens, - num_experts, - hidden_size, + num_experts: tl.constexpr, + hidden_size: tl.constexpr, # strides + stride_row_id_map_token, + stride_row_id_map_expert, stride_input_token, stride_input_hidden, stride_output_token, @@ -264,6 +487,7 @@ def _unpermute_kernel( stride_unpermuted_probs_token, stride_unpermuted_probs_expert, # metas + PROBS_LOAD_WIDTH: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -271,41 +495,63 @@ def _unpermute_kernel( data_type = input_ptr.dtype.element_ty compute_type = tl.float32 - pid = tl.program_id(0) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - for expert_idx in range(num_experts): - src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) - if src_row != -1: - input_off = src_row * stride_input_token + current_offset * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - if WITH_MERGING_PROBS: - merging_prob_off = ( - pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - inp *= merging_prob - accumulator += inp - if PERMUTE_PROBS: - if current_start == 0: - unpermuted_prob_off = ( - pid * stride_unpermuted_probs_token - + expert_idx * stride_unpermuted_probs_expert - ) - if src_row != -1: - permuted_prob_off = src_row * stride_permuted_probs_token - prob = tl.load(permuted_probs_ptr + permuted_prob_off) - tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) - else: - tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) - accumulator = accumulator.to(data_type) - output_off = pid * stride_output_token + current_offset * stride_output_hidden - tl.store(output_ptr + output_off, accumulator, mask=mask) - current_start += BLOCK_SIZE + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + if PERMUTE_PROBS: + # write 0.0 to probs_grad that are not routed + if pid_h == 0: + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + stride_unpermuted_probs_expert * map_load_off + ) + tl.store( + unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts + ) + accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + src_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ) + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + merging_prob_off = ( + pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob + accumulator += inp + if PERMUTE_PROBS: + if pid_h == 0: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + accumulator = accumulator.to(data_type) + output_off = pid_t * stride_output_token + current_offset * stride_output_hidden + tl.store(output_ptr + output_off, accumulator, mask=mask) try: @@ -316,6 +562,8 @@ def _unpermute_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_unpermute_kernel) @@ -332,7 +580,27 @@ def unpermute_with_mask_map( num_experts: int, hidden_size: int, ): - # pylint: disable=missing-function-docstring + """ + Unpermute the input tensor based on the row_id_map. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_out_tokens, hidden_size]`. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + merging_probs: torch.Tensor + The merging probabilities of the input tensor. If it is not None, it will be used as weights + to reduce the unpermuted tokens. + permuted_probs: torch.Tensor + The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + num_tokens: int + Number of tokens in the permuted tensor. + num_experts: int + Number of experts in the permuted tensor. + hidden_size: int + Hidden size of the permuted tensor. + """ output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") if permuted_probs is not None: unpermuted_probs = torch.empty( @@ -340,7 +608,8 @@ def unpermute_with_mask_map( ) else: unpermuted_probs = None - grid = (num_tokens,) + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _unpermute_kernel[grid]( inp, output, @@ -348,9 +617,10 @@ def unpermute_with_mask_map( merging_probs, permuted_probs, unpermuted_probs, - num_tokens, num_experts, hidden_size, + row_id_map.stride(0), + row_id_map.stride(1), inp.stride(0), inp.stride(1), output.stride(0), @@ -360,6 +630,7 @@ def unpermute_with_mask_map( permuted_probs.stride(0) if permuted_probs is not None else None, unpermuted_probs.stride(0) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None, + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), WITH_MERGING_PROBS=merging_probs is not None, PERMUTE_PROBS=permuted_probs is not None, ) @@ -376,10 +647,11 @@ def _unpermute_bwd_with_merging_probs_kernel( merging_probs_grad_ptr, row_id_map_ptr, # sizes - num_tokens, - num_experts, - hidden_size, + num_experts: tl.constexpr, + hidden_size: tl.constexpr, # strides + stride_row_id_map_token, + stride_row_id_map_expert, stride_fwd_output_grad_token, stride_fwd_output_grad_hidden, stride_fwd_input_grad_token, @@ -391,56 +663,63 @@ def _unpermute_bwd_with_merging_probs_kernel( stride_merging_probs_grad_token, stride_merging_probs_grad_expert, # metas + PROBS_LOAD_WIDTH: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): data_type = fwd_output_grad_ptr.dtype.element_ty compute_type = tl.float32 pid = tl.program_id(0) - for expert_idx in range(num_experts): - dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) - if dst_row != -1: - prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_off = ( - pid * stride_fwd_output_grad_token - + current_offset * stride_fwd_output_grad_hidden - ) - inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - merging_prob_off = ( - pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - output = inp * merging_prob - output = output.to(data_type) - output_off = ( - dst_row * stride_fwd_input_grad_token - + current_offset * stride_fwd_input_grad_hidden - ) - tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) - - fwd_input_off = ( - dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden - ) - fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) - prob_grad_accum += fwd_input.to(compute_type) * inp - current_start += BLOCK_SIZE - probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) - probs_grad_off = ( - pid * stride_merging_probs_grad_token - + expert_idx * stride_merging_probs_grad_expert + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + token_probs_grad_off = ( + pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off + ) + tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts) + n_routed = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert + ) + expert_idx = tl.load( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_off = ( + pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden ) - tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) - else: - probs_grad_off = ( - pid * stride_merging_probs_grad_token - + expert_idx * stride_merging_probs_grad_expert + inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) + output_off = ( + dst_row * stride_fwd_input_grad_token + + current_offset * stride_fwd_input_grad_hidden ) - tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) + tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) + + fwd_input_off = ( + dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden + ) + fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) + prob_grad_accum += fwd_input.to(compute_type) * inp + current_start += BLOCK_SIZE + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) try: @@ -451,6 +730,8 @@ def _unpermute_bwd_with_merging_probs_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_unpermute_bwd_with_merging_probs_kernel) @@ -468,7 +749,28 @@ def unpermute_with_mask_map_bwd_with_merging_probs( num_out_tokens: int, hidden_size: int, ): - # pylint: disable=missing-function-docstring + """ + Unpermute backward pass kernel with merging probs. + + Parameters + ---------- + fwd_output_grad: torch.Tensor + The gradient of the output tensor of shape `[num_tokens, hidden_size]`. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + fwd_input: torch.Tensor + The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. + merging_probs: torch.Tensor + The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. + num_tokens: int + Number of tokens in the permuted tensor. + num_experts: int + Number of experts in the permuted tensor. + num_out_tokens: int + Number of tokens in the output tensor. + hidden_size: int + Hidden size of the output tensor. + """ act_grad = torch.empty( (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" ) @@ -483,9 +785,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs( merging_probs, merging_probs_grad, row_id_map, - num_tokens, num_experts, hidden_size, + row_id_map.stride(0), + row_id_map.stride(1), fwd_output_grad.stride(0), fwd_output_grad.stride(1), act_grad.stride(0), @@ -496,34 +799,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs( merging_probs.stride(1), merging_probs_grad.stride(0), merging_probs_grad.stride(1), + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), ) return act_grad, merging_probs_grad @triton.jit -def _sort_chunks_by_idxs_kernel( +def _make_chunk_sort_map_kernel( # pointers - input_ptr, split_sizes_ptr, sorted_indices_ptr, - output_ptr, dst_rows_ptr, - probs_ptr, - permuted_probs_ptr, # sizes - num_splits, - hidden_size, - # strides - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_probs_token, - stride_permuted_probs_token, + num_splits: tl.constexpr, # metas - PERMUTE_PROBS: tl.constexpr, IDX_LOAD_WIDTH: tl.constexpr, - BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) @@ -533,104 +823,58 @@ def _sort_chunks_by_idxs_kernel( ) # get chunk idx of the current token in the input tensor - input_chunk_idx = -1 - in_chunk_offset = tl.zeros([], dtype=tl.int64) - acc_chunk_sizes = tl.zeros([], dtype=tl.int64) - cursor = 0 - while cursor < num_splits: - cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64) - acc_chunk_sizes += cur_chunk_size - if input_chunk_idx == -1 and acc_chunk_sizes > pid: - input_chunk_idx = cursor - in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size) - cursor += 1 + input_split_sizes = tl.load( + split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 + ).to(tl.int32) + input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) + input_chunk_idx = tl.sum(input_split_sizes_mask) + input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) + in_chunk_offset = pid - input_split_sizes_presum # get chunk idx of the current token in the output tensor - output_chunk_idx = 0 - cursor = 0 - while cursor < num_splits: - cur_input_idx = tl.load(sorted_indices_ptr + cursor) - if cur_input_idx == input_chunk_idx: - output_chunk_idx = cursor - cursor += 1 + output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0) + output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1) # make row_id_map output_split_sizes = tl.load( split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits - ).to(tl.int64) + ).to(tl.int32) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset tl.store(dst_rows_ptr + pid, dst_row) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_offsets = pid * stride_input_token + current_offset * stride_input_hidden - output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden - inp = tl.load(input_ptr + input_offsets, mask=mask) - tl.store(output_ptr + output_offsets, inp, mask=mask) - current_start += BLOCK_SIZE - if PERMUTE_PROBS: - prob_off = pid * stride_probs_token - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - - -try: - _sort_chunks_by_idxs_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - ], - key=["hidden_size"], - )(_sort_chunks_by_idxs_kernel) -except RuntimeError: - pass - - -def sort_chunks_by_idx( - inp: torch.Tensor, +def make_chunk_sort_map( split_sizes: torch.Tensor, sorted_indices: torch.Tensor, - probs: torch.Tensor, num_tokens: int, - hidden_size: int, num_splits: int, ): - # pylint: disable=missing-function-docstring - row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") - output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") - if probs is not None: - permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") - else: - permuted_probs = None + """ + Make a row_id_map for chunk sort. + + Parameters + ---------- + split_sizes: torch.Tensor + The sizes of the chunks of shape `[num_splits,]`. + sorted_indices: torch.Tensor + The indices of the sorted chunks of shape `[num_splits,]`. + num_tokens: int + Number of tokens in the input tensor. + num_splits: int + Number of splits of split_sizes and sorted_indices. + """ + row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda") grid = (num_tokens,) - _sort_chunks_by_idxs_kernel[grid]( - inp, + _make_chunk_sort_map_kernel[grid]( split_sizes, sorted_indices, - output, row_id_map, - probs, - permuted_probs, num_splits, - hidden_size, - inp.stride(0), - inp.stride(1), - output.stride(0), - output.stride(1), - probs.stride(0) if probs is not None else None, - permuted_probs.stride(0) if permuted_probs is not None else None, - PERMUTE_PROBS=probs is not None, IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), ) - return output, row_id_map, permuted_probs + return row_id_map @triton.jit @@ -642,7 +886,7 @@ def _sort_chunks_by_map_kernel( probs_ptr, permuted_probs_ptr, # sizes - hidden_size, + hidden_size: tl.constexpr, # strides stride_input_token, stride_input_hidden, @@ -653,23 +897,28 @@ def _sort_chunks_by_map_kernel( # metas PERMUTE_PROBS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + FORWARD: tl.constexpr, ): - pid = tl.program_id(0) - dst_row = tl.load(row_id_map_ptr + pid) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden - output_offsets = pid * stride_output_token + current_offset * stride_output_hidden - inp = tl.load(input_ptr + input_offsets, mask=mask) - tl.store(output_ptr + output_offsets, inp, mask=mask) - current_start += BLOCK_SIZE + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + if FORWARD: + src_row = pid_t + dst_row = tl.load(row_id_map_ptr + pid_t) + else: + src_row = tl.load(row_id_map_ptr + pid_t) + dst_row = pid_t + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden + output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) if PERMUTE_PROBS: - prob_off = dst_row * stride_probs_token - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = pid * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) + if pid_h == 0: + prob_off = src_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) try: @@ -680,6 +929,8 @@ def _sort_chunks_by_map_kernel( triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), ], key=["hidden_size"], )(_sort_chunks_by_map_kernel) @@ -693,14 +944,33 @@ def sort_chunks_by_map( probs: torch.Tensor, num_tokens: int, hidden_size: int, + is_forward: bool, ): - # pylint: disable=missing-function-docstring + """ + Sort chunks with row_id_map. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`. + row_id_map: torch.Tensor + The token to expert mapping tensor of shape `[num_tokens,]`. + probs: torch.Tensor + The probabilities of the input tensor. If it is not None, it will be permuted. + num_tokens: int + Number of tokens in the input tensor. + hidden_size: int + Hidden size of the input tensor. + is_forward: bool + Whether the sort is for forward or backward. + """ output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") if probs is not None: permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") else: permuted_probs = None - grid = (num_tokens,) + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) _sort_chunks_by_map_kernel[grid]( inp, output, @@ -715,5 +985,6 @@ def sort_chunks_by_map( probs.stride(0) if probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None, PERMUTE_PROBS=probs is not None, + FORWARD=is_forward, ) return output, permuted_probs