From 4be092a8fc3134b089a7ebedf93906795ef57a91 Mon Sep 17 00:00:00 2001 From: tongliu Date: Thu, 10 Jul 2025 22:41:43 -0700 Subject: [PATCH 1/6] fix underterminsic problem in CI Signed-off-by: tongliu --- tests/pytorch/test_fused_router.py | 41 ++++++++++++++----- .../common/fused_router/fused_moe_aux_loss.cu | 1 + 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 61a7750b63..e760984014 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -148,11 +148,15 @@ def run_comparison( ): # Set some parameters if score_function == "sigmoid": - logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1 + # Construct the special logits to avoid inf in the sigmoid function + logits = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01 logits = logits.unsqueeze(0).repeat(num_tokens, 1) + random_values = torch.rand(num_tokens, num_experts, device="cuda") + _, indices = torch.sort(random_values, dim=1) + logits = torch.gather(logits, 1, indices) else: logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 - logits = logits.view(num_tokens, num_experts) + logits = logits.view(num_tokens, num_experts) logits.requires_grad = True if enable_bias and score_function == "sigmoid": expert_bias = torch.arange(num_experts, device="cuda") * 0.1 @@ -210,7 +214,7 @@ def run_comparison( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("group_topk", [None, 4]) @@ -241,7 +245,7 @@ def test_topk_sigmoid( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("use_pre_softmax", [True, False]) @@ -272,12 +276,21 @@ def test_topk_softmax( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): - logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) + if score_function == "sigmoid": + # Construct the special logits to avoid inf in the sigmoid function + logits = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + random_values = torch.rand(num_tokens, num_experts, device="cuda") + _, indices = torch.sort(random_values, dim=1) + logits = torch.gather(logits, 1, indices) + else: + logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = logits.view(num_tokens, num_experts) logits.requires_grad = True logits_clone = deepcopy(logits) @@ -307,11 +320,17 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) @pytest.mark.parametrize("topk", [4]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): - probs = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) + # Construct the special probs to avoid inf in the sigmoid function + probs = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01 + probs = probs.unsqueeze(0).repeat(num_tokens, 1) + random_values = torch.rand(num_tokens, num_experts, device="cuda") + _, indices = torch.sort(random_values, dim=1) + probs = torch.gather(probs, 1, indices) + probs = probs.view(num_tokens, num_experts) probs.requires_grad = True tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32) @@ -375,6 +394,6 @@ def profile_topk_softmax( test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4) test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4) test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=256, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index ddb09f270a..fae37226c7 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -174,6 +174,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.gridDim = cluster_size; config.blockDim = 1024; config.dynamicSmemBytes = sizeof(CompType) * num_experts; + config.stream = stream; // Update the max cluster size based on the device cudaOccupancyMaxPotentialClusterSize( From 1f6d0ade576300af53348c6e4b409d98a13dffb7 Mon Sep 17 00:00:00 2001 From: tongliu Date: Fri, 11 Jul 2025 00:58:47 -0700 Subject: [PATCH 2/6] fix bug on mbs>1 Signed-off-by: tongliu --- .../common/fused_router/fused_moe_aux_loss.cu | 85 ++++++++++--------- .../include/transformer_engine/fused_router.h | 15 ++-- transformer_engine/pytorch/csrc/extensions.h | 7 +- .../pytorch/csrc/extensions/pybind.cpp | 7 +- .../pytorch/csrc/extensions/router.cpp | 15 ++-- transformer_engine/pytorch/router.py | 52 +++++++++--- 6 files changed, 110 insertions(+), 71 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index fae37226c7..b28cf934e6 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -23,8 +23,10 @@ using CompType = double; template __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff, + int total_num_tokens, + int num_experts, + int num_rows, int num_cols, + int topk, float coeff, DataType* aux_loss, float* Const_buf) { #if __CUDA_ARCH__ >= 900 // Using cooperative_groups to manage the cluster @@ -43,7 +45,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, extern __shared__ float shmem_aux_loss[]; CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); // Clear the shmem - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] = CompType(0); } __syncthreads(); @@ -54,11 +56,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * 2. reduce on the cluster */ // Loop: for all positions in each row - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { CompType tmp = CompType(0); // Loop: for all rows that this warp is responsible for - for (int j = warp_id; j < num_tokens; j += warp_num) { - tmp += CompType(probs[j * num_experts + i]); + for (int j = warp_id; j < num_rows; j += warp_num) { + tmp += CompType(probs[j * num_cols + i]); } atomicAdd(&aggregated_probs_per_expert[i], tmp); } @@ -68,7 +70,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, for (int i = 1; i < block_num; i++) { // Map the shared memory of the block i to the current block CompType* dst_smem = reinterpret_cast(cluster.map_shared_rank(shmem_aux_loss, i)); - for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + for (int j = threadIdx.x; j < num_cols; j += blockDim.x) { atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]); } } @@ -80,7 +82,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * In-place update on shmem */ if (block_id == 0) { - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); } __syncthreads(); @@ -90,7 +92,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); __syncwarp(); if (lane_id == 0) { @@ -113,7 +115,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); // Clear the shmem - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] = CompType(0); } __syncthreads(); @@ -122,11 +124,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce the probs to the aggregated_probs_per_expert */ // Loop: for all positions in each row - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { CompType tmp = CompType(0); // Loop: for all rows that this warp is responsible for - for (int j = warp_id; j < num_tokens; j += warp_num) { - tmp += CompType(probs[j * num_experts + i]); + for (int j = warp_id; j < num_rows; j += warp_num) { + tmp += CompType(probs[j * num_cols + i]); } atomicAdd(&aggregated_probs_per_expert[i], tmp); } @@ -136,7 +138,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: aggregated_probs_per_expert * tokens_per_expert * In-place update on shmem */ - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); } __syncthreads(); @@ -146,7 +148,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); __syncwarp(); if (lane_id == 0) { @@ -164,8 +166,10 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, template void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff, + int total_num_tokens, + int num_experts, + int num_rows, int num_cols, + int topk, float coeff, DataType* aux_loss, float* Const_buf, cudaStream_t stream) { if (cuda::sm_arch(cuda::current_device()) >= 900) { @@ -173,7 +177,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, int cluster_size = 8; config.gridDim = cluster_size; config.blockDim = 1024; - config.dynamicSmemBytes = sizeof(CompType) * num_experts; + config.dynamicSmemBytes = sizeof(CompType) * num_cols; config.stream = stream; // Update the max cluster size based on the device @@ -190,19 +194,19 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.attrs = attribute; cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, - tokens_per_expert, total_num_tokens, num_tokens, num_experts, topk, coeff, + tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf); } else { - size_t smem_size = sizeof(CompType) * num_experts; + size_t smem_size = sizeof(CompType) * num_cols; fused_moe_aux_loss_forward_kernel - <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_tokens, - num_experts, topk, coeff, aux_loss, Const_buf); + <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, + num_rows, num_cols, topk, coeff, aux_loss, Const_buf); } } void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, - int total_num_tokens, int num_tokens, int num_experts, int topk, - float coeff, Tensor& aux_loss, Tensor& Const_buf, + int total_num_tokens, int num_experts, int num_rows, int num_cols, + int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( probs.data.dtype, DataType, @@ -211,45 +215,46 @@ void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_ex fused_moe_aux_loss_forward_kernel_launcher( reinterpret_cast(probs.data.dptr), reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, - num_tokens, num_experts, topk, coeff, reinterpret_cast(aux_loss.data.dptr), + num_experts, num_rows, num_cols, topk, coeff, reinterpret_cast(aux_loss.data.dptr), reinterpret_cast(Const_buf.data.dptr), stream););); } template __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, const IndexType* tokens_per_expert, - int num_tokens, int num_experts, + int num_rows, int num_cols, DataType* grad_aux_loss, DataType* grad_probs) { int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp; int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; // Loop: for all positions in each row - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; IndexType tokens_per_expert_i = tokens_per_expert[i]; double grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows - for (int j = global_warp_id; j < num_tokens; j += global_warp_num) { - grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; + for (int j = global_warp_id; j < num_rows; j += global_warp_num) { + grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; } } } template void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, - const IndexType* tokens_per_expert, int num_tokens, - int num_experts, DataType* grad_aux_loss, + const IndexType* tokens_per_expert, + int num_rows, int num_cols, + DataType* grad_aux_loss, DataType* grad_probs, cudaStream_t stream) { // Meta data for the kernel int block_size = 256; - int grid_size = (num_tokens + block_size - 1) / block_size; + int grid_size = (num_rows + block_size - 1) / block_size; fused_moe_aux_loss_backward_kernel<<>>( - Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs); + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); } void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, - int num_tokens, int num_experts, Tensor& grad_aux_loss, + int num_rows, int num_cols, Tensor& grad_aux_loss, Tensor& grad_probs, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( grad_aux_loss.data.dtype, DataType, @@ -257,7 +262,7 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p tokens_per_expert.data.dtype, IndexType, fused_moe_aux_loss_backward_kernel_launcher( reinterpret_cast(Const_buf.data.dptr), - reinterpret_cast(tokens_per_expert.data.dptr), num_tokens, num_experts, + reinterpret_cast(tokens_per_expert.data.dptr), num_rows, num_cols, reinterpret_cast(grad_aux_loss.data.dptr), reinterpret_cast(grad_probs.data.dptr), stream););); } @@ -265,25 +270,25 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p } // namespace transformer_engine void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, - int total_num_tokens, int num_tokens, int num_experts, + int total_num_tokens, int num_experts, int num_rows, int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); using namespace transformer_engine; fused_moe_aux_loss_forward( *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, - num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss), + num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), *convertNVTETensorCheck(Const_buf), stream); } void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, - const NVTETensor tokens_per_expert, int num_tokens, - int num_experts, NVTETensor grad_aux_loss, + const NVTETensor tokens_per_expert, int num_rows, int num_cols, + NVTETensor grad_aux_loss, NVTETensor grad_probs, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_backward); using namespace transformer_engine; fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf), - *convertNVTETensorCheck(tokens_per_expert), num_tokens, num_experts, + *convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols, *convertNVTETensorCheck(grad_aux_loss), *convertNVTETensorCheck(grad_probs), stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 7a3421a4e6..b385e7a076 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -96,8 +96,9 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou * \param[in] probs Probabilities from the forward pass. * \param[in] tokens_per_expert Number of tokens per expert. * \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss. - * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. + * \param[in] num_rows Number of rows of probs. + * \param[in] num_cols Number of columns of probs. * \param[in] topk Topk value. * \param[in] coeff Coefficient. * \param[out] aux_loss Output GPU scalar for auxiliary loss. @@ -105,7 +106,8 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, - int total_num_tokens, int num_tokens, int num_experts, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream); @@ -113,15 +115,16 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to * * \param[in] Const_buf Constant buffer from the forward pass. * \param[in] tokens_per_expert Number of tokens per expert. - * \param[in] num_tokens Number of total tokens. - * \param[in] num_experts Number of experts. + * \param[in] num_rows Number of rows of probs. + * \param[in] num_cols Number of columns of probs. * \param[in] grad_aux_loss Gradient of auxiliary loss. * \param[out] grad_probs Gradient of probs. * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, - const NVTETensor tokens_per_expert, int num_tokens, - int num_experts, NVTETensor grad_aux_loss, + const NVTETensor tokens_per_expert, + int num_rows, int num_cols, + NVTETensor grad_aux_loss, NVTETensor grad_probs, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 61b22fc34a..f6c09d45f3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -38,11 +38,12 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff); + int total_num_tokens, + int num_experts, int num_rows, int num_cols, + int topk, float coeff); at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, - int num_tokens, int num_experts, at::Tensor grad_aux_loss); + int num_rows, int num_cols, at::Tensor grad_aux_loss); /*************************************************************************************************** * Permutation diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 92ae618f2c..f20c513cf8 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -278,11 +278,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd"); m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), - py::arg("num_tokens"), py::arg("num_experts"), py::arg("topk"), py::arg("coeff"), + py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), + py::arg("coeff"), "Fused aux loss fwd"); m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd, - py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_tokens"), - py::arg("num_experts"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("num_cols"), + py::arg("grad_aux_loss"), "Fused aux loss bwd"); // Misc m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 4aa10b203a..3a973ce028 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -145,8 +145,9 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff) { + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff) { TORCH_CHECK(topk > 0, "topk must be greater than 0"); TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0"); TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0"); @@ -161,16 +162,16 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, - num_tokens, num_experts, topk, coeff, aux_loss_cu.data(), + num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); return std::make_tuple(aux_loss, Const_buf); } at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, - int num_tokens, int num_experts, at::Tensor grad_aux_loss) { + int num_rows, int num_cols, at::Tensor grad_aux_loss) { // Create the output tensor - at::Tensor grad_probs = at::empty({num_tokens, num_experts}, + at::Tensor grad_probs = at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); @@ -179,8 +180,8 @@ at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_ex auto grad_probs_cu = makeTransformerEngineTensor(grad_probs); // Meta data for the kernel - nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_tokens, - num_experts, grad_aux_loss_cu.data(), grad_probs_cu.data(), + nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_rows, + num_cols, grad_aux_loss_cu.data(), grad_probs_cu.data(), at::cuda::getCurrentCUDAStream()); return grad_probs; diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 8683a80964..db5114ae04 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -27,6 +27,12 @@ def forward( expert_bias: torch.Tensor, ): # pylint: disable=missing-function-docstring + # Save the shape of the logits + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + # Get the metadata of the viewed logits + num_tokens = logits.size(0) + num_experts = logits.size(1) probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd( logits, topk, @@ -37,9 +43,11 @@ def forward( score_function, expert_bias, ) + # Restore the shape + probs = probs.view(tensor_shape) ctx.save_for_backward(routing_map, intermediate_output) - ctx.num_tokens = logits.size(0) - ctx.num_experts = logits.size(1) + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts ctx.use_pre_softmax = use_pre_softmax ctx.topk = topk ctx.scaling_factor = scaling_factor @@ -50,17 +58,23 @@ def forward( def backward(ctx, grad_probs, _): # pylint: disable=missing-function-docstring routing_map, intermediate_output = ctx.saved_tensors + # Save the shape of the grad_probs + tensor_shape = grad_probs.shape + # Adjust the shape of the grad_probs to 2D shape + grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1]) grad_logits = tex.fused_topk_with_score_function_bwd( ctx.num_tokens, ctx.num_experts, routing_map, intermediate_output, - grad_probs.contiguous(), + grad_probs, ctx.topk, ctx.use_pre_softmax, ctx.scaling_factor, ctx.score_function, ) + # Restore the shape + grad_logits = grad_logits.view(tensor_shape) return grad_logits, None, None, None, None, None, None, None @@ -124,6 +138,12 @@ def forward( score_function: str, ): # pylint: disable=missing-function-docstring + # Save the shape of the logits + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + # Get the metadata of the viewed logits + num_tokens = logits.size(0) + num_experts = logits.size(1) scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd( logits=logits, topk=topk, @@ -132,22 +152,28 @@ def forward( ctx.save_for_backward(intermediate_output) ctx.topk = topk ctx.score_function = score_function - ctx.num_tokens = logits.size(0) - ctx.num_experts = logits.size(1) + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts return routing_map, scores @staticmethod def backward(ctx, _, grad_scores): # pylint: disable=missing-function-docstring intermediate_output = ctx.saved_tensors[0] + # Save the shape of the grad_scores + tensor_shape = grad_scores.shape + # Adjust the shape of the grad_scores to 2D shape + grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1]) grad_logits = tex.fused_score_for_moe_aux_loss_bwd( num_tokens=ctx.num_tokens, num_experts=ctx.num_experts, intermediate_output=intermediate_output, - grad_scores=grad_scores.contiguous(), + grad_scores=grad_scores, topk=ctx.topk, score_function=ctx.score_function, ) + # Restore the shape + grad_logits = grad_logits.view(tensor_shape) return grad_logits, None, None @@ -189,19 +215,21 @@ def forward( coeff: float, ): # pylint: disable=missing-function-docstring - num_tokens = probs.size(0) + num_rows = probs.size(0) + num_cols = probs.size(1) aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd( probs=probs, tokens_per_expert=tokens_per_expert, total_num_tokens=total_num_tokens, - num_tokens=num_tokens, num_experts=num_experts, + num_rows=num_rows, + num_cols=num_cols, topk=topk, coeff=coeff, ) ctx.save_for_backward(Const_buf, tokens_per_expert) - ctx.num_tokens = num_tokens - ctx.num_experts = num_experts + ctx.num_rows = num_rows + ctx.num_cols = num_cols return aux_loss @staticmethod @@ -211,8 +239,8 @@ def backward(ctx, grad_aux_loss): grad_probs = tex.fused_moe_aux_loss_bwd( Const_buf=Const_buf, tokens_per_expert=tokens_per_expert, - num_tokens=ctx.num_tokens, - num_experts=ctx.num_experts, + num_rows=ctx.num_rows, + num_cols=ctx.num_cols, grad_aux_loss=grad_aux_loss, ) return grad_probs, None, None, None, None, None From 2633320baf7fea9b654dad16f77c34597d5f952d Mon Sep 17 00:00:00 2001 From: tongliu Date: Fri, 11 Jul 2025 03:59:46 -0700 Subject: [PATCH 3/6] fix bug on sm dispatcher Signed-off-by: tongliu --- transformer_engine/common/fused_router/fused_moe_aux_loss.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index b28cf934e6..8f97828f4b 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -172,7 +172,7 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, int topk, float coeff, DataType* aux_loss, float* Const_buf, cudaStream_t stream) { - if (cuda::sm_arch(cuda::current_device()) >= 900) { + if (cuda::sm_arch(cuda::current_device()) >= 90) { cudaLaunchConfig_t config = {0}; int cluster_size = 8; config.gridDim = cluster_size; From 4decb72dc9d795634c490802c704c7af34d0ed62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:04:11 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_router/fused_moe_aux_loss.cu | 40 +++++++++---------- .../include/transformer_engine/fused_router.h | 12 +++--- transformer_engine/pytorch/csrc/extensions.h | 10 ++--- .../pytorch/csrc/extensions/pybind.cpp | 7 ++-- .../pytorch/csrc/extensions/router.cpp | 8 ++-- 5 files changed, 35 insertions(+), 42 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 8f97828f4b..221963b11b 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -23,10 +23,8 @@ using CompType = double; template __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, - int num_experts, - int num_rows, int num_cols, - int topk, float coeff, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, float coeff, DataType* aux_loss, float* Const_buf) { #if __CUDA_ARCH__ >= 900 // Using cooperative_groups to manage the cluster @@ -166,10 +164,8 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, template void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, - int num_experts, - int num_rows, int num_cols, - int topk, float coeff, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, DataType* aux_loss, float* Const_buf, cudaStream_t stream) { if (cuda::sm_arch(cuda::current_device()) >= 90) { @@ -194,8 +190,8 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.attrs = attribute; cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, - tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, - aux_loss, Const_buf); + tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, + coeff, aux_loss, Const_buf); } else { size_t smem_size = sizeof(CompType) * num_cols; fused_moe_aux_loss_forward_kernel @@ -215,15 +211,16 @@ void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_ex fused_moe_aux_loss_forward_kernel_launcher( reinterpret_cast(probs.data.dptr), reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, - num_experts, num_rows, num_cols, topk, coeff, reinterpret_cast(aux_loss.data.dptr), + num_experts, num_rows, num_cols, topk, coeff, + reinterpret_cast(aux_loss.data.dptr), reinterpret_cast(Const_buf.data.dptr), stream););); } template __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, - const IndexType* tokens_per_expert, - int num_rows, int num_cols, - DataType* grad_aux_loss, DataType* grad_probs) { + const IndexType* tokens_per_expert, int num_rows, + int num_cols, DataType* grad_aux_loss, + DataType* grad_probs) { int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp; int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; @@ -242,9 +239,8 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, template void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, - const IndexType* tokens_per_expert, - int num_rows, int num_cols, - DataType* grad_aux_loss, + const IndexType* tokens_per_expert, int num_rows, + int num_cols, DataType* grad_aux_loss, DataType* grad_probs, cudaStream_t stream) { // Meta data for the kernel int block_size = 256; @@ -270,8 +266,8 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p } // namespace transformer_engine void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, - int total_num_tokens, int num_experts, int num_rows, int num_cols, - int topk, float coeff, NVTETensor aux_loss, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); using namespace transformer_engine; @@ -282,9 +278,9 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to } void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, - const NVTETensor tokens_per_expert, int num_rows, int num_cols, - NVTETensor grad_aux_loss, - NVTETensor grad_probs, cudaStream_t stream) { + const NVTETensor tokens_per_expert, int num_rows, + int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_backward); using namespace transformer_engine; fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf), diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index b385e7a076..8cf4b222a5 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -106,9 +106,8 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, - int total_num_tokens, int num_experts, - int num_rows, int num_cols, - int topk, float coeff, NVTETensor aux_loss, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream); /*! \brief Backward pass for auxiliary loss. @@ -122,10 +121,9 @@ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor to * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, - const NVTETensor tokens_per_expert, - int num_rows, int num_cols, - NVTETensor grad_aux_loss, - NVTETensor grad_probs, cudaStream_t stream); + const NVTETensor tokens_per_expert, int num_rows, + int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f6c09d45f3..25e8582220 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -38,12 +38,12 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, - int total_num_tokens, - int num_experts, int num_rows, int num_cols, - int topk, float coeff); + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff); -at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, - int num_rows, int num_cols, at::Tensor grad_aux_loss); +at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, + int num_cols, at::Tensor grad_aux_loss); /*************************************************************************************************** * Permutation diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f20c513cf8..c9b5a67a78 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -279,11 +279,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), - py::arg("coeff"), - "Fused aux loss fwd"); + py::arg("coeff"), "Fused aux loss fwd"); m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd, - py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("num_cols"), - py::arg("grad_aux_loss"), "Fused aux loss bwd"); + py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), + py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); // Misc m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 3a973ce028..9befe14f88 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -168,11 +168,11 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, return std::make_tuple(aux_loss, Const_buf); } -at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, - int num_rows, int num_cols, at::Tensor grad_aux_loss) { +at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, + int num_cols, at::Tensor grad_aux_loss) { // Create the output tensor - at::Tensor grad_probs = at::empty({num_rows, num_cols}, - at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); + at::Tensor grad_probs = + at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); From 54851b60f4781674ea7c7ee534eefd585f89622a Mon Sep 17 00:00:00 2001 From: tongliu Date: Sun, 13 Jul 2025 20:05:26 -0700 Subject: [PATCH 5/6] fix CI initial values Signed-off-by: tongliu --- tests/pytorch/test_fused_router.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index e760984014..2a7be859b7 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -149,17 +149,16 @@ def run_comparison( # Set some parameters if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function - logits = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01 - logits = logits.unsqueeze(0).repeat(num_tokens, 1) - random_values = torch.rand(num_tokens, num_experts, device="cuda") - _, indices = torch.sort(random_values, dim=1) - logits = torch.gather(logits, 1, indices) + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 logits = logits.view(num_tokens, num_experts) logits.requires_grad = True if enable_bias and score_function == "sigmoid": expert_bias = torch.arange(num_experts, device="cuda") * 0.1 + expert_bias = torch.flip(expert_bias, dims=[0]) expert_bias.requires_grad = True else: expert_bias = None @@ -214,7 +213,7 @@ def run_comparison( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) @pytest.mark.parametrize("num_experts", [128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("group_topk", [None, 4]) @@ -283,11 +282,9 @@ def test_topk_softmax( def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function - logits = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01 - logits = logits.unsqueeze(0).repeat(num_tokens, 1) - random_values = torch.rand(num_tokens, num_experts, device="cuda") - _, indices = torch.sort(random_values, dim=1) - logits = torch.gather(logits, 1, indices) + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 logits = logits.view(num_tokens, num_experts) @@ -325,11 +322,9 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("topk", [4]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): # Construct the special probs to avoid inf in the sigmoid function - probs = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01 - probs = probs.unsqueeze(0).repeat(num_tokens, 1) - random_values = torch.rand(num_tokens, num_experts, device="cuda") - _, indices = torch.sort(random_values, dim=1) - probs = torch.gather(probs, 1, indices) + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) probs = probs.view(num_tokens, num_experts) probs.requires_grad = True From a610720b8f77800ec74e4aed701a5703b15f9bb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Jul 2025 03:06:57 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fused_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 2a7be859b7..d2cb85dd37 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -322,7 +322,7 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("topk", [4]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): # Construct the special probs to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) probs = probs.view(num_tokens, num_experts)