Skip to content

[JAX] Select cuDNN backend for normalization #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable);

/*! \brief Helper to query if cuDNN backend for normalization is enabled */
int nvte_use_cudnn_norm_fwd();
int nvte_use_cudnn_norm_bwd();

/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
* in weight dtype. If set to false, it will compute in compute dtype.
*
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,13 @@ void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype);
transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable;
}

int nvte_use_cudnn_norm_fwd() {
NVTE_API_CALL(nvte_use_cudnn_norm_fwd);
return transformer_engine::normalization::use_cudnn_norm_fwd();
}

int nvte_use_cudnn_norm_bwd() {
NVTE_API_CALL(nvte_use_cudnn_norm_bwd);
return transformer_engine::normalization::use_cudnn_norm_bwd();
}
3 changes: 1 addition & 2 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether CuDNN norm fwd is enabled."""
# MXFP8_1D_SCALING always uses CuDNN currently
return (
int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1
or scaling_mode == ScalingMode.MXFP8_1D_SCALING
transformer_engine_jax.use_cudnn_norm_fwd() or scaling_mode == ScalingMode.MXFP8_1D_SCALING
)


Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/jax/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor.set_columnwise_data(static_cast<void *>(&temp), out_dtype, input_shape);
}

// Enable norm with cuDNN instead of TE kernels
nvte_enable_cudnn_norm_fwd(true);

// dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
Expand Down Expand Up @@ -138,6 +141,9 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
colwise_scale_inv_buf->dimensions().back()});
}

// Enable norm with cuDNN instead of TE kernels
nvte_enable_cudnn_norm_fwd(true);

if (_norm_type == NVTE_Norm_Type::LayerNorm) {
NVTE_CHECK(w_dtype == convert_ffi_datatype_to_te_dtype(beta_buf.element_type()),
"gamma and beta must have the same data type.");
Expand Down Expand Up @@ -200,6 +206,9 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s
TensorWrapper dummy_work_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;

// Enable norm with cuDNN instead of TE kernels
nvte_enable_cudnn_norm_bwd(true);

if (norm_type == NVTE_Norm_Type::LayerNorm) {
auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
Expand Down Expand Up @@ -270,6 +279,9 @@ Error_Type NormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type
auto workspace_shape = std::vector<size_t>{wkspace_size};
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);

// Enable norm with cuDNN instead of TE kernels
nvte_enable_cudnn_norm_bwd(true);

if (static_cast<NVTE_Norm_Type>(norm_type) == NVTE_Norm_Type::LayerNorm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("use_cudnn_norm_fwd", &nvte_use_cudnn_norm_fwd);

pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
Expand Down