diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 61a7750b63..d2cb85dd37 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -148,14 +148,17 @@ def run_comparison( ): # Set some parameters if score_function == "sigmoid": - logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1 - logits = logits.unsqueeze(0).repeat(num_tokens, 1) + # Construct the special logits to avoid inf in the sigmoid function + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 - logits = logits.view(num_tokens, num_experts) + logits = logits.view(num_tokens, num_experts) logits.requires_grad = True if enable_bias and score_function == "sigmoid": expert_bias = torch.arange(num_experts, device="cuda") * 0.1 + expert_bias = torch.flip(expert_bias, dims=[0]) expert_bias.requires_grad = True else: expert_bias = None @@ -210,7 +213,7 @@ def run_comparison( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) @pytest.mark.parametrize("num_experts", [128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("group_topk", [None, 4]) @@ -241,7 +244,7 @@ def test_topk_sigmoid( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("use_pre_softmax", [True, False]) @@ -272,12 +275,19 @@ def test_topk_softmax( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) @pytest.mark.parametrize("topk", [4, 8]) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): - logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) + if score_function == "sigmoid": + # Construct the special logits to avoid inf in the sigmoid function + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + else: + logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = logits.view(num_tokens, num_experts) logits.requires_grad = True logits_clone = deepcopy(logits) @@ -307,11 +317,15 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) @pytest.mark.parametrize("topk", [4]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): - probs = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype) + # Construct the special probs to avoid inf in the sigmoid function + offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 + probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + probs = probs.view(num_tokens, num_experts) probs.requires_grad = True tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32) @@ -375,6 +389,6 @@ def profile_topk_softmax( test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4) test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4) test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=256, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4) + test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index ddb09f270a..221963b11b 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -23,8 +23,8 @@ using CompType = double; template __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff, + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, float coeff, DataType* aux_loss, float* Const_buf) { #if __CUDA_ARCH__ >= 900 // Using cooperative_groups to manage the cluster @@ -43,7 +43,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, extern __shared__ float shmem_aux_loss[]; CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); // Clear the shmem - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] = CompType(0); } __syncthreads(); @@ -54,11 +54,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * 2. reduce on the cluster */ // Loop: for all positions in each row - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { CompType tmp = CompType(0); // Loop: for all rows that this warp is responsible for - for (int j = warp_id; j < num_tokens; j += warp_num) { - tmp += CompType(probs[j * num_experts + i]); + for (int j = warp_id; j < num_rows; j += warp_num) { + tmp += CompType(probs[j * num_cols + i]); } atomicAdd(&aggregated_probs_per_expert[i], tmp); } @@ -68,7 +68,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, for (int i = 1; i < block_num; i++) { // Map the shared memory of the block i to the current block CompType* dst_smem = reinterpret_cast(cluster.map_shared_rank(shmem_aux_loss, i)); - for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + for (int j = threadIdx.x; j < num_cols; j += blockDim.x) { atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]); } } @@ -80,7 +80,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * In-place update on shmem */ if (block_id == 0) { - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); } __syncthreads(); @@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); __syncwarp(); if (lane_id == 0) { @@ -113,7 +113,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); // Clear the shmem - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] = CompType(0); } __syncthreads(); @@ -122,11 +122,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce the probs to the aggregated_probs_per_expert */ // Loop: for all positions in each row - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { CompType tmp = CompType(0); // Loop: for all rows that this warp is responsible for - for (int j = warp_id; j < num_tokens; j += warp_num) { - tmp += CompType(probs[j * num_experts + i]); + for (int j = warp_id; j < num_rows; j += warp_num) { + tmp += CompType(probs[j * num_cols + i]); } atomicAdd(&aggregated_probs_per_expert[i], tmp); } @@ -136,7 +136,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: aggregated_probs_per_expert * tokens_per_expert * In-place update on shmem */ - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { + for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); } __syncthreads(); @@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); __syncwarp(); if (lane_id == 0) { @@ -164,16 +164,17 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, template void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, DataType* aux_loss, float* Const_buf, cudaStream_t stream) { - if (cuda::sm_arch(cuda::current_device()) >= 900) { + if (cuda::sm_arch(cuda::current_device()) >= 90) { cudaLaunchConfig_t config = {0}; int cluster_size = 8; config.gridDim = cluster_size; config.blockDim = 1024; - config.dynamicSmemBytes = sizeof(CompType) * num_experts; + config.dynamicSmemBytes = sizeof(CompType) * num_cols; + config.stream = stream; // Update the max cluster size based on the device cudaOccupancyMaxPotentialClusterSize( @@ -189,19 +190,19 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.attrs = attribute; cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, - tokens_per_expert, total_num_tokens, num_tokens, num_experts, topk, coeff, - aux_loss, Const_buf); + tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, + coeff, aux_loss, Const_buf); } else { - size_t smem_size = sizeof(CompType) * num_experts; + size_t smem_size = sizeof(CompType) * num_cols; fused_moe_aux_loss_forward_kernel - <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_tokens, - num_experts, topk, coeff, aux_loss, Const_buf); + <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, + num_rows, num_cols, topk, coeff, aux_loss, Const_buf); } } void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, - int total_num_tokens, int num_tokens, int num_experts, int topk, - float coeff, Tensor& aux_loss, Tensor& Const_buf, + int total_num_tokens, int num_experts, int num_rows, int num_cols, + int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( probs.data.dtype, DataType, @@ -210,45 +211,46 @@ void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_ex fused_moe_aux_loss_forward_kernel_launcher( reinterpret_cast(probs.data.dptr), reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, - num_tokens, num_experts, topk, coeff, reinterpret_cast(aux_loss.data.dptr), + num_experts, num_rows, num_cols, topk, coeff, + reinterpret_cast(aux_loss.data.dptr), reinterpret_cast(Const_buf.data.dptr), stream););); } template __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, - const IndexType* tokens_per_expert, - int num_tokens, int num_experts, - DataType* grad_aux_loss, DataType* grad_probs) { + const IndexType* tokens_per_expert, int num_rows, + int num_cols, DataType* grad_aux_loss, + DataType* grad_probs) { int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp; int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; // Loop: for all positions in each row - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; IndexType tokens_per_expert_i = tokens_per_expert[i]; double grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows - for (int j = global_warp_id; j < num_tokens; j += global_warp_num) { - grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; + for (int j = global_warp_id; j < num_rows; j += global_warp_num) { + grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; } } } template void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, - const IndexType* tokens_per_expert, int num_tokens, - int num_experts, DataType* grad_aux_loss, + const IndexType* tokens_per_expert, int num_rows, + int num_cols, DataType* grad_aux_loss, DataType* grad_probs, cudaStream_t stream) { // Meta data for the kernel int block_size = 256; - int grid_size = (num_tokens + block_size - 1) / block_size; + int grid_size = (num_rows + block_size - 1) / block_size; fused_moe_aux_loss_backward_kernel<<>>( - Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs); + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); } void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, - int num_tokens, int num_experts, Tensor& grad_aux_loss, + int num_rows, int num_cols, Tensor& grad_aux_loss, Tensor& grad_probs, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( grad_aux_loss.data.dtype, DataType, @@ -256,7 +258,7 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p tokens_per_expert.data.dtype, IndexType, fused_moe_aux_loss_backward_kernel_launcher( reinterpret_cast(Const_buf.data.dptr), - reinterpret_cast(tokens_per_expert.data.dptr), num_tokens, num_experts, + reinterpret_cast(tokens_per_expert.data.dptr), num_rows, num_cols, reinterpret_cast(grad_aux_loss.data.dptr), reinterpret_cast(grad_probs.data.dptr), stream););); } @@ -264,25 +266,25 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p } // namespace transformer_engine void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, - int total_num_tokens, int num_tokens, int num_experts, - int topk, float coeff, NVTETensor aux_loss, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); using namespace transformer_engine; fused_moe_aux_loss_forward( *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, - num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss), + num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), *convertNVTETensorCheck(Const_buf), stream); } void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, - const NVTETensor tokens_per_expert, int num_tokens, - int num_experts, NVTETensor grad_aux_loss, - NVTETensor grad_probs, cudaStream_t stream) { + const NVTETensor tokens_per_expert, int num_rows, + int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs, + cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_backward); using namespace transformer_engine; fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf), - *convertNVTETensorCheck(tokens_per_expert), num_tokens, num_experts, + *convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols, *convertNVTETensorCheck(grad_aux_loss), *convertNVTETensorCheck(grad_probs), stream); } diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 7a3421a4e6..8cf4b222a5 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -96,8 +96,9 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou * \param[in] probs Probabilities from the forward pass. * \param[in] tokens_per_expert Number of tokens per expert. * \param[in] total_num_tokens Number of total tokens. Will be used in seq/global aux loss. - * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. + * \param[in] num_rows Number of rows of probs. + * \param[in] num_cols Number of columns of probs. * \param[in] topk Topk value. * \param[in] coeff Coefficient. * \param[out] aux_loss Output GPU scalar for auxiliary loss. @@ -105,24 +106,24 @@ void nvte_fused_score_for_moe_aux_loss_backward(const NVTETensor intermediate_ou * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, - int total_num_tokens, int num_tokens, int num_experts, - int topk, float coeff, NVTETensor aux_loss, + int total_num_tokens, int num_experts, int num_rows, + int num_cols, int topk, float coeff, NVTETensor aux_loss, NVTETensor Const_buf, cudaStream_t stream); /*! \brief Backward pass for auxiliary loss. * * \param[in] Const_buf Constant buffer from the forward pass. * \param[in] tokens_per_expert Number of tokens per expert. - * \param[in] num_tokens Number of total tokens. - * \param[in] num_experts Number of experts. + * \param[in] num_rows Number of rows of probs. + * \param[in] num_cols Number of columns of probs. * \param[in] grad_aux_loss Gradient of auxiliary loss. * \param[out] grad_probs Gradient of probs. * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, - const NVTETensor tokens_per_expert, int num_tokens, - int num_experts, NVTETensor grad_aux_loss, - NVTETensor grad_probs, cudaStream_t stream); + const NVTETensor tokens_per_expert, int num_rows, + int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 61b22fc34a..25e8582220 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -38,11 +38,12 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff); + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff); -at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, - int num_tokens, int num_experts, at::Tensor grad_aux_loss); +at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, + int num_cols, at::Tensor grad_aux_loss); /*************************************************************************************************** * Permutation diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 92ae618f2c..c9b5a67a78 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -278,11 +278,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd"); m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), - py::arg("num_tokens"), py::arg("num_experts"), py::arg("topk"), py::arg("coeff"), - "Fused aux loss fwd"); + py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), + py::arg("coeff"), "Fused aux loss fwd"); m.def("fused_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_moe_aux_loss_bwd, - py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_tokens"), - py::arg("num_experts"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), + py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); // Misc m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 4aa10b203a..9befe14f88 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -145,8 +145,9 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, - int total_num_tokens, int num_tokens, - int num_experts, int topk, float coeff) { + int total_num_tokens, int num_experts, + int num_rows, int num_cols, int topk, + float coeff) { TORCH_CHECK(topk > 0, "topk must be greater than 0"); TORCH_CHECK(total_num_tokens > 0, "total_num_tokens must be greater than 0"); TORCH_CHECK(num_experts > 0, "num_experts must be greater than 0"); @@ -161,17 +162,17 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); nvte_fused_moe_aux_loss_forward(probs_cu.data(), tokens_per_expert_cu.data(), total_num_tokens, - num_tokens, num_experts, topk, coeff, aux_loss_cu.data(), + num_experts, num_rows, num_cols, topk, coeff, aux_loss_cu.data(), Const_buf_cu.data(), at::cuda::getCurrentCUDAStream()); return std::make_tuple(aux_loss, Const_buf); } -at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, - int num_tokens, int num_experts, at::Tensor grad_aux_loss) { +at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows, + int num_cols, at::Tensor grad_aux_loss) { // Create the output tensor - at::Tensor grad_probs = at::empty({num_tokens, num_experts}, - at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); + at::Tensor grad_probs = + at::empty({num_rows, num_cols}, at::dtype(grad_aux_loss.scalar_type()).device(at::kCUDA)); auto Const_buf_cu = makeTransformerEngineTensor(Const_buf); auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); @@ -179,8 +180,8 @@ at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_ex auto grad_probs_cu = makeTransformerEngineTensor(grad_probs); // Meta data for the kernel - nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_tokens, - num_experts, grad_aux_loss_cu.data(), grad_probs_cu.data(), + nvte_fused_moe_aux_loss_backward(Const_buf_cu.data(), tokens_per_expert_cu.data(), num_rows, + num_cols, grad_aux_loss_cu.data(), grad_probs_cu.data(), at::cuda::getCurrentCUDAStream()); return grad_probs; diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 8683a80964..db5114ae04 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -27,6 +27,12 @@ def forward( expert_bias: torch.Tensor, ): # pylint: disable=missing-function-docstring + # Save the shape of the logits + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + # Get the metadata of the viewed logits + num_tokens = logits.size(0) + num_experts = logits.size(1) probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd( logits, topk, @@ -37,9 +43,11 @@ def forward( score_function, expert_bias, ) + # Restore the shape + probs = probs.view(tensor_shape) ctx.save_for_backward(routing_map, intermediate_output) - ctx.num_tokens = logits.size(0) - ctx.num_experts = logits.size(1) + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts ctx.use_pre_softmax = use_pre_softmax ctx.topk = topk ctx.scaling_factor = scaling_factor @@ -50,17 +58,23 @@ def forward( def backward(ctx, grad_probs, _): # pylint: disable=missing-function-docstring routing_map, intermediate_output = ctx.saved_tensors + # Save the shape of the grad_probs + tensor_shape = grad_probs.shape + # Adjust the shape of the grad_probs to 2D shape + grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1]) grad_logits = tex.fused_topk_with_score_function_bwd( ctx.num_tokens, ctx.num_experts, routing_map, intermediate_output, - grad_probs.contiguous(), + grad_probs, ctx.topk, ctx.use_pre_softmax, ctx.scaling_factor, ctx.score_function, ) + # Restore the shape + grad_logits = grad_logits.view(tensor_shape) return grad_logits, None, None, None, None, None, None, None @@ -124,6 +138,12 @@ def forward( score_function: str, ): # pylint: disable=missing-function-docstring + # Save the shape of the logits + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + # Get the metadata of the viewed logits + num_tokens = logits.size(0) + num_experts = logits.size(1) scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd( logits=logits, topk=topk, @@ -132,22 +152,28 @@ def forward( ctx.save_for_backward(intermediate_output) ctx.topk = topk ctx.score_function = score_function - ctx.num_tokens = logits.size(0) - ctx.num_experts = logits.size(1) + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts return routing_map, scores @staticmethod def backward(ctx, _, grad_scores): # pylint: disable=missing-function-docstring intermediate_output = ctx.saved_tensors[0] + # Save the shape of the grad_scores + tensor_shape = grad_scores.shape + # Adjust the shape of the grad_scores to 2D shape + grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1]) grad_logits = tex.fused_score_for_moe_aux_loss_bwd( num_tokens=ctx.num_tokens, num_experts=ctx.num_experts, intermediate_output=intermediate_output, - grad_scores=grad_scores.contiguous(), + grad_scores=grad_scores, topk=ctx.topk, score_function=ctx.score_function, ) + # Restore the shape + grad_logits = grad_logits.view(tensor_shape) return grad_logits, None, None @@ -189,19 +215,21 @@ def forward( coeff: float, ): # pylint: disable=missing-function-docstring - num_tokens = probs.size(0) + num_rows = probs.size(0) + num_cols = probs.size(1) aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd( probs=probs, tokens_per_expert=tokens_per_expert, total_num_tokens=total_num_tokens, - num_tokens=num_tokens, num_experts=num_experts, + num_rows=num_rows, + num_cols=num_cols, topk=topk, coeff=coeff, ) ctx.save_for_backward(Const_buf, tokens_per_expert) - ctx.num_tokens = num_tokens - ctx.num_experts = num_experts + ctx.num_rows = num_rows + ctx.num_cols = num_cols return aux_loss @staticmethod @@ -211,8 +239,8 @@ def backward(ctx, grad_aux_loss): grad_probs = tex.fused_moe_aux_loss_bwd( Const_buf=Const_buf, tokens_per_expert=tokens_per_expert, - num_tokens=ctx.num_tokens, - num_experts=ctx.num_experts, + num_rows=ctx.num_rows, + num_cols=ctx.num_cols, grad_aux_loss=grad_aux_loss, ) return grad_probs, None, None, None, None, None