Skip to content

[BUG fix] solve the CI bug of router fusion #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,15 @@ def run_comparison(
):
# Set some parameters
if score_function == "sigmoid":
logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
# Construct the special logits to avoid inf in the sigmoid function
logits = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01
logits = logits.unsqueeze(0).repeat(num_tokens, 1)
random_values = torch.rand(num_tokens, num_experts, device="cuda")
_, indices = torch.sort(random_values, dim=1)
logits = torch.gather(logits, 1, indices)
Comment on lines +151 to +156
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how this helps avoid infs? Was the problem that logits became too large, causing sigmoid to output 1.0 in multiple entries? Also, I don't understand the benefit from forcing logits >= 1. Previously we had logits >= 0 and sigmoid(0) = 0.5, so I wouldn't expect us to have had divide-by-zero problems.

else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
Expand Down Expand Up @@ -210,7 +214,7 @@ def run_comparison(


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
Expand Down Expand Up @@ -241,7 +245,7 @@ def test_topk_sigmoid(


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("use_pre_softmax", [True, False])
Expand Down Expand Up @@ -272,12 +276,21 @@ def test_topk_softmax(


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
logits = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01
logits = logits.unsqueeze(0).repeat(num_tokens, 1)
random_values = torch.rand(num_tokens, num_experts, device="cuda")
_, indices = torch.sort(random_values, dim=1)
logits = torch.gather(logits, 1, indices)
else:
logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True

logits_clone = deepcopy(logits)
Expand Down Expand Up @@ -307,11 +320,17 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 32111])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4])
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
probs = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
# Construct the special probs to avoid inf in the sigmoid function
probs = 1 + torch.arange(num_experts, device="cuda", dtype=dtype) * 0.01
probs = probs.unsqueeze(0).repeat(num_tokens, 1)
random_values = torch.rand(num_tokens, num_experts, device="cuda")
_, indices = torch.sort(random_values, dim=1)
probs = torch.gather(probs, 1, indices)
probs = probs.view(num_tokens, num_experts)
probs.requires_grad = True

tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
Expand Down Expand Up @@ -375,6 +394,6 @@ def profile_topk_softmax(
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=32111, num_experts=256, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4)
test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4)
94 changes: 48 additions & 46 deletions transformer_engine/common/fused_router/fused_moe_aux_loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ using CompType = double;
template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff,
int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf) {
#if __CUDA_ARCH__ >= 900
// Using cooperative_groups to manage the cluster
Expand All @@ -43,7 +43,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
extern __shared__ float shmem_aux_loss[];
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);
// Clear the shmem
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
Expand All @@ -54,11 +54,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* 2. reduce on the cluster
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_tokens; j += warp_num) {
tmp += CompType(probs[j * num_experts + i]);
for (int j = warp_id; j < num_rows; j += warp_num) {
tmp += CompType(probs[j * num_cols + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
Expand All @@ -68,7 +68,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
for (int i = 1; i < block_num; i++) {
// Map the shared memory of the block i to the current block
CompType* dst_smem = reinterpret_cast<CompType*>(cluster.map_shared_rank(shmem_aux_loss, i));
for (int j = threadIdx.x; j < num_experts; j += blockDim.x) {
for (int j = threadIdx.x; j < num_cols; j += blockDim.x) {
atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]);
}
}
Expand All @@ -80,7 +80,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* In-place update on shmem
*/
if (block_id == 0) {
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
Expand All @@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
__syncwarp();

if (lane_id == 0) {
Expand All @@ -113,7 +113,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
CompType* aggregated_probs_per_expert = reinterpret_cast<CompType*>(shmem_aux_loss);

// Clear the shmem
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] = CompType(0);
}
__syncthreads();
Expand All @@ -122,11 +122,11 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce the probs to the aggregated_probs_per_expert
*/
// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
CompType tmp = CompType(0);
// Loop: for all rows that this warp is responsible for
for (int j = warp_id; j < num_tokens; j += warp_num) {
tmp += CompType(probs[j * num_experts + i]);
for (int j = warp_id; j < num_rows; j += warp_num) {
tmp += CompType(probs[j * num_cols + i]);
}
atomicAdd(&aggregated_probs_per_expert[i], tmp);
}
Expand All @@ -136,7 +136,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: aggregated_probs_per_expert * tokens_per_expert
* In-place update on shmem
*/
for (int i = threadIdx.x; i < num_experts; i += blockDim.x) {
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]);
}
__syncthreads();
Expand All @@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Reduce to get the sum of aggregated_probs_per_expert
*/
CompType intermediate_result =
warp_reduce_on_shmem(aggregated_probs_per_expert, num_experts, sum, lane_id);
warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id);
__syncwarp();

if (lane_id == 0) {
Expand All @@ -164,16 +164,17 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
template <typename DataType, typename IndexType>
void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
const IndexType* tokens_per_expert,
int total_num_tokens, int num_tokens,
int num_experts, int topk, float coeff,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff,
DataType* aux_loss, float* Const_buf,
cudaStream_t stream) {
if (cuda::sm_arch(cuda::current_device()) >= 900) {
if (cuda::sm_arch(cuda::current_device()) >= 90) {
cudaLaunchConfig_t config = {0};
int cluster_size = 8;
config.gridDim = cluster_size;
config.blockDim = 1024;
config.dynamicSmemBytes = sizeof(CompType) * num_experts;
config.dynamicSmemBytes = sizeof(CompType) * num_cols;
config.stream = stream;

// Update the max cluster size based on the device
cudaOccupancyMaxPotentialClusterSize(
Expand All @@ -189,19 +190,19 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs,
config.attrs = attribute;

cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel<DataType, IndexType>, probs,
tokens_per_expert, total_num_tokens, num_tokens, num_experts, topk, coeff,
aux_loss, Const_buf);
tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk,
coeff, aux_loss, Const_buf);
} else {
size_t smem_size = sizeof(CompType) * num_experts;
size_t smem_size = sizeof(CompType) * num_cols;
fused_moe_aux_loss_forward_kernel<DataType, IndexType>
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_tokens,
num_experts, topk, coeff, aux_loss, Const_buf);
<<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts,
num_rows, num_cols, topk, coeff, aux_loss, Const_buf);
}
}

void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts, int topk,
float coeff, Tensor& aux_loss, Tensor& Const_buf,
int total_num_tokens, int num_experts, int num_rows, int num_cols,
int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf,
cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
probs.data.dtype, DataType,
Expand All @@ -210,79 +211,80 @@ void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_ex
fused_moe_aux_loss_forward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<DataType*>(probs.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), total_num_tokens,
num_tokens, num_experts, topk, coeff, reinterpret_cast<DataType*>(aux_loss.data.dptr),
num_experts, num_rows, num_cols, topk, coeff,
reinterpret_cast<DataType*>(aux_loss.data.dptr),
reinterpret_cast<float*>(Const_buf.data.dptr), stream);););
}

template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
const IndexType* tokens_per_expert,
int num_tokens, int num_experts,
DataType* grad_aux_loss, DataType* grad_probs) {
const IndexType* tokens_per_expert, int num_rows,
int num_cols, DataType* grad_aux_loss,
DataType* grad_probs) {
int global_warp_num = gridDim.x * blockDim.x / kThreadsPerWarp;
int global_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / kThreadsPerWarp;
int lane_id = threadIdx.x % kThreadsPerWarp;

// Loop: for all positions in each row
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0];
IndexType tokens_per_expert_i = tokens_per_expert[i];
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
// Loop: for all rows
for (int j = global_warp_id; j < num_tokens; j += global_warp_num) {
grad_probs[j * num_experts + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
for (int j = global_warp_id; j < num_rows; j += global_warp_num) {
grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
}
}
}

template <typename DataType, typename IndexType>
void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf,
const IndexType* tokens_per_expert, int num_tokens,
int num_experts, DataType* grad_aux_loss,
const IndexType* tokens_per_expert, int num_rows,
int num_cols, DataType* grad_aux_loss,
DataType* grad_probs, cudaStream_t stream) {
// Meta data for the kernel
int block_size = 256;
int grid_size = (num_tokens + block_size - 1) / block_size;
int grid_size = (num_rows + block_size - 1) / block_size;
fused_moe_aux_loss_backward_kernel<DataType, IndexType><<<grid_size, block_size, 0, stream>>>(
Const_buf, tokens_per_expert, num_tokens, num_experts, grad_aux_loss, grad_probs);
Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs);
}

void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert,
int num_tokens, int num_experts, Tensor& grad_aux_loss,
int num_rows, int num_cols, Tensor& grad_aux_loss,
Tensor& grad_probs, cudaStream_t stream) {
TE_ROUTER_PROBS_TYPE_SWITCH_ALL(
grad_aux_loss.data.dtype, DataType,
TE_ROUTER_INDEX_TYPE_SWITCH_ALL(
tokens_per_expert.data.dtype, IndexType,
fused_moe_aux_loss_backward_kernel_launcher<DataType, IndexType>(
reinterpret_cast<float*>(Const_buf.data.dptr),
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_tokens, num_experts,
reinterpret_cast<IndexType*>(tokens_per_expert.data.dptr), num_rows, num_cols,
reinterpret_cast<DataType*>(grad_aux_loss.data.dptr),
reinterpret_cast<DataType*>(grad_probs.data.dptr), stream);););
}

} // namespace transformer_engine

void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert,
int total_num_tokens, int num_tokens, int num_experts,
int topk, float coeff, NVTETensor aux_loss,
int total_num_tokens, int num_experts, int num_rows,
int num_cols, int topk, float coeff, NVTETensor aux_loss,
NVTETensor Const_buf, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_forward);
using namespace transformer_engine;
fused_moe_aux_loss_forward(
*convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens,
num_tokens, num_experts, topk, coeff, *convertNVTETensorCheck(aux_loss),
num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss),
*convertNVTETensorCheck(Const_buf), stream);
}

void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf,
const NVTETensor tokens_per_expert, int num_tokens,
int num_experts, NVTETensor grad_aux_loss,
NVTETensor grad_probs, cudaStream_t stream) {
const NVTETensor tokens_per_expert, int num_rows,
int num_cols, NVTETensor grad_aux_loss, NVTETensor grad_probs,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_moe_aux_loss_backward);
using namespace transformer_engine;
fused_moe_aux_loss_backward(*convertNVTETensorCheck(Const_buf),
*convertNVTETensorCheck(tokens_per_expert), num_tokens, num_experts,
*convertNVTETensorCheck(tokens_per_expert), num_rows, num_cols,
*convertNVTETensorCheck(grad_aux_loss),
*convertNVTETensorCheck(grad_probs), stream);
}
Loading