From d504628fa449365f674496ef9f41b166fa61f6e0 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Sun, 2 Jun 2024 16:13:26 -0500 Subject: [PATCH] [Kernel][ROCm][AMD] enable fused topk_softmax kernel for moe layer (#4927) This PR enables the fused topk_softmax kernel used in moe layer for HIP --- CMakeLists.txt | 8 ++-- Dockerfile.rocm | 1 + csrc/cuda_compat.h | 4 ++ csrc/moe/topk_softmax_kernels.cu | 27 +++++++---- setup.py | 2 +- .../layers/fused_moe/fused_moe.py | 46 ++++++++----------- 6 files changed, 45 insertions(+), 43 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f991af61d9b..a197063f3360 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -311,6 +311,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) + # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and # there are supported target arches. @@ -320,8 +323,3 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") add_dependencies(default _punica_C) endif() endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) -endif() diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 9bfe8446a519..e30a2aaf3020 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -108,6 +108,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && python3 setup.py install \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 5909e5eaf5e6..82e55613d915 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -19,8 +19,12 @@ #ifndef USE_ROCM #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) #endif #ifndef USE_ROCM diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 8c65f40fe836..6ba4fcdb3a3f 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -19,15 +19,22 @@ #include #include #include +#include "../cuda_compat.h" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) namespace vllm { namespace moe { -static constexpr int WARP_SIZE = 32; - /// Aligned array type template < typename T, @@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. @@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables @@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way if (other_max > max_val || (other_max == max_val && other_expert < expert)) @@ -383,7 +390,7 @@ struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; @@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; diff --git a/setup.py b/setup.py index d99fc050f6d8..f7d465b60c15 100644 --- a/setup.py +++ b/setup.py @@ -382,7 +382,7 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] -if _is_cuda(): +if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) if not _is_neuron(): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bb7938b3715b..20a3c9f6f893 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,9 +8,9 @@ import triton import triton.language as tl +import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -319,34 +319,26 @@ def fused_topk( M, _ = hidden_states.shape - if is_hip(): - # The MoE kernels are not yet supported on ROCm. - routing_weights = torch.softmax(gating_output, - dim=-1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) - else: - import vllm._moe_C as moe_kernels - - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, + topk_weights = torch.empty(M, topk, - dtype=torch.int32, + dtype=torch.float32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids