From b66f82dd575aff9f0add2f27ff5582634064642e Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 11 Jul 2025 07:00:33 -0700 Subject: [PATCH 1/5] select cuDNN backend for normalization by default Signed-off-by: Phuong Nguyen --- .../include/transformer_engine/normalization.h | 4 ++++ transformer_engine/common/normalization/common.cpp | 10 ++++++++++ .../jax/cpp_extensions/normalization.py | 2 +- .../jax/csrc/extensions/normalization.cpp | 12 ++++++++++++ transformer_engine/jax/csrc/extensions/pybind.cpp | 1 + 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index 9c194e9da2..ce06eb57cd 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -149,6 +149,10 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable); +/*! \brief Helper to query if cuDNN backend for normalization is enabled */ +int nvte_use_cudnn_norm_fwd(); +int nvte_use_cudnn_norm_bwd(); + /*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma * in weight dtype. If set to false, it will compute in compute dtype. * diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 9df81a917f..ba4a3e6a7c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -539,3 +539,13 @@ void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; } + +int nvte_use_cudnn_norm_fwd() { + NVTE_API_CALL(nvte_use_cudnn_norm_fwd); + return transformer_engine::normalization::use_cudnn_norm_fwd(); +} + +int nvte_use_cudnn_norm_bwd() { + NVTE_API_CALL(nvte_use_cudnn_norm_bwd); + return transformer_engine::normalization::use_cudnn_norm_bwd(); +} diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3b563efbd0..5f9f5375b2 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -71,7 +71,7 @@ def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool: """Retrieves whether CuDNN norm fwd is enabled.""" # MXFP8_1D_SCALING always uses CuDNN currently return ( - int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1 + transformer_engine_jax.use_cudnn_norm_fwd() or scaling_mode == ScalingMode.MXFP8_1D_SCALING ) diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index b07404eb78..fc5602247d 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -36,6 +36,9 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si output_tensor.set_columnwise_data(static_cast(&temp), out_dtype, input_shape); } + // Enable norm with cuDNN instead of TE kernels + nvte_enable_cudnn_norm_fwd(true); + // dummy tensor wrappers that will carry workspace size info later TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; @@ -138,6 +141,9 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc colwise_scale_inv_buf->dimensions().back()}); } + // Enable norm with cuDNN instead of TE kernels + nvte_enable_cudnn_norm_fwd(true); + if (_norm_type == NVTE_Norm_Type::LayerNorm) { NVTE_CHECK(w_dtype == convert_ffi_datatype_to_te_dtype(beta_buf.element_type()), "gamma and beta must have the same data type."); @@ -200,6 +206,9 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; + // Enable norm with cuDNN instead of TE kernels + nvte_enable_cudnn_norm_bwd(true); + if (norm_type == NVTE_Norm_Type::LayerNorm) { auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); @@ -270,6 +279,9 @@ Error_Type NormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type auto workspace_shape = std::vector{wkspace_size}; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); + // Enable norm with cuDNN instead of TE kernels + nvte_enable_cudnn_norm_bwd(true); + if (static_cast(norm_type) == NVTE_Norm_Type::LayerNorm) { auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c1..2f8681d785 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -84,6 +84,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); + m.def("use_cudnn_norm_fwd", &nvte_use_cudnn_norm_fwd); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) From da94b8bbef3be67fbf03bea3f4654c3e2c337e1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Jul 2025 19:14:35 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/normalization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 5f9f5375b2..f959c77e57 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -71,8 +71,7 @@ def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool: """Retrieves whether CuDNN norm fwd is enabled.""" # MXFP8_1D_SCALING always uses CuDNN currently return ( - transformer_engine_jax.use_cudnn_norm_fwd() - or scaling_mode == ScalingMode.MXFP8_1D_SCALING + transformer_engine_jax.use_cudnn_norm_fwd() or scaling_mode == ScalingMode.MXFP8_1D_SCALING ) From 2e739532e43ce8ee5dcb6095c9445778d0312070 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 16 Jul 2025 08:05:57 -0700 Subject: [PATCH 3/5] TE/PyT to use nvte_use_cudnn_norm_fwd Signed-off-by: Phuong Nguyen --- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 88404a2e1a..a8dedcdbcd 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -102,7 +102,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + if (nvte_use_cudnn_norm_fwd()){ // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } @@ -263,7 +263,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + if (nvte_use_cudnn_norm_fwd()){ // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } From aa903b64a75ee34698c36809c1d1fd59f015034f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 16 Jul 2025 08:09:21 -0700 Subject: [PATCH 4/5] update docs Signed-off-by: Phuong Nguyen --- .../include/transformer_engine/normalization.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index ce06eb57cd..d3efb36279 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -142,15 +142,22 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor NVTETensor workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream); -/*! \brief Helper to enable cuDNN backend for normalization +/*! \brief Helper to enable cuDNN backend LayerNorm/RMSNorm forward pass * * \param[in] bool Enable if True */ void nvte_enable_cudnn_norm_fwd(bool enable); + +/*! \brief Helper to enable cuDNN backend LayerNorm/RMSNorm backward pass + * + * \param[in] bool Enable if True + */ void nvte_enable_cudnn_norm_bwd(bool enable); -/*! \brief Helper to query if cuDNN backend for normalization is enabled */ +/*! \brief Whether cuDNN backend is enabled for LayerNorm/RMSNorm forward pass */ int nvte_use_cudnn_norm_fwd(); + +/*! \brief Whether cuDNN backend is enabled for LayerNorm/RMSNorm backward pass */ int nvte_use_cudnn_norm_bwd(); /*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma @@ -163,6 +170,8 @@ int nvte_use_cudnn_norm_bwd(); */ void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable); +/*! \brief Enumerate for normalization type used in TE + */ enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; #ifdef __cplusplus From 73dc07e58985b86e537070f443b425e2064ba808 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:11:30 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index a8dedcdbcd..28c1ec84b6 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -102,7 +102,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (nvte_use_cudnn_norm_fwd()){ + if (nvte_use_cudnn_norm_fwd()) { // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 != 0 || H % 128 != 0; } @@ -263,7 +263,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Always used fused kernel for FP8 delayed scaling force_unfused_kernel = false; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (nvte_use_cudnn_norm_fwd()){ + if (nvte_use_cudnn_norm_fwd()) { // cuDNN MXFP8 kernel requires full tile force_unfused_kernel = N % 128 != 0 || H % 128 != 0; }