From cf1774c06cb463a38aafc17e7d75eb1322c23ee5 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 4 Jun 2025 21:30:27 +0000 Subject: [PATCH 01/30] added XLA FFI custom op for TE/common nvte_cublas_gemm Signed-off-by: Alp Dener started GemmPrimitive, abstract done Signed-off-by: Alp Dener gemm custom op working with BF16, needs testing for FP8/MXFP8 Signed-off-by: Alp Dener converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call Signed-off-by: Alp Dener BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue Signed-off-by: Alp Dener fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2 Signed-off-by: Alp Dener updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet Signed-off-by: Alp Dener new GemmPrimitive passing all Dense tests Signed-off-by: Alp Dener import cleanup and reverted code chunk movement Signed-off-by: Alp Dener removed unused .transpose() implementations from ScaledTensors Signed-off-by: Alp Dener all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl Signed-off-by: Alp Dener removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions Signed-off-by: Alp Dener removed unused changes to ScaledTensor classes and debug prints Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 188 +++-- tests/jax/test_layer.py | 15 +- .../jax/cpp_extensions/activation.py | 75 +- transformer_engine/jax/cpp_extensions/gemm.py | 696 ++++++++++++++++-- transformer_engine/jax/cpp_extensions/misc.py | 9 +- .../jax/cpp_extensions/normalization.py | 7 + .../jax/cpp_extensions/quantization.py | 65 +- transformer_engine/jax/csrc/extensions.h | 3 + .../jax/csrc/extensions/gemm.cpp | 173 +++++ transformer_engine/jax/csrc/extensions/misc.h | 9 + .../jax/csrc/extensions/pybind.cpp | 6 + transformer_engine/jax/dense.py | 104 +-- transformer_engine/jax/layernorm_dense.py | 115 +-- transformer_engine/jax/layernorm_mlp.py | 208 +++--- .../jax/quantize/dequantizer.py | 17 + transformer_engine/jax/quantize/tensor.py | 63 +- 16 files changed, 1349 insertions(+), 404 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index f689bce6a5..5a59d113d5 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2,13 +2,15 @@ # # See LICENSE for license information. +import os +import operator +from functools import reduce +from typing import Union + +import pytest import jax import jax.numpy as jnp -import pytest from jax import jit, value_and_grad -from functools import reduce -from typing import Union -import operator from utils import ( assert_allclose, @@ -30,7 +32,6 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( - DelayedScaleQuantizer, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -44,6 +45,9 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x + +from transformer_engine_jax import is_non_nt_fp8_gemm_supported GEMM_CASES = [ (256, 256, 512), @@ -59,10 +63,14 @@ is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] +supported_fp8_gemm_layouts = [] """ Find supported scaling modes""" if is_fp8_supported: supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) + supported_fp8_gemm_layouts.append("NT") + if is_non_nt_fp8_gemm_supported(): + supported_fp8_gemm_layouts += ["TT", "TN", "NN"] if is_mxfp8_supported: supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) @@ -73,7 +81,7 @@ def is_shape_supported_by_mxfp8(input_shape): input_shape = input_shape.values[0] ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True - except: + except AssertionError: # get_scale_shapes will raise an exception if the shape is not supported return False @@ -147,6 +155,13 @@ def assert_dequantized_grouped_scaled_tensor( pytest.fail("a must be a GroupedScaledTensor object") +def use_jax_dot_for_gemm(enabled=False): + if enabled: + os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' + elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: + os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + + ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_TYPES = [ ("gelu",), @@ -624,15 +639,15 @@ def test_quantize_bitwise( ): key = jax.random.PRNGKey(0) - input = jax.random.uniform(key, input_shape, in_dtype) + inp = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) + jax_output = _jax_quantize(inp, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + te_output = tex.quantize(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -710,7 +725,7 @@ def test_quantize_dbias( pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") key = jax.random.PRNGKey(0) - input = jax.random.uniform(key, input_shape, in_dtype) + inp = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout @@ -718,15 +733,15 @@ def test_quantize_dbias( te_output, te_dbias = jit( lambda input: tex.quantize_dbias( - input, quantizer=te_quantizer, flatten_axis=flatten_axis + inp, quantizer=te_quantizer, flatten_axis=flatten_axis ) - )(input) + )(inp) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( - input, quantizer=jax_quantizer, flatten_axis=flatten_axis + inp, quantizer=jax_quantizer, flatten_axis=flatten_axis ) - )(input) + )(inp) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -880,7 +895,10 @@ def _generate_gemm_input(self, m, n, k, data_layout): @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_bf16(self, m, n, k, data_layout): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims) @@ -890,23 +908,56 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) + @pytest_parametrize_wrapper( + "lhs_q_dtype,rhs_q_dtype", + [ + (jnp.float8_e4m3fn, jnp.float8_e4m3fn), # fprop GEMM + (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM + (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM + ] + ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): - x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + @pytest_parametrize_wrapper("data_layout", supported_fp8_gemm_layouts) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_fp8(self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, + with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + + lhs, rhs, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False + scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, + is_2x2x=False + ) + lhs_quantizer = ( + quantizer_set.x + if lhs_q_dtype == jnp.float8_e4m3fn + else quantizer_set.dgrad + ) + rhs_quantizer = ( + quantizer_set.kernel + if rhs_q_dtype == jnp.float8_e4m3fn + else quantizer_set.dgrad ) + primitive_out = tex.gemm( - x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set + lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims ) - ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + ref_out = self._ref_gemm_with_jnp_dot(lhs, rhs, data_layout) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) + test_q_dtype = ( + jnp.float8_e5m2 + if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) + else jnp.float8_e4m3fn + ) + assert_allclose(primitive_out, ref_out, dtype=test_q_dtype) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - def test_dense_grad_bf16(self, m, n, k): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_bf16(self, m, n, k, with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -932,9 +983,11 @@ def ref_func(x, w, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -956,7 +1009,8 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True + scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, + is_2x2x=True ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 @@ -969,10 +1023,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2) @pytest.fixture(name="random_inputs") @@ -996,19 +1050,14 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") + use_jax_dot_for_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1025,8 +1074,8 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, is_2x2x=True, ) @@ -1072,32 +1121,27 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) - assert_allclose(prim_out, ref_out, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias + self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") + # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1123,8 +1167,8 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, is_2x2x=True, ) @@ -1149,26 +1193,26 @@ def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ) ) - def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): + def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): + dim_nums = ((1,), (0,)), ((), ()) + ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - # TODO: replace gemm with jnp.dot - linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,))) + + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, dim_nums) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = _jax_act_lu(linear_1_out, activation_type) - linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) + act_out = _jax_act_lu(linear_1_out, activation_type) + + linear_2_out = jax.lax.dot_general(act_out, kernel_2, dim_nums) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) - return linear_2_out - - def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): - return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2)) + return jnp.mean(linear_2_out) value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) @@ -1193,18 +1237,18 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) # E5M2 * E5M2 is not supported diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d59e130530..ab79e2eae4 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -487,6 +487,13 @@ class BaseTester: runner = BaseRunner + def use_jax_dot_for_gemm(self, enabled=False): + """Enable/disable TE custom cuBLAS GEMM primitive.""" + if enabled: + os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' + elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: + os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" QuantizeConfig.finalize() # Ensure FP8 disabled. @@ -499,16 +506,20 @@ def test_backward(self, data_shape, dtype, attrs): @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): + @pytest.mark.parametrize("with_jax_gemm", (False, True)) + def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): """Test forward with fp8 enabled""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): + @pytest.mark.parametrize("with_jax_gemm", (False, True)) + def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): """Test backward with fp8 enabled""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ce66bba3cf..4d0a4c1bbd 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -985,6 +985,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -993,6 +994,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1037,6 +1039,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1090,6 +1096,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1100,6 +1107,7 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1113,13 +1121,42 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) + scale = jnp.empty((), jnp.float32) + act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive - if not PrimitiveClass.enabled(): + if ( + not PrimitiveClass.enabled() + or (quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE) + ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) + if quantizer is None: + output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + dz, + x, + scale, + # outputs float32 for dbias accumulation + out_dtype=(jnp.float32 if is_dbias else x.dtype), + # default value for no scaling, TE/common ignore this value when scale is unset + scaling_mode=ScalingMode.NO_SCALING.value, + is_2x=False, # unused + scale_dtype=jnp.float32, # unused + is_dbias=False, + act_enum=act_type_id, + act_len=act_len, + is_outer=True, + ) + output = output.astype(x.dtype) + dbias = None + if is_dbias: + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) + + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype, + ), dbias + + return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): @@ -1145,31 +1182,6 @@ def quantize_dact_dbias( if war_output is not None: return war_output - scale = jnp.empty((), jnp.float32) - - act_type_id = ActivationEnum[activation_type] - - if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( - dz, - x, - scale, - # outputs float32 for dbias accumulation - out_dtype=(jnp.float32 if is_dbias else x.dtype), - # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused - scale_dtype=jnp.float32, # unused - is_dbias=False, - act_enum=act_type_id, - act_len=act_len, - is_outer=True, - ) - dbias = None - if is_dbias: - dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - return output.astype(x.dtype), dbias - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( @@ -1183,7 +1195,7 @@ def quantize_dact_dbias( ) return out, dbias - if isinstance(quantizer, DelayedScaleQuantizer): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # TE/common dact_dbias_quantize does not support gated act yet @@ -1243,6 +1255,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1252,6 +1265,7 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1262,5 +1276,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, + noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d3c23015c1..d553da10f3 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,19 +3,25 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict -from functools import partial, reduce -import operator import math +import operator +from collections.abc import Iterable +from typing import Tuple, Sequence, Union +from functools import partial, reduce, lru_cache + import jax import jax.numpy as jnp +from jax import dtypes +from jax.sharding import NamedSharding, PartitionSpec + +import transformer_engine_jax as tex from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize - from ..quantize import ( ScaledTensor, + ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, @@ -25,9 +31,18 @@ QuantizeLayout, noop_quantizer_set, ) +from ..sharding import get_padded_spec -__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] +__all__ = [ + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_contracting_dims", + "is_gemm_with_all_layouts_supported", +] num_cublas_streams = get_num_compute_streams() @@ -35,16 +50,517 @@ def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if get_device_compute_capability(0) >= 90: + if tex.get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 -def is_gemm_with_all_layouts_supported() -> False: +def is_gemm_with_all_layouts_supported() -> bool: """Return True if using blackwell, False otherwise.""" return get_device_compute_capability(0) >= 100 +def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: + """Convert relative (negative) indexes to absolute dimension numbers.""" + dims_ = dims if isinstance(dims, Iterable) else (dims, ) + if len(dims_) == 0: + return dims_ + return tuple( ndim + dim if dim < 0 else dim for dim in dims_ ) + + +def get_non_contracting_dims(ndim, contracting_dims): + """Return a tuple of dimensions not included in the contracting dimensions.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(dim for dim in range(ndim) if dim not in contracting_dims) + + +def transpose_contracting_dims(ndim, contracting_dims): + """Compute the new dimension numbers for contracting dimensions after a transpose.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(ndim - i - 1 for i in contracting_dims)[::-1] + + +def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: + lhs, rhs, e4m3, e5m2, e8m0 = map( + dtypes.canonicalize_dtype, + ( + lhs_dtype, + rhs_dtype, + jnp.float8_e4m3fn, + jnp.float8_e5m2, + jnp.uint8 # replace with jnp.float8_e8m0 when JAX/XLA merges support + ) + ) + + # MXFP8 GEMM needs both operands to be MXFP8 (uint8 for now until JAX merges float8_e8m0) + if lhs is e8m0 and rhs is e8m0: + return True + + # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) + if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): + return True + + # Any other combination of data types is not supported + return False + + +def _get_gemm_layout( + operand_ndims: Tuple[int, int], + contracting_dims: Tuple[Sequence[int], Sequence[int]] +) -> Tuple[bool, bool]: + lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting + rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting + return lhs_is_transposed, rhs_is_transposed + + +def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_contracting_dims, rhs_contracting_dims = map( + sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims + ) + + lhs_q = lhs + rhs_q = rhs + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + lhs_q = lhs_quantizer.quantize( + lhs, + is_rowwise=True, + is_colwise=False, + flatten_axis=( + max(lhs_contracting_dims) + 1 + if lhs_is_transposed + else min(lhs_contracting_dims) + ), + ) + if lhs_is_transposed: + # Manually update data layout and columnwise flag to avoid transposing already + # transposed data + lhs_q.data_layout = "T" + lhs_q.is_colwise = True + + if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + rhs_q = rhs_quantizer.quantize( + rhs, + is_rowwise=True, + is_colwise=False, + flatten_axis=( + min(rhs_contracting_dims) + if rhs_is_transposed + else max(rhs_contracting_dims) + 1 + ), + ) + if not rhs_is_transposed: + # Manually update data layout and columnwise flag to avoid transposing already + # transposed data + rhs_q.data_layout = "T" + rhs_q.is_colwise = True + + return lhs_q, rhs_q + + +class GemmPrimitive(BasePrimitive): + """ + Primitive for cuBLAS GEMM + """ + + name = "te_gemm_ffi" + multiple_results = True + impl_static_args = (6, 7, 8, 9, 10, 11, 12) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, + contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + # Sanity-check operand layouts and types + operand_ndims = (lhs.ndim, rhs.ndim) + ( + lhs_contracting_dims, + rhs_contracting_dims, + ) = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_contracting_size, rhs_contracting_size = map( + lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims) + ) + assert lhs_contracting_size == rhs_contracting_size, ( + "cuBLAS GEMM operands have incompatible contracting dimensions: " + f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) + + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) + if scaling_mode != ScalingMode.NO_SCALING: + assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + assert lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0, ( + "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + ) + if ( + scaling_mode != ScalingMode.MXFP8_1D_SCALING + and not tex.is_non_nt_fp8_gemm_supported() + ): + assert not lhs_is_transposed and rhs_is_transposed, ( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) + + # Determine output shape and dtype + assert dtypes.canonicalize_dtype(out_dtype).itemsize > 1, ( + "cuBLAS GEMM custom op does not support 8-bit quantized output types." + ) + lhs_non_contracting_shape, rhs_non_contracting_shape = map( + lambda shape, dims: [ shape[dim] for dim in range(len(shape)) if dim not in dims ], + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims) + ) + out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # Validate bias + bias_shape = (0, ) + bias_dtype = out_dtype + if fuse_bias: + expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) + if not grad: + assert bias.size == expected_bias_size, ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({expected_bias_size}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {bias_dtype} but found {bias.dtype}." + ) + bias_shape = bias.shape + else: + bias_shape = rhs_non_contracting_shape + bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + + # Validate pre-GeLU + pre_gelu_shape = (0, ) + pre_gelu_dtype = out_dtype + if fuse_gelu: + pre_gelu_shape = out_shape + if grad: + pre_gelu_ndim = len(pre_gelu_shape) + assert ( + gelu_input.ndim == pre_gelu_shape + and all(gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)) + ), ( + "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " + f"expected {pre_gelu_shape} but found {gelu_input.shape}." + ) + assert gelu_input.dtype == out_dtype, ( + "cuBLAS GEMM pre-GeLU tensor has incorrect data type, " + f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." + ) + pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + + # Need extra workspace for swizzled scale factors + lhs_swizzle_size = 0 + rhs_swizzle_size = 0 + swizzle_dtype = jnp.uint8 # replace with jnp.float8_e8m0 when JAX merges support + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_swizzle_size = lhs_scale_inv.size + rhs_swizzle_size = rhs_scale_inv.size + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size, ), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size, ), dtype=swizzle_dtype) + + # Declare cuBLAS workspace + workspace = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), + dtype=jnp.uint8) + + return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + + @staticmethod + def outer_abstract(*args, **kwargs): + outputs = GemmPrimitive.abstract(*args, **kwargs) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, + contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + del out_dtype + lhs_aval, _, rhs_aval, *_ = ctx.avals_in + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout((lhs_aval.ndim, rhs_aval.ndim), + (lhs_cdims, rhs_cdims)) + + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + kwargs = { + "scaling_mode" : int(scaling_mode.value), + "lhs_axis_boundary" : max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary" : min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed" : lhs_transposed, + "rhs_transposed" : rhs_transposed, + "fuse_bias" : fuse_bias, + "fuse_gelu" : fuse_gelu, + "grad" : grad, + } + + operand_output_aliases = {} + if fuse_bias and not grad: + operand_output_aliases.update({ 4 : 1 }) # bias <-> bias_grad + if fuse_gelu and grad: + operand_output_aliases.update({ 5 : 2 }) # gelu_input <-> pre_gelu_out + + return jax.ffi.ffi_lowering( + GemmPrimitive.name, + operand_output_aliases=operand_output_aliases, + )(ctx, *args, **kwargs) + + @staticmethod + def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, contracting_dims, + scaling_mode, fuse_bias, fuse_gelu, grad): + outputs = GemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + ) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def batcher(batched_args, batch_dims, out_dtype, contracting_dims, scaling_mode, fuse_bias, + fuse_gelu, grad): + assert GemmPrimitive.outer_primitive is not None + lhs, _, rhs, *_ = batched_args + lhs_bdims, *_ = batch_dims + + # Output is batched like LHS only if LHS is batched and RHS is not + out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None, ) + bias_bdims = (None, ) # Bias is never batched + pre_gelu_bdims = (None, ) # Pre-GeLU output, if exists, is batched like GEMM output + if fuse_gelu and not grad: + pre_gelu_bdims = out_bdims + + return ( + GemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + ), + (out_bdims, bias_bdims, pre_gelu_bdims) + ) + + @staticmethod + def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse_bias, + fuse_gelu, grad, mesh, arg_infos, result_infos): + del out_dtype, scaling_mode, result_infos + + # Check contracting dimensions + lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) + operand_ndims = (len(lhs_spec), len(rhs_spec)) + lhs_contracting_dims, rhs_contracting_dims = map( + sanitize_dims, operand_ndims, contracting_dims + ) + lhs_contracting_specs, rhs_contracting_specs = map( + lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None], + (lhs_spec, rhs_spec), + (lhs_contracting_dims, rhs_contracting_dims) + ) + assert len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1, ( + "cuBLAS GEMM operands can have only one sharded contracting dimension." + ) + lhs_contracting_spec, rhs_contracting_spec = map( + lambda spec: None if len(spec) == 0 else spec[0], + (lhs_contracting_specs, rhs_contracting_specs) + ) + assert lhs_contracting_spec == rhs_contracting_spec, ( + "cuBLAS GEMM operands must have the same sharding in contracting dimensions." + ) + + # Sanity check leading dimensions, allow for simultaneous batch and sequence sharding + lhs_leading_dims, rhs_leading_dims = map( + get_non_contracting_dims, operand_ndims, (lhs_contracting_dims, rhs_contracting_dims) + ) + lhs_leading_specs, rhs_leading_specs = map( + lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None ], + (lhs_spec, rhs_spec), + (lhs_leading_dims, rhs_leading_dims) + ) + assert len(lhs_leading_specs) <= 1 and len(rhs_leading_specs) <= 1, ( + "cuBLAS GEMM operands cannot have more than one sharded leading dimensions. This error " + "usually means a sequence-parallel operand was not all-gathered before the GEMM op." + ) + + # Determine output sharding + lhs_leading_spec, rhs_leading_spec = map( + lambda spec: None if len(spec) == 0 else spec[0], + (lhs_leading_specs, rhs_leading_specs) + ) + out_spec = (lhs_leading_spec, rhs_leading_spec) + if operand_ndims[0] > 2 and operand_ndims[1] == 2: + # Restore batch dimensions/sharding to the output + out_spec = (*lhs_leading_specs, rhs_leading_spec) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Bias gradient sharding inherits the RHS contracting spec + bias_spec = (None, ) + if fuse_bias and grad: + bias_spec = (rhs_contracting_spec, ) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + + # Pre-GeLU sharding matches output sharding + pre_gelu_spec = (None, ) + if fuse_gelu and not grad: + pre_gelu_spec = out_spec + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + + return (out_sharding, bias_sharding, pre_gelu_sharding) + + @staticmethod + def partition(out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, + mesh, arg_infos, result_infos): + out_shardings = GemmPrimitive.infer_sharding_from_operands( + out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, + mesh, arg_infos, result_infos + ) + output_spec = out_shardings[0].spec + + # Operand shardings are already guarded with asserts so leave them unchanged here + lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + + # Any distributed scales (e.g. MXFP8) need to be gathered + scale_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Bias has to be sharded same as the trailing dimension of the GEMM output + bias_spec = (None, ) + if fuse_bias and not grad: + bias_spec = (output_spec[-1], ) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + + # Pre-GeLU output has to be sharded same as the GEMM output + pre_gelu_spec = (None, ) + if fuse_gelu and grad: + pre_gelu_spec = output_spec + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + + arg_shardings = ( + lhs_sharding, + scale_sharding, + rhs_sharding, + scale_sharding, + bias_sharding, + pre_gelu_sharding, + ) + + return mesh, GemmPrimitive.impl, out_shardings, arg_shardings + + +register_primitive(GemmPrimitive) + + +@lru_cache(maxsize=1) +def gemm_uses_jax_dot() -> bool: + """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" + return not GemmPrimitive.enabled() + + +def _te_gemm( + lhs: Union[jax.Array, ScaledTensor], + rhs: Union[jax.Array, ScaledTensor], + bias: jax.Array = None, + gelu_input: jax.Array = None, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1, ), (-2, )), + fuse_bias: bool = False, + fuse_gelu: bool = False, + grad: bool = False, +) -> Tuple[jax.Array, ...]: + # Prepare non-quantized GEMM operands + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_contracting_dims, rhs_contracting_dims = map( + sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims + ) + + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + + # Extract GEMM custom op inputs from quantized operands + if isinstance(lhs_q, ScaledTensor): + assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) + if isinstance(lhs_q, ScaledTensor2x): + # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise + # shape. Since we have access to both row-wise and column-wise tensors, we always + # choose the one that avoids transposing LHS in the GEMM kernel to comply with the + # NT-layout restriction for FP8 GEMM on Hopper. + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + scaling_mode = lhs_q.scaling_mode + lhs_data = lhs_q.data + lhs_scale_inv = lhs_q.scale_inv + if lhs_q.data_layout == "T": + lhs_contracting_dims = transpose_contracting_dims(lhs_q.ndim, lhs_contracting_dims) + + if isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) + if isinstance(rhs_q, ScaledTensor2x): + # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise + # shape. Since we have access to both row-wise and column-wise tensors, we always + # choose the one that avoids transposing LHS in the GEMM kernel to comply with the + # NT-layout restriction for FP8 GEMM on Hopper. + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) + rhs_data = rhs_q.data + rhs_scale_inv = rhs_q.scale_inv + if rhs_q.data_layout == "T": + rhs_contracting_dims = transpose_contracting_dims(rhs_q.ndim, rhs_contracting_dims) + + # Dummy empties for bias and gelu + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype + if bias is None or not (fuse_bias and not grad): + bias = jnp.empty(0, dtype=out_dtype) + if gelu_input is None or not (fuse_gelu and grad): + gelu_input = jnp.empty(0, dtype=out_dtype) + + return GemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=(lhs_contracting_dims, rhs_contracting_dims), + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -221,11 +737,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) def _calculate_remaining_shape(shape, contracting_dims): - return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) - - -def _transpose_contract_dims(ndim, contracting_dims): - return tuple(ndim - i - 1 for i in contracting_dims)[::-1] + contracting_dims_ = sanitize_dims(len(shape), contracting_dims) + return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_) # Apply jit to guarantee correctness of FP8 GEMM. @@ -233,9 +746,9 @@ def _transpose_contract_dims(ndim, contracting_dims): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract) + lhs_contract = transpose_contracting_dims(lhs.data.ndim, lhs_contract) if rhs.data_layout == "T": - rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract) + rhs_contract = transpose_contracting_dims(rhs.data.ndim, rhs_contract) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -306,12 +819,12 @@ def _jax_gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, ) -> jnp.ndarray: """ FP8 GEMM via JAX """ - dim_nums = (contracting_dims, ((), ())) def _jax_gemm_fp8_impl(lhs, rhs): @@ -331,65 +844,116 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - return _jax_gemm_fp8_impl(lhs, rhs) - - if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): - if quantizer_set != noop_quantizer_set: - assert type(quantizer_set.x) is type(quantizer_set.kernel) - (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - # Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm) - lhs_q = quantizer_set.x.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = quantizer_set.kernel.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) - if ( - isinstance(lhs, jnp.ndarray) - and isinstance(rhs, jnp.ndarray) - and quantizer_set == noop_quantizer_set - ): - return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) + if isinstance(lhs_q, ScaledTensor) or isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor), ( + "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." + ) + return _jax_gemm_fp8_impl(lhs_q, rhs_q) - raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array") + return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: QuantizerSet = noop_quantizer_set, -) -> jnp.ndarray: - """General matrix multiplication with optional quantization. - - Args: - lhs: First input matrix. - rhs: Second input matrix. - contracting_dims: Tuple of two sequences representing the contracting dimensions. - The first sequence represents the contracting dimensions of the first matrix, - and the second sequence represents the contracting dimensions of the second matrix. - quantizer_set: Set of quantizers for FP8 quantization of the output. - If None, no quantization is applied and the output has the same dtype as the inputs. - - Returns: - If quantizer_set is None: - The matrix multiplication result. - Shape: (M, N) - Dtype: Same as input dtype - If quantizer_set is provided: - A ScaledTensor containing the quantized matrix multiplication result. + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + **kwargs, +) -> Tuple[jnp.ndarray, ...]: + r"""General matrix multiplication with optional quantization. + + Parameters + ---------- + lhs: Union[jax.Array, ScaledTensor] + Left-hand side operand in the matrix multiplication. + rhs: Union[jax.Array, ScaledTensor] + Right-hand side operand in the matrix multiplication. + lhs_quantizer: Quantizer, default = None + Object for down-casting the LHS operand for quantized GEMM. + rhs_quantizer: Quantizer, default = None + Object for down-casting the RHS operand for quantized GEMM. + contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (-2, )) + Tuple of two sequences representing the contracting dimensions. The first sequence + represents the contracting dimensions of the LHS operand, and the second sequence + represents the contracting dimensions of the RHS operand. + bias: jax.Array, default = None + Optional additive bias term, required for forward GEMM with bias fusion. Only supported + with TE's custom call to cuBLAS GEMM. + gelu_input: jax.Array, default = None + Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only + supported with TE's custom call to cuBLAS GEMM. + fuse_bias: bool, default = False + Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with + TE's custom call to cuBLAS GEMM. + fuse_gelu: bool, default = False + Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported + with TE's custom call to cuBLAS GEMM. + grad: bool, default = False + Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with + TE's custom call to cuBLAS GEMM. + + Returns + ------- + jax.Array: + Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the + GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution + when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and + `grad=False`. + Optional[jax.Array]: + Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call + to cuBLAS GEMM. + Optional[jax.Array]: + Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input + to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to + compute the GeLU contribution to the gradient. Only supported with TE's custom call to + cuBLAS GEMM. """ + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility + if lhs_quantizer is None or rhs_quantizer is None: + quantizer_set = kwargs.get("quantizer_set", None) + if quantizer_set is not None: + lhs_quantizer = quantizer_set.x + rhs_quantizer = quantizer_set.kernel + + # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + if gemm_uses_jax_dot(): + assert kwargs.get("bias", None) is None and not kwargs.get("fuse_bias", False), ( + "TE GEMM was invoked with bias fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert kwargs.get("gelu_input", None) is None and not kwargs.get("fuse_gelu", False), ( + "TE GEMM was invoked with GeLU fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + + outputs = _te_gemm( + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims, + **kwargs + ) - return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) + # Discard empty outputs + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) + grad = kwargs.get("grad", False) + clean_outputs = outputs[0] # first output is the final result and is never empty + if (fuse_bias and grad) or (fuse_gelu and not grad): + clean_outputs = (outputs[0], ) + if fuse_bias and grad: # only return bias gradient if it exists + clean_outputs += (outputs[1], ) + if fuse_gelu and not grad: # only return pre-GeLU output if it exists + clean_outputs += (outputs[2], ) + return clean_outputs def grouped_gemm( diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 87f1c1913a..94dfaa45a4 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to calculate dbias separately. This function checks if the workaround should be applied. """ + if quantizer is None: + return False + arch_l_100 = False for local_gpu_id in range(len(jax.local_devices())): if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: arch_l_100 = True break + # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, + # but this fails when bias fusion is turned on with arch < 100. + force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() return ( - quantizer is not None - and quantizer.q_layout == QuantizeLayout.ROWWISE + (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 and is_dbias ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 07ebb33114..f2829c3a58 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1276,6 +1276,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1292,6 +1293,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1319,6 +1321,11 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") + if quantizer is None and noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), mu, rsigma + return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 11b3cdc2a3..e84fb3a976 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -538,11 +538,12 @@ def _jax_quantize( def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): - assert flatten_axis < 0 + sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype dbias = jnp.sum( dx.astype(jnp.float32), - axis=tuple(range(dx.ndim + flatten_axis)), + axis=tuple(range(sum_axis)), keepdims=False, ) return dbias.astype(dtype) @@ -568,6 +569,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -577,24 +579,28 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + # Early-exit for non-quantized call dq_dtype = dq_dtype or x.dtype - - PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if not PrimitiveClass.enabled(): + if quantizer is None: + dbias = None if is_dbias: - return _jax_quantize_dbias( - x, - quantizer=quantizer, - dq_dtype=dq_dtype, + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + if noop_scaled_tensor: + # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() + # always works. + return ScaledTensorFactory.create_2x( + x, None, x, None, ScalingMode.NO_SCALING, dq_dtype=x.dtype, data_layout="NN", flatten_axis=flatten_axis, - ) - return ( - _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), - None, - ) + ), dbias + return x, dbias - # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, + # fall back on the native-JAX quantize implementation + PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive + if ( + quantizer.q_layout == QuantizeLayout.COLWISE + or not PrimitiveClass.enabled() + ): if is_dbias: return _jax_quantize_dbias( x, @@ -606,9 +612,8 @@ def _quantize_dbias_impl( _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), None, ) - scale = jnp.empty((), jnp.float32) - # TE/common dbias_quantize does not support 1x on arch < 100 + # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out, _ = _quantize_dbias_impl( x=x, @@ -620,29 +625,23 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - if quantizer is None: - if is_dbias: - return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - return x, None - + scale = jnp.empty((), jnp.float32) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) scale = compute_scale_from_amax(amax, quantizer.q_dtype) - - if isinstance(quantizer, DelayedScaleQuantizer): + elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale - is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) # It is faster to use 1x quantization for tensor scaling + is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() and is_1x_kernel_supported ) - q_layout = quantizer.q_layout if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE @@ -666,7 +665,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if force_1x_quantization: colwise_scale_inv = rowwise_scale_inv if q_layout == QuantizeLayout.ROWWISE: @@ -698,6 +697,7 @@ def quantize( x: jnp.ndarray, quantizer: Quantizer, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -707,6 +707,8 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer + is None. Returns: A ScaledTensor containing the quantized input tensor. @@ -715,6 +717,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) return out @@ -724,6 +727,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -734,6 +738,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when + quantizer is None. Returns: A tuple containing: @@ -743,7 +749,8 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis + dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index aa257abe95..553534a205 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +// GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d57d4682ca..02822a3f0a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -4,12 +4,16 @@ * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" +#include "transformer_engine/swizzle.h" #include +#include +#include #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" +#include "common/util/string.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 @@ -17,6 +21,175 @@ namespace transformer_engine { namespace jax { +std::tuple> xla_buffer_to_nvte_gemm_operand( + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + // Set tensor data with collapsed 2D shape + auto buffer_dims = buffer.dimensions(); + std::vector input_shape = {product(buffer_dims, 0, axis_boundary), + product(buffer_dims, axis_boundary, buffer_dims.size())}; + auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type()); + TensorWrapper input(get_nvte_scaling_mode(scaling_mode)); + + if (rowwise) { + input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + } else { + input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + } + + // Set scaling factor for quantized tensors + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); + + std::vector scale_shape(scale_inv.dimensions().begin(), scale_inv.dimensions().end()); + auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) + ? DType::kFloat8E8M0 + : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + + // Swizzle scaling factors for MXFP8 + if (is_block_scaling(scaling_mode)) { + // Get the swizzle buffer + NVTE_CHECK(swizzled_scale_inv->element_count() > 0, + "Missing swizzled inverse scale buffer in the JAX primitive."); + auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto swizzled_scale_inv_dtype = convert_ffi_datatype_to_te_dtype( + swizzled_scale_inv->element_type()); + NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); + + // Create tensor to hold swizzled scale factor + TensorWrapper output(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } + + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + + // Set swizzled scales into the input tensor + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } + } + } + + return std::make_tuple(std::move(input), input_shape); +} + +Error_Type GemmFFI( + cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, + Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type lhs_swizzle, + Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad) { + // Operands (this includes swizzling MXFP8 scaling factors) + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when + // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) + bool always_rowwise = ( + scaling_mode == JAXX_Scaling_Mode::NO_SCALING || ( + is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; + bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, rhs, rhs_scale_inv, lhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + + // Output tensor + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), "cuBLAS GEMM output buffer size is incorrect, " + "expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); + + // Bias input to forward pass or bias gradient output from backward pass + void* bias_ptr = nullptr; + std::vector bias_shape = {0}; + DType bias_dtype = out_dtype; + if (fuse_bias) { + if (!grad) { + NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); + } + bias_ptr = bias_grad->untyped_data(); + bias_shape.at(0) = bias_grad->dimensions().front(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + } + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + // Pre-GeLU output from forward pass or input to backward pass + void* pre_gelu_ptr = nullptr; + std::vector pre_gelu_shape = {0}; + DType pre_gelu_dtype = out_dtype; + if (gelu_input.element_count() > 0) { + if (grad) { + NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"); + } + pre_gelu_ptr = pre_gelu_out->untyped_data(); + pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1), + static_cast(pre_gelu_out->dimensions().back())}; + pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type()); + } + auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); + + // cuBLAS workspace + std::vector workspace_shape = {static_cast(workspace->element_count())}; + auto workspace_ = TensorWrapper(workspace->untyped_data(), workspace_shape, DType::kByte); + + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, false, + num_math_sm, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad"), + FFI_CudaGraph_Traits); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 03194e9d72..b4ed3399d9 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { CURRENT_TENSOR_SCALING = 3, }; +inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING + || mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); +} + +inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2d7801cc20..afbeb644c1 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -55,6 +55,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + // GEMM + dict["te_gemm_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); + // Grouped GEMM dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), @@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); + m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8834f4f73c..0fe8d1bb9a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -48,7 +48,7 @@ def dense( Transformed output tensor """ # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set: + if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) output = tex.gemm(x, kernel, contracting_dims) if bias is not None: @@ -90,39 +90,50 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims - - flatten_axis_x = -len(x_contracting_dims) - flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + x_cdims, k_cdims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + x_is_transposed = x.ndim - 1 not in x_cdims + k_is_transposed = kernel.ndim - 1 in k_cdims + assert not x_is_transposed and not k_is_transposed, ( + "Forward-mode Dense layer implementation only supports 'NN' layout inputs." + ) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = tex.quantize(x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, + noop_scaled_tensor=True) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) + rowwise_x = casted_x.get_rowwise_tensor() + colwise_x = casted_x.get_colwise_tensor() casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + kernel, flatten_axis=max(k_cdims) + 1, quantizer=quantizer_set.kernel, + noop_scaled_tensor=True ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) + colwise_k = casted_kernel.get_colwise_tensor() + rowwise_k = casted_kernel.get_rowwise_tensor() - # GEMM NN + # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) + # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + use_bias = bias is not None output = tex.gemm( - casted_x.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + rowwise_x, + colwise_k, + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, - x.shape, - kernel.shape, + colwise_x, + rowwise_k, use_bias, quantizer_set, - flatten_axis_k, ) return output, ctx @@ -135,46 +146,49 @@ def _dense_bwd_rule( Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( - colwise_casted_x, - rowwise_casted_kernel, - x_shape, - kernel_shape, + colwise_x, + rowwise_k, use_bias, quantizer_set, - flatten_axis_k, ) = ctx + # Original non-contracting dimensions in the forward pass are contracting dimensions for the + # backward pass. + fwd_x_cdims, fwd_k_cdims = map( + tex.sanitize_dims, (colwise_x.ndim, rowwise_k.ndim), contracting_dims + ) + fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_x.ndim, fwd_x_cdims) + fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_k.ndim, fwd_k_cdims) + # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input + flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) - # GEMM NT - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) - ) - # k_non_contracting_dims - k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims - ) + # Prepare DGRAD and WGRAD operands and contracting dims + rowwise_g = casted_grad.get_rowwise_tensor() + rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) + colwise_g = casted_grad.get_colwise_tensor() + colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) + + # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, - (g_contracting_dim, k_contracting_dim), + rowwise_g, + rowwise_k, + contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), + grad=True ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) - # GEMM TN - # x_non_contracting_dims - g_contracting_dim = x_contracting_dim = tuple( - range(0, len(x_shape) - len(fwd_x_contracting_dims)) - ) - + # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) + # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) + colwise_x, + colwise_g, + contracting_dims=(fwd_x_non_cdims, colwise_g_cdims), + grad=True, ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 727ff78c2d..26b9d8fad5 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -173,12 +173,12 @@ def _layernorm_dense_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - x_contracting_dims = (len(x.shape) - 1,) - k_contracting_dims = (0,) + x_cdims = (x.ndim - 1,) + k_cdims = (0,) assert x.shape[-1] == kernel.shape[0] + # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) - casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -187,42 +187,50 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) - # Kernel in (hidden_in, hidden_out...) - flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + # Layernorm output (batch..., hidden_in) + rowwise_ln_out = casted_ln_out.get_rowwise_tensor() + colwise_ln_out = casted_ln_out.get_colwise_tensor() + + # Kernel (hidden_in, hidden_out) + flatten_axis = 1 - kernel.ndim + casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + noop_scaled_tensor=True) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) - # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out...) - output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), - ) + rowwise_kernel = casted_kernel.get_rowwise_tensor() + colwise_kernel = casted_kernel.get_colwise_tensor() + # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) + # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) use_bias = bias is not None - if use_bias: + output = tex.gemm( + rowwise_ln_out, + colwise_kernel, + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False, + ) + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, - x.shape, - kernel.shape, + colwise_ln_out, + rowwise_kernel, mu, rsigma, x, gamma, beta, - x_contracting_dims, - k_contracting_dims, + x_cdims, + k_cdims, use_bias, quantizer_set, - flatten_axis, ) return output, ctx @@ -233,7 +241,7 @@ def _layernorm_dense_bwd_rule( zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes, # pylint: disable=unused-argument + dot_input_axes, kernel_axes, ctx, grad, @@ -250,57 +258,57 @@ def _layernorm_dense_bwd_rule( Tuple of gradients for all input parameters """ ( - colwise_casted_ln_out, - rowwise_casted_kernel, - x_shape, - kernel_shape, + colwise_ln_out, + rowwise_kernel, mu, rsigma, x, gamma, beta, - x_contracting_dims_in_fwd, - k_contracting_dims_in_fwd, + fwd_x_cdims, + fwd_k_cdims, use_bias, quantizer_set, - flatten_axis, ) = ctx + # Original non-contracting dimensions in the forward pass are contracting dimensions for the + # backward pass. + fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) + fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_kernel.ndim, fwd_k_cdims) + # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input + flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_constracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim) - ) - # k_non_contracting_dims - k_constracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd - ) + # Prepare DGRAD and WGRAD operands and contracting dims + rowwise_g = casted_grad.get_rowwise_tensor() + rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) + colwise_g = casted_grad.get_colwise_tensor() + colwise_ln_out_cdims = fwd_x_non_cdims + colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) - # NT GEMM + # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), - ) - - dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) - - g_constracting_dim = x_constracting_dim = tuple( - range(0, len(x_shape) - len(x_contracting_dims_in_fwd)) + rowwise_g, + rowwise_kernel, + contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), + grad=True ) + dgrad = with_sharding_constraint_by_logical_axes(dgrad, dot_input_axes) - # TN GEMM + # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) + # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) wgrad = tex.gemm( - colwise_casted_ln_out, - casted_grad.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + colwise_ln_out, + colwise_g, + contracting_dims=(colwise_ln_out_cdims, colwise_g_cdims), + grad=True, ) - wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) + # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad, x, @@ -312,6 +320,7 @@ def _layernorm_dense_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) + dx = with_sharding_constraint_by_logical_axes(dx, layernorm_input_axes) return dx, wgrad, dgamma, dbeta, dbias, quantizer_set diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e04b930233..8a42555a5e 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -22,7 +22,11 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type -from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set +from .quantize import ( + with_sharding_constraint_by_logical_axes, + QuantizerSet, + noop_quantizer_set +) from .sharding import get_non_contracting_logical_axes @@ -244,16 +248,16 @@ def _layernorm_mlp_fwd_rule( assert len(kernel_2.shape) == 2 assert kernel_1.shape[-2] == len(activation_type) - x_contracting_dims = (len(x.shape) - 1,) - k_contracting_dims = (0,) + x_cdims = (x.ndim - 1,) + k_cdims = (0,) - assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + assert x.shape[x_cdims[0]] == kernel_1.shape[k_cdims[0]] use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None + # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) - casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -262,49 +266,77 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + # FC1 kernel (hidden_in, act_len, hidden_out) + casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, + noop_scaled_tensor=True) + + # Prepare FC1 FPROP operands and layouts + rowwise_ln_out = casted_ln_out.get_rowwise_tensor() + rowwise_kernel_1 = casted_kernel_1.get_rowwise_tensor() + colwise_ln_out = casted_ln_out.get_colwise_tensor() + colwise_kernel_1 = casted_kernel_1.get_colwise_tensor() - # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out) + # FC1 GEMM: + # (batch..., hidden_in) x (hidden_in, act_len, hidden_out) = (batch..., act_len, hidden_out) + # FC1 FP8 GEMM: + # (batch..., hidden_in) x (hidden_out, act_len, hidden_in)^T = (batch..., act_len, hidden_out) + use_bias_1 = bias_1 is not None dot_1_output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel_1.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + rowwise_ln_out, + colwise_kernel_1, + bias=bias_1 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False, ) if dot_1_input_axes is not None and kernel_1_axes is not None: dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), + *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_cdims), + *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_cdims), ) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - if use_bias_1: + if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) - # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) - + # Activation (batch..., act_len, hidden_out) -> (batch..., hidden_out) + casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + noop_scaled_tensor=True) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) + # FC2 kernel (hidden_out, hidden_in) + casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, + noop_scaled_tensor=True) - # NN GEMM - # (batch..., hidden_in) x (hidden_out, hidden_in) + # Prepare FC2 FPROP operands and layouts + rowwise_act_out = casted_act_out.get_rowwise_tensor() + rowwise_kernel_2 = casted_kernel_2.get_rowwise_tensor() + colwise_act_out = casted_act_out.get_colwise_tensor() + colwise_kernel_2 = casted_kernel_2.get_colwise_tensor() + + # FC2 GEMM: + # (batch..., hidden_out) x (hidden_out, hidden_in) = (batch..., hidden_in) + # FC2 FP8 GEMM: + # (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dot_2_output = tex.gemm( - casted_act_out.get_rowwise_tensor(), - casted_kernel_2.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + rowwise_act_out, + colwise_kernel_2, + bias=bias_2 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False ) - if use_bias_2: + if use_bias_2 and tex.gemm_uses_jax_dot(): bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) @@ -317,15 +349,13 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - casted_ln_out.get_colwise_tensor(), - casted_kernel_1.get_rowwise_tensor(), + colwise_ln_out, + rowwise_kernel_1, dot_1_output, - casted_act_out.get_colwise_tensor(), - casted_kernel_2.get_rowwise_tensor(), - x_contracting_dims, - k_contracting_dims, - kernel_1.shape, - kernel_2.shape, + colwise_act_out, + rowwise_kernel_2, + x_cdims, + k_cdims, use_bias_1, use_bias_2, quantizer_sets, @@ -362,22 +392,20 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, rsigma, gamma, beta, - colwise_casted_ln_out, - rowwise_casted_kernel_1, + colwise_ln_out, + rowwise_kernel_1, dot_1_output, - colwise_casted_act_out, - rowwise_casted_kernel_2, - x_contracting_dims_in_fwd, - k_contracting_dims_in_fwd, - kernel_1_shape, - kernel_2_shape, + colwise_act_out, + rowwise_kernel_2, + fwd_x_cdims, + fwd_k_cdims, use_bias_1, use_bias_2, quantizer_sets, @@ -385,82 +413,85 @@ def _layernorm_mlp_bwd_rule( ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets - # Since the sharding of outputs should be the same as dot_1's input + # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input + fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) + flatten_axis_grad = len(fwd_x_non_cdims) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad + grad, is_dbias=use_bias_2, flatten_axis=flatten_axis_grad, + quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, ) - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - g_contracting_dims_2 = tuple( - range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) - ) - # k_non_contracting_dims - k_contracting_dims_2 = tuple( - dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd - ) + # Prepare FC2 DGRAD and WGRAD operands and contracting dims + rowwise_g = casted_grad.get_rowwise_tensor() + rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) + fwd_k2_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_2.ndim, fwd_k_cdims) - # NT GEMM - # (batch..., hidden_out) x (hidden_in, hidden_out) + colwise_g = casted_grad.get_colwise_tensor() + colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) + colwise_act_out_cdims = tex.get_non_contracting_dims(colwise_act_out.ndim, fwd_x_cdims) + + # FC2 DGRAD GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) dgrad_2 = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel_2, - (g_contracting_dims_2, k_contracting_dims_2), + rowwise_g, + rowwise_kernel_2, + contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), + grad=True ) - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - x_contracting_dims = g_contracting_dims = tuple( - range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) - ) - - # TN GEMM - # (hidden, batch...,) x (hidden, batch...) + # FC2 WGRAD GEMM: + # (batch..., hidden_out)^T x (batch..., hidden_in) = (hidden_out, hidden_in) + # FC2 WGRAD FP8 GEMM: + # (hidden_out, batch...) x (hidden_in, batch...)^T = (hidden_out, hidden_in) wgrad_2 = tex.gemm( - colwise_casted_act_out, - casted_grad.get_colwise_tensor(), - (x_contracting_dims, g_contracting_dims), + colwise_act_out, + colwise_g, + contracting_dims=(colwise_act_out_cdims, colwise_g_cdims), + grad=True, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) + # Activation gradient w/ bias fusion (batch..., hidden_out) -> (batch.., act_len, hidden_out) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( dgrad_2, dot_1_output, activation_type=activation_type, is_dbias=use_bias_1, - quantizer=ffn2_quantizer_set.dgrad, + quantizer=ffn1_quantizer_set.dgrad, + noop_scaled_tensor=True, ) - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim - g_contracting_dims_1 = tuple( - range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) - ) - # k_non_contracting_dims - k_contracting_dims_1 = tuple( - dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd - ) + # Prepare FC1 DGRAD and WGRAD operands and contracting dims + rowwise_dact_out = casted_dact_out.get_rowwise_tensor() + rowwise_dact_out_cdims = tuple(range(flatten_axis_grad, rowwise_dact_out.ndim)) + colwise_dact_out = casted_dact_out.get_colwise_tensor() + colwise_dact_out_cdims = tex.get_non_contracting_dims(casted_dact_out.ndim, rowwise_dact_out_cdims) + fwd_k1_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_1.ndim, fwd_k_cdims) - # NT GEMM + # FC1 DGRAD GEMM: + # (batch..., act_len, hidden_out) x (hidden_in, act_len, hidden_out)^T = (batch..., hidden_in) dgrad_1 = tex.gemm( - casted_dact_out.get_rowwise_tensor(), - rowwise_casted_kernel_1, - (g_contracting_dims_1, k_contracting_dims_1), + rowwise_dact_out, + rowwise_kernel_1, + contracting_dims=(rowwise_dact_out_cdims, fwd_k1_non_cdims), + grad=True ) - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) - # TN GEMM - # (hidden, batch...) x (hidden, batch...) + # FC1 WGRAD GEMM: + # (batch..., hidden_in)^T x (batch..., act_len, hidden_out) = (hidden_in, act_len, hidden_out) + # FC1 WGRAD FP8 GEMM: + # (hidden_in, batch...) x (hidden_out, act_len, batch...)^T = (hidden_in, act_len, hidden_out) wgrad_1 = tex.gemm( - colwise_casted_ln_out, - casted_dact_out.get_colwise_tensor(), - (x_contracting_dims, g_contracting_dims), + colwise_ln_out, + colwise_dact_out, + contracting_dims=(fwd_x_non_cdims, colwise_dact_out_cdims), + grad=True, ) - wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) + # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad_1, x, @@ -472,6 +503,7 @@ def _layernorm_mlp_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) + dx = with_sharding_constraint_by_logical_axes(dx, norm_input_axes) return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 06a2562fb1..2459190f1a 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -36,6 +36,22 @@ def dequantize(scaled_tensor): """Dequantizing given tensor to higher precision.""" +@dataclass +class NoopDequantizer(Dequantizer): + """No-op Dequantizer Class""" + + @staticmethod + def _dequantize_func(data, *args, **kwargs): + """A no-op dequantize function that returns the data without any changes.""" + del args, kwargs + return data + + @staticmethod + def dequantize(scaled_tensor): + """A no-op dequantize function that simply returns the data array in the ScaledTensor.""" + return scaled_tensor.data + + class TensorScaleDequantizer(Dequantizer): """ TensorScaling Dequantizer Class @@ -152,6 +168,7 @@ def dequantize(scaled_tensor): ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NO_SCALING: NoopDequantizer, } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 02b1a1a99e..4761816220 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -55,6 +55,11 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, *aux_data) + @property + @abstractmethod + def ndim(self): + """Number of dimensions of the underlying quantized array.""" + @abstractmethod def dequantize(self): """Dequantizes the tensor back to its original precision. @@ -136,25 +141,40 @@ def __post_init__(self): 0 < self.flatten_axis < len(self.data.shape) ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis - ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + if self.scaling_mode == ScalingMode.NO_SCALING: + self.scale_inv = jnp.empty((1, ), dtype=jnp.float32) + + else: + expected_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis ) + if self.scale_inv.shape != expected_scale_shape: + assert self.scale_inv.shape == expected_unpadded_scale_shape, ( + f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" + f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" + f" {self.scale_inv.shape}" + ) + expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, self.is_colwise, is_padded=False, + flatten_axis=self.flatten_axis + ) + if self.scale_inv.shape != expected_scale_shape: + assert self.scale_inv.shape == expected_unpadded_scale_shape, ( + f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" + f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" + f" {self.scale_inv.shape}" + ) + pad_width = tuple( + (0, a - b) for a, b in zip(expected_scale_shape, + expected_unpadded_scale_shape) + ) + # This actually pad scale_inv with nan, should we pad it with 127 directly instead? + self.scale_inv = jnp.pad( + self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. @@ -173,6 +193,10 @@ def tree_flatten(self): ) return (children, aux_data) + @property + def ndim(self): + return self.data.ndim + def dequantize(self): """Dequantizes the tensor using the stored dequantization function. @@ -370,6 +394,11 @@ def tree_flatten(self): aux_data = () return (children, aux_data) + @property + def ndim(self): + """Number of dimensions of the underlying row-wise tensor.""" + return self.rowwise_tensor.ndim + def dequantize(self): """Dequantizes the tensor using the row-wise component's dequantization. From da0709afa8fa727a0c2451a2d35d5ac7aa2f9559 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 13 Jun 2025 05:02:42 +0000 Subject: [PATCH 02/30] minor unit test cleanup Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 5 ++--- transformer_engine/jax/cpp_extensions/quantization.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5a59d113d5..a87d12bec1 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -45,7 +45,6 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x from transformer_engine_jax import is_non_nt_fp8_gemm_supported @@ -1133,7 +1132,7 @@ def ref_func(x, w, gamma, beta): @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm @@ -1141,7 +1140,7 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ - + use_jax_dot_for_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index e84fb3a976..4bb4803fcb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -36,7 +36,6 @@ Quantizer, GroupedQuantizer, QuantizeLayout, - DelayedScaleQuantizer, ScalingMode, compute_scale_from_amax, ) From e5b933c064466ef197558ffe7c31b2dd5f8348b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 05:03:15 +0000 Subject: [PATCH 03/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 54 ++-- tests/jax/test_layer.py | 6 +- .../jax/cpp_extensions/activation.py | 19 +- transformer_engine/jax/cpp_extensions/gemm.py | 239 +++++++++++------- .../jax/cpp_extensions/normalization.py | 10 +- .../jax/cpp_extensions/quantization.py | 27 +- .../jax/csrc/extensions/gemm.cpp | 65 ++--- transformer_engine/jax/csrc/extensions/misc.h | 4 +- transformer_engine/jax/dense.py | 31 +-- transformer_engine/jax/layernorm_dense.py | 15 +- transformer_engine/jax/layernorm_mlp.py | 41 +-- transformer_engine/jax/quantize/tensor.py | 12 +- 12 files changed, 306 insertions(+), 217 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index a87d12bec1..002a61c715 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -156,9 +156,9 @@ def assert_dequantized_grouped_scaled_tensor( def use_jax_dot_for_gemm(enabled=False): if enabled: - os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' - elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: - os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] @@ -731,9 +731,7 @@ def test_quantize_dbias( ) te_output, te_dbias = jit( - lambda input: tex.quantize_dbias( - inp, quantizer=te_quantizer, flatten_axis=flatten_axis - ) + lambda input: tex.quantize_dbias(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) )(inp) jax_output, jax_dbias = jit( @@ -911,44 +909,42 @@ def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): "lhs_q_dtype,rhs_q_dtype", [ (jnp.float8_e4m3fn, jnp.float8_e4m3fn), # fprop GEMM - (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM - (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM - ] + (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM + (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM + ], ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", supported_fp8_gemm_layouts) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_gemm_fp8(self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, - with_jax_gemm): + def test_gemm_fp8( + self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, with_jax_gemm + ): use_jax_dot_for_gemm(enabled=with_jax_gemm) lhs, rhs, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=False - ) - lhs_quantizer = ( - quantizer_set.x - if lhs_q_dtype == jnp.float8_e4m3fn - else quantizer_set.dgrad + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=False, ) + lhs_quantizer = quantizer_set.x if lhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad rhs_quantizer = ( - quantizer_set.kernel - if rhs_q_dtype == jnp.float8_e4m3fn - else quantizer_set.dgrad + quantizer_set.kernel if rhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad ) primitive_out = tex.gemm( - lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, - contracting_dims=contracting_dims + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims, ) ref_out = self._ref_gemm_with_jnp_dot(lhs, rhs, data_layout) test_q_dtype = ( - jnp.float8_e5m2 - if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) - else jnp.float8_e4m3fn + jnp.float8_e5m2 if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) else jnp.float8_e4m3fn ) assert_allclose(primitive_out, ref_out, dtype=test_q_dtype) @@ -1008,8 +1004,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=True + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index ab79e2eae4..389148415a 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -490,9 +490,9 @@ class BaseTester: def use_jax_dot_for_gemm(self, enabled=False): """Enable/disable TE custom cuBLAS GEMM primitive.""" if enabled: - os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' - elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: - os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 4d0a4c1bbd..341dcb0c8c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1124,9 +1124,8 @@ def quantize_dact_dbias( scale = jnp.empty((), jnp.float32) act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive - if ( - not PrimitiveClass.enabled() - or (quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE) + if not PrimitiveClass.enabled() or ( + quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) @@ -1152,9 +1151,17 @@ def quantize_dact_dbias( dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) if noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype, - ), dbias + return ( + ScaledTensorFactory.create_2x( + output, + None, + output, + None, + ScalingMode.NO_SCALING, + dq_dtype=output.dtype, + ), + dbias, + ) return output, dbias diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d553da10f3..5f1218872a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -62,10 +62,10 @@ def is_gemm_with_all_layouts_supported() -> bool: def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: """Convert relative (negative) indexes to absolute dimension numbers.""" - dims_ = dims if isinstance(dims, Iterable) else (dims, ) + dims_ = dims if isinstance(dims, Iterable) else (dims,) if len(dims_) == 0: return dims_ - return tuple( ndim + dim if dim < 0 else dim for dim in dims_ ) + return tuple(ndim + dim if dim < 0 else dim for dim in dims_) def get_non_contracting_dims(ndim, contracting_dims): @@ -88,8 +88,8 @@ def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: rhs_dtype, jnp.float8_e4m3fn, jnp.float8_e5m2, - jnp.uint8 # replace with jnp.float8_e8m0 when JAX/XLA merges support - ) + jnp.uint8, # replace with jnp.float8_e8m0 when JAX/XLA merges support + ), ) # MXFP8 GEMM needs both operands to be MXFP8 (uint8 for now until JAX merges float8_e8m0) @@ -105,8 +105,7 @@ def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: def _get_gemm_layout( - operand_ndims: Tuple[int, int], - contracting_dims: Tuple[Sequence[int], Sequence[int]] + operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]] ) -> Tuple[bool, bool]: lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting @@ -128,9 +127,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_rowwise=True, is_colwise=False, flatten_axis=( - max(lhs_contracting_dims) + 1 - if lhs_is_transposed - else min(lhs_contracting_dims) + max(lhs_contracting_dims) + 1 if lhs_is_transposed else min(lhs_contracting_dims) ), ) if lhs_is_transposed: @@ -145,9 +142,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_rowwise=True, is_colwise=False, flatten_axis=( - min(rhs_contracting_dims) - if rhs_is_transposed - else max(rhs_contracting_dims) + 1 + min(rhs_contracting_dims) if rhs_is_transposed else max(rhs_contracting_dims) + 1 ), ) if not rhs_is_transposed: @@ -171,8 +166,20 @@ class GemmPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, - contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + def abstract( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) ( @@ -182,7 +189,7 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), - (lhs_contracting_dims, rhs_contracting_dims) + (lhs_contracting_dims, rhs_contracting_dims), ) assert lhs_contracting_size == rhs_contracting_size, ( "cuBLAS GEMM operands have incompatible contracting dimensions: " @@ -195,9 +202,9 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype "cuBLAS GEMM quantized operands have incompatible data types: " f"{lhs.dtype} x {rhs.dtype}." ) - assert lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0, ( - "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." - ) + assert ( + lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 + ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING and not tex.is_non_nt_fp8_gemm_supported() @@ -209,19 +216,19 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype ) # Determine output shape and dtype - assert dtypes.canonicalize_dtype(out_dtype).itemsize > 1, ( - "cuBLAS GEMM custom op does not support 8-bit quantized output types." - ) + assert ( + dtypes.canonicalize_dtype(out_dtype).itemsize > 1 + ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." lhs_non_contracting_shape, rhs_non_contracting_shape = map( - lambda shape, dims: [ shape[dim] for dim in range(len(shape)) if dim not in dims ], + lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], (lhs.shape, rhs.shape), - (lhs_contracting_dims, rhs_contracting_dims) + (lhs_contracting_dims, rhs_contracting_dims), ) out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) # Validate bias - bias_shape = (0, ) + bias_shape = (0,) bias_dtype = out_dtype if fuse_bias: expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) @@ -240,15 +247,14 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) # Validate pre-GeLU - pre_gelu_shape = (0, ) + pre_gelu_shape = (0,) pre_gelu_dtype = out_dtype if fuse_gelu: pre_gelu_shape = out_shape if grad: pre_gelu_ndim = len(pre_gelu_shape) - assert ( - gelu_input.ndim == pre_gelu_shape - and all(gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)) + assert gelu_input.ndim == pre_gelu_shape and all( + gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim) ), ( "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " f"expected {pre_gelu_shape} but found {gelu_input.shape}." @@ -266,12 +272,13 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_swizzle_size = lhs_scale_inv.size rhs_swizzle_size = rhs_scale_inv.size - lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size, ), dtype=swizzle_dtype) - rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size, ), dtype=swizzle_dtype) + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) # Declare cuBLAS workspace - workspace = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), - dtype=jnp.uint8) + workspace = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace @@ -281,31 +288,45 @@ def outer_abstract(*args, **kwargs): return outputs[:-3] # discard workspace arrays @staticmethod - def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, - contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): del out_dtype lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) - lhs_transposed, rhs_transposed = _get_gemm_layout((lhs_aval.ndim, rhs_aval.ndim), - (lhs_cdims, rhs_cdims)) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) + ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) kwargs = { - "scaling_mode" : int(scaling_mode.value), - "lhs_axis_boundary" : max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - "rhs_axis_boundary" : min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, - "lhs_transposed" : lhs_transposed, - "rhs_transposed" : rhs_transposed, - "fuse_bias" : fuse_bias, - "fuse_gelu" : fuse_gelu, - "grad" : grad, + "scaling_mode": int(scaling_mode.value), + "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed": lhs_transposed, + "rhs_transposed": rhs_transposed, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, } operand_output_aliases = {} if fuse_bias and not grad: - operand_output_aliases.update({ 4 : 1 }) # bias <-> bias_grad + operand_output_aliases.update({4: 1}) # bias <-> bias_grad if fuse_gelu and grad: - operand_output_aliases.update({ 5 : 2 }) # gelu_input <-> pre_gelu_out + operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out return jax.ffi.ffi_lowering( GemmPrimitive.name, @@ -313,8 +334,20 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ )(ctx, *args, **kwargs) @staticmethod - def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, contracting_dims, - scaling_mode, fuse_bias, fuse_gelu, grad): + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): outputs = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, @@ -332,16 +365,24 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, co return outputs[:-3] # discard workspace arrays @staticmethod - def batcher(batched_args, batch_dims, out_dtype, contracting_dims, scaling_mode, fuse_bias, - fuse_gelu, grad): + def batcher( + batched_args, + batch_dims, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args lhs_bdims, *_ = batch_dims # Output is batched like LHS only if LHS is batched and RHS is not - out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None, ) - bias_bdims = (None, ) # Bias is never batched - pre_gelu_bdims = (None, ) # Pre-GeLU output, if exists, is batched like GEMM output + out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None,) + bias_bdims = (None,) # Bias is never batched + pre_gelu_bdims = (None,) # Pre-GeLU output, if exists, is batched like GEMM output if fuse_gelu and not grad: pre_gelu_bdims = out_bdims @@ -355,12 +396,21 @@ def batcher(batched_args, batch_dims, out_dtype, contracting_dims, scaling_mode, fuse_gelu=fuse_gelu, grad=grad, ), - (out_bdims, bias_bdims, pre_gelu_bdims) + (out_bdims, bias_bdims, pre_gelu_bdims), ) @staticmethod - def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse_bias, - fuse_gelu, grad, mesh, arg_infos, result_infos): + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + mesh, + arg_infos, + result_infos, + ): del out_dtype, scaling_mode, result_infos # Check contracting dimensions @@ -370,29 +420,29 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse sanitize_dims, operand_ndims, contracting_dims ) lhs_contracting_specs, rhs_contracting_specs = map( - lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None], + lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], (lhs_spec, rhs_spec), - (lhs_contracting_dims, rhs_contracting_dims) - ) - assert len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1, ( - "cuBLAS GEMM operands can have only one sharded contracting dimension." + (lhs_contracting_dims, rhs_contracting_dims), ) + assert ( + len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1 + ), "cuBLAS GEMM operands can have only one sharded contracting dimension." lhs_contracting_spec, rhs_contracting_spec = map( lambda spec: None if len(spec) == 0 else spec[0], - (lhs_contracting_specs, rhs_contracting_specs) - ) - assert lhs_contracting_spec == rhs_contracting_spec, ( - "cuBLAS GEMM operands must have the same sharding in contracting dimensions." + (lhs_contracting_specs, rhs_contracting_specs), ) + assert ( + lhs_contracting_spec == rhs_contracting_spec + ), "cuBLAS GEMM operands must have the same sharding in contracting dimensions." # Sanity check leading dimensions, allow for simultaneous batch and sequence sharding lhs_leading_dims, rhs_leading_dims = map( get_non_contracting_dims, operand_ndims, (lhs_contracting_dims, rhs_contracting_dims) ) lhs_leading_specs, rhs_leading_specs = map( - lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None ], + lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], (lhs_spec, rhs_spec), - (lhs_leading_dims, rhs_leading_dims) + (lhs_leading_dims, rhs_leading_dims), ) assert len(lhs_leading_specs) <= 1 and len(rhs_leading_specs) <= 1, ( "cuBLAS GEMM operands cannot have more than one sharded leading dimensions. This error " @@ -401,8 +451,7 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse # Determine output sharding lhs_leading_spec, rhs_leading_spec = map( - lambda spec: None if len(spec) == 0 else spec[0], - (lhs_leading_specs, rhs_leading_specs) + lambda spec: None if len(spec) == 0 else spec[0], (lhs_leading_specs, rhs_leading_specs) ) out_spec = (lhs_leading_spec, rhs_leading_spec) if operand_ndims[0] > 2 and operand_ndims[1] == 2: @@ -411,13 +460,13 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # Bias gradient sharding inherits the RHS contracting spec - bias_spec = (None, ) + bias_spec = (None,) if fuse_bias and grad: - bias_spec = (rhs_contracting_spec, ) + bias_spec = (rhs_contracting_spec,) bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) # Pre-GeLU sharding matches output sharding - pre_gelu_spec = (None, ) + pre_gelu_spec = (None,) if fuse_gelu and not grad: pre_gelu_spec = out_spec pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) @@ -425,11 +474,27 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse return (out_sharding, bias_sharding, pre_gelu_sharding) @staticmethod - def partition(out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, - mesh, arg_infos, result_infos): + def partition( + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + mesh, + arg_infos, + result_infos, + ): out_shardings = GemmPrimitive.infer_sharding_from_operands( - out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, - mesh, arg_infos, result_infos + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + mesh, + arg_infos, + result_infos, ) output_spec = out_shardings[0].spec @@ -442,13 +507,13 @@ def partition(out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, g scale_sharding = NamedSharding(mesh, PartitionSpec(None)) # Bias has to be sharded same as the trailing dimension of the GEMM output - bias_spec = (None, ) + bias_spec = (None,) if fuse_bias and not grad: - bias_spec = (output_spec[-1], ) + bias_spec = (output_spec[-1],) bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) # Pre-GeLU output has to be sharded same as the GEMM output - pre_gelu_spec = (None, ) + pre_gelu_spec = (None,) if fuse_gelu and grad: pre_gelu_spec = output_spec pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) @@ -481,7 +546,7 @@ def _te_gemm( gelu_input: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1, ), (-2, )), + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, @@ -848,9 +913,9 @@ def _jax_gemm_fp8_impl(lhs, rhs): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) or isinstance(rhs_q, ScaledTensor): - assert isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor), ( - "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." - ) + assert isinstance(lhs_q, ScaledTensor) and isinstance( + rhs_q, ScaledTensor + ), "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." return _jax_gemm_fp8_impl(lhs_q, rhs_q) return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) @@ -939,7 +1004,7 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, - **kwargs + **kwargs, ) # Discard empty outputs @@ -948,11 +1013,11 @@ def gemm( grad = kwargs.get("grad", False) clean_outputs = outputs[0] # first output is the final result and is never empty if (fuse_bias and grad) or (fuse_gelu and not grad): - clean_outputs = (outputs[0], ) + clean_outputs = (outputs[0],) if fuse_bias and grad: # only return bias gradient if it exists - clean_outputs += (outputs[1], ) + clean_outputs += (outputs[1],) if fuse_gelu and not grad: # only return pre-GeLU output if it exists - clean_outputs += (outputs[2], ) + clean_outputs += (outputs[2],) return clean_outputs diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index f2829c3a58..bf5c257d7b 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1322,9 +1322,13 @@ def normalization_fwd( raise ValueError(f"{norm_type=} is not supported.") if quantizer is None and noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype - ), mu, rsigma + return ( + ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), + mu, + rsigma, + ) return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 4bb4803fcb..3cb0e1cdfb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -587,19 +587,25 @@ def _quantize_dbias_impl( if noop_scaled_tensor: # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() # always works. - return ScaledTensorFactory.create_2x( - x, None, x, None, ScalingMode.NO_SCALING, dq_dtype=x.dtype, data_layout="NN", - flatten_axis=flatten_axis, - ), dbias + return ( + ScaledTensorFactory.create_2x( + x, + None, + x, + None, + ScalingMode.NO_SCALING, + dq_dtype=x.dtype, + data_layout="NN", + flatten_axis=flatten_axis, + ), + dbias, + ) return x, dbias # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if ( - quantizer.q_layout == QuantizeLayout.COLWISE - or not PrimitiveClass.enabled() - ): + if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -748,7 +754,10 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + dz, + quantizer=quantizer, + is_dbias=is_dbias, + flatten_axis=flatten_axis, noop_scaled_tensor=noop_scaled_tensor, ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 02822a3f0a..7736e0ec50 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -4,7 +4,6 @@ * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" -#include "transformer_engine/swizzle.h" #include #include @@ -12,8 +11,9 @@ #include "../extensions.h" #include "common/util/cuda_runtime.h" -#include "common/util/system.h" #include "common/util/string.h" +#include "common/util/system.h" +#include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 @@ -44,8 +44,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( std::vector scale_shape(scale_inv.dimensions().begin(), scale_inv.dimensions().end()); auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) - ? DType::kFloat8E8M0 - : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + ? DType::kFloat8E8M0 + : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { @@ -58,8 +58,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(swizzled_scale_inv->element_count() > 0, "Missing swizzled inverse scale buffer in the JAX primitive."); auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - auto swizzled_scale_inv_dtype = convert_ffi_datatype_to_te_dtype( - swizzled_scale_inv->element_type()); + auto swizzled_scale_inv_dtype = + convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, "Inverse scale factors need to have an 8-bit data type."); @@ -92,19 +92,18 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } -Error_Type GemmFFI( - cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, - Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type lhs_swizzle, - Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, - int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad) { +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, + Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad) { // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) - bool always_rowwise = ( - scaling_mode == JAXX_Scaling_Mode::NO_SCALING || ( - is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( @@ -117,12 +116,14 @@ Error_Type GemmFFI( (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), "cuBLAS GEMM output buffer size is incorrect, " - "expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, " + "expected ", + out_.numel(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(output->dimensions())); // Bias input to forward pass or bias gradient output from backward pass - void* bias_ptr = nullptr; + void *bias_ptr = nullptr; std::vector bias_shape = {0}; DType bias_dtype = out_dtype; if (fuse_bias) { @@ -137,7 +138,7 @@ Error_Type GemmFFI( auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); // Pre-GeLU output from forward pass or input to backward pass - void* pre_gelu_ptr = nullptr; + void *pre_gelu_ptr = nullptr; std::vector pre_gelu_shape = {0}; DType pre_gelu_dtype = out_dtype; if (gelu_input.element_count() > 0) { @@ -168,18 +169,18 @@ Error_Type GemmFFI( XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs - .Arg() // lhs_scale_inv - .Arg() // rhs - .Arg() // rhs_scale_inv - .Arg() // bias - .Arg() // gelu_input - .Ret() // output - .Ret() // bias_grad - .Ret() // pre_gelu_out - .Ret() // lhs_swizzled - .Ret() // rhs_swizzled - .Ret() // workspace + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace .Attr("scaling_mode") .Attr("lhs_axis_boundary") .Attr("rhs_axis_boundary") diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index b4ed3399d9..af7f54feb6 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -48,8 +48,8 @@ enum class JAXX_Scaling_Mode : int64_t { }; inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { - return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING - || mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING || + mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); } inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 0fe8d1bb9a..6f94620d99 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -90,24 +90,25 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, Returns: Tuple of (output, context) for backward pass """ - x_cdims, k_cdims = map( - tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims - ) + x_cdims, k_cdims = map(tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims) x_is_transposed = x.ndim - 1 not in x_cdims k_is_transposed = kernel.ndim - 1 in k_cdims - assert not x_is_transposed and not k_is_transposed, ( - "Forward-mode Dense layer implementation only supports 'NN' layout inputs." - ) + assert ( + not x_is_transposed and not k_is_transposed + ), "Forward-mode Dense layer implementation only supports 'NN' layout inputs." - casted_x = tex.quantize(x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, - noop_scaled_tensor=True) + casted_x = tex.quantize( + x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) rowwise_x = casted_x.get_rowwise_tensor() colwise_x = casted_x.get_colwise_tensor() casted_kernel = tex.quantize( - kernel, flatten_axis=max(k_cdims) + 1, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True + kernel, + flatten_axis=max(k_cdims) + 1, + quantizer=quantizer_set.kernel, + noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) colwise_k = casted_kernel.get_colwise_tensor() @@ -163,7 +164,10 @@ def _dense_bwd_rule( # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_grad, + quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) @@ -175,10 +179,7 @@ def _dense_bwd_rule( # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - rowwise_g, - rowwise_k, - contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), - grad=True + rowwise_g, rowwise_k, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 26b9d8fad5..bd8984ecec 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -197,8 +197,9 @@ def _layernorm_dense_fwd_rule( # Kernel (hidden_in, hidden_out) flatten_axis = 1 - kernel.ndim - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) rowwise_kernel = casted_kernel.get_rowwise_tensor() @@ -278,7 +279,10 @@ def _layernorm_dense_bwd_rule( # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_grad, + quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) @@ -291,10 +295,7 @@ def _layernorm_dense_bwd_rule( # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - rowwise_g, - rowwise_kernel, - contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), - grad=True + rowwise_g, rowwise_kernel, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, dot_input_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8a42555a5e..f76205d9f7 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -22,11 +22,7 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type -from .quantize import ( - with_sharding_constraint_by_logical_axes, - QuantizerSet, - noop_quantizer_set -) +from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .sharding import get_non_contracting_logical_axes @@ -271,8 +267,9 @@ def _layernorm_mlp_fwd_rule( casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) # FC1 kernel (hidden_in, act_len, hidden_out) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) # Prepare FC1 FPROP operands and layouts rowwise_ln_out = casted_ln_out.get_rowwise_tensor() @@ -309,13 +306,15 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # Activation (batch..., act_len, hidden_out) -> (batch..., hidden_out) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, - noop_scaled_tensor=True) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) # FC2 kernel (hidden_out, hidden_in) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) # Prepare FC2 FPROP operands and layouts rowwise_act_out = casted_act_out.get_rowwise_tensor() @@ -333,7 +332,7 @@ def _layernorm_mlp_fwd_rule( bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, contracting_dims=(x_cdims, k_cdims), - grad=False + grad=False, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -418,8 +417,11 @@ def _layernorm_mlp_bwd_rule( flatten_axis_grad = len(fwd_x_non_cdims) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, flatten_axis=flatten_axis_grad, - quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, + grad, + is_dbias=use_bias_2, + flatten_axis=flatten_axis_grad, + quantizer=ffn2_quantizer_set.dgrad, + noop_scaled_tensor=True, ) # Prepare FC2 DGRAD and WGRAD operands and contracting dims @@ -433,10 +435,7 @@ def _layernorm_mlp_bwd_rule( # FC2 DGRAD GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) dgrad_2 = tex.gemm( - rowwise_g, - rowwise_kernel_2, - contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), - grad=True + rowwise_g, rowwise_kernel_2, contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), grad=True ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -466,7 +465,9 @@ def _layernorm_mlp_bwd_rule( rowwise_dact_out = casted_dact_out.get_rowwise_tensor() rowwise_dact_out_cdims = tuple(range(flatten_axis_grad, rowwise_dact_out.ndim)) colwise_dact_out = casted_dact_out.get_colwise_tensor() - colwise_dact_out_cdims = tex.get_non_contracting_dims(casted_dact_out.ndim, rowwise_dact_out_cdims) + colwise_dact_out_cdims = tex.get_non_contracting_dims( + casted_dact_out.ndim, rowwise_dact_out_cdims + ) fwd_k1_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_1.ndim, fwd_k_cdims) # FC1 DGRAD GEMM: @@ -475,7 +476,7 @@ def _layernorm_mlp_bwd_rule( rowwise_dact_out, rowwise_kernel_1, contracting_dims=(rowwise_dact_out_cdims, fwd_k1_non_cdims), - grad=True + grad=True, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 4761816220..1c5d77c05b 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -142,7 +142,7 @@ def __post_init__(self): ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((1, ), dtype=jnp.float32) + self.scale_inv = jnp.empty((1,), dtype=jnp.float32) else: expected_scale_shape = self.scaling_mode.get_scale_shape( @@ -158,8 +158,10 @@ def __post_init__(self): f" {self.scale_inv.shape}" ) expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, - flatten_axis=self.flatten_axis + self.data.shape, + self.is_colwise, + is_padded=False, + flatten_axis=self.flatten_axis, ) if self.scale_inv.shape != expected_scale_shape: assert self.scale_inv.shape == expected_unpadded_scale_shape, ( @@ -168,8 +170,8 @@ def __post_init__(self): f" {self.scale_inv.shape}" ) pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, - expected_unpadded_scale_shape) + (0, a - b) + for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) ) # This actually pad scale_inv with nan, should we pad it with 127 directly instead? self.scale_inv = jnp.pad( From 92dec51a5bb293e265ae5315f3d4865e8246479e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 13 Jun 2025 06:55:20 +0000 Subject: [PATCH 04/30] FP8 tests passing on Blackwell but MXFP8 outputs NaN Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d553da10f3..81c64359b1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -123,34 +123,36 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ lhs_q = lhs rhs_q = rhs if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + scaling_mode = lhs_quantizer.scaling_mode lhs_q = lhs_quantizer.quantize( lhs, - is_rowwise=True, - is_colwise=False, + is_rowwise=True if scaling_mode.is_tensor_scaling() else not lhs_is_transposed, + is_colwise=False if scaling_mode.is_tensor_scaling() else lhs_is_transposed, flatten_axis=( max(lhs_contracting_dims) + 1 if lhs_is_transposed else min(lhs_contracting_dims) ), ) - if lhs_is_transposed: + if lhs_is_transposed and scaling_mode.is_tensor_scaling(): # Manually update data layout and columnwise flag to avoid transposing already # transposed data lhs_q.data_layout = "T" lhs_q.is_colwise = True if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + scaling_mode = rhs_quantizer.scaling_mode rhs_q = rhs_quantizer.quantize( rhs, - is_rowwise=True, - is_colwise=False, + is_rowwise=True if scaling_mode.is_tensor_scaling() else rhs_is_transposed, + is_colwise=False if scaling_mode.is_tensor_scaling() else not rhs_is_transposed, flatten_axis=( min(rhs_contracting_dims) if rhs_is_transposed else max(rhs_contracting_dims) + 1 ), ) - if not rhs_is_transposed: + if not rhs_is_transposed and scaling_mode.is_tensor_scaling(): # Manually update data layout and columnwise flag to avoid transposing already # transposed data rhs_q.data_layout = "T" @@ -515,7 +517,7 @@ def _te_gemm( scaling_mode = lhs_q.scaling_mode lhs_data = lhs_q.data lhs_scale_inv = lhs_q.scale_inv - if lhs_q.data_layout == "T": + if lhs_is_transposed and lhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): lhs_contracting_dims = transpose_contracting_dims(lhs_q.ndim, lhs_contracting_dims) if isinstance(rhs_q, ScaledTensor): @@ -535,7 +537,7 @@ def _te_gemm( ) rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv - if rhs_q.data_layout == "T": + if rhs_is_transposed and rhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): rhs_contracting_dims = transpose_contracting_dims(rhs_q.ndim, rhs_contracting_dims) # Dummy empties for bias and gelu From 9eba586c3ee1d0453ad34bb183b6b1bef5edfcce Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Sat, 14 Jun 2025 06:21:22 +0000 Subject: [PATCH 05/30] reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2 Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 150 ++++++------- transformer_engine/jax/cpp_extensions/gemm.py | 102 ++++----- .../jax/csrc/extensions/gemm.cpp | 6 +- transformer_engine/jax/dense.py | 100 ++++----- transformer_engine/jax/layernorm_dense.py | 115 +++++----- transformer_engine/jax/layernorm_mlp.py | 201 ++++++++---------- transformer_engine/jax/quantize/tensor.py | 2 +- 7 files changed, 303 insertions(+), 373 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 002a61c715..ccb667b8f2 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2,15 +2,13 @@ # # See LICENSE for license information. -import os -import operator -from functools import reduce -from typing import Union - -import pytest import jax import jax.numpy as jnp +import pytest from jax import jit, value_and_grad +from functools import reduce +from typing import Union +import operator from utils import ( assert_allclose, @@ -46,8 +44,6 @@ from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine_jax import is_non_nt_fp8_gemm_supported - GEMM_CASES = [ (256, 256, 512), (32, 32, 32), @@ -62,14 +58,10 @@ is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] -supported_fp8_gemm_layouts = [] """ Find supported scaling modes""" if is_fp8_supported: supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) - supported_fp8_gemm_layouts.append("NT") - if is_non_nt_fp8_gemm_supported(): - supported_fp8_gemm_layouts += ["TT", "TN", "NN"] if is_mxfp8_supported: supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) @@ -80,7 +72,7 @@ def is_shape_supported_by_mxfp8(input_shape): input_shape = input_shape.values[0] ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True - except AssertionError: + except: # get_scale_shapes will raise an exception if the shape is not supported return False @@ -154,13 +146,6 @@ def assert_dequantized_grouped_scaled_tensor( pytest.fail("a must be a GroupedScaledTensor object") -def use_jax_dot_for_gemm(enabled=False): - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_TYPES = [ ("gelu",), @@ -638,15 +623,15 @@ def test_quantize_bitwise( ): key = jax.random.PRNGKey(0) - inp = jax.random.uniform(key, input_shape, in_dtype) + input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - jax_output = _jax_quantize(inp, quantizer=jax_quantizer, flatten_axis=flatten_axis) + jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -724,21 +709,23 @@ def test_quantize_dbias( pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") key = jax.random.PRNGKey(0) - inp = jax.random.uniform(key, input_shape, in_dtype) + input = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) te_output, te_dbias = jit( - lambda input: tex.quantize_dbias(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) - )(inp) + lambda input: tex.quantize_dbias( + input, quantizer=te_quantizer, flatten_axis=flatten_axis + ) + )(input) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( - inp, quantizer=jax_quantizer, flatten_axis=flatten_axis + input, quantizer=jax_quantizer, flatten_axis=flatten_axis ) - )(inp) + )(input) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -863,6 +850,21 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +valid_fp8_gemm_operand_types = [ + (jnp.float8_e4m3fn, jnp.float8_e4m3fn), + (jnp.float8_e5m2, jnp.float8_e4m3fn), + (jnp.float8_e4m3fn, jnp.float8_e5m2), +] + + +def _use_jax_fp8_gemm(enabled=False): + import os + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": @@ -892,10 +894,7 @@ def _generate_gemm_input(self, m, n, k, data_layout): @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): - use_jax_dot_for_gemm(enabled=with_jax_gemm) - + def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims) @@ -904,55 +903,36 @@ def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper( - "lhs_q_dtype,rhs_q_dtype", - [ - (jnp.float8_e4m3fn, jnp.float8_e4m3fn), # fprop GEMM - (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM - (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM - ], - ) + @pytest_parametrize_wrapper("m,n,k", [(256, 512, 1024)]) + @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("data_layout", supported_fp8_gemm_layouts) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_gemm_fp8( - self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, with_jax_gemm - ): - use_jax_dot_for_gemm(enabled=with_jax_gemm) - - lhs, rhs, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + _use_jax_fp8_gemm(enabled=with_jax_gemm) + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=False, - ) - lhs_quantizer = quantizer_set.x if lhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad - rhs_quantizer = ( - quantizer_set.kernel if rhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad + is_2x2x=False ) - primitive_out = tex.gemm( - lhs, - rhs, - lhs_quantizer=lhs_quantizer, - rhs_quantizer=rhs_quantizer, + x, + w, contracting_dims=contracting_dims, + lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, + rhs_quantizer=( + quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), ) - ref_out = self._ref_gemm_with_jnp_dot(lhs, rhs, data_layout) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) - test_q_dtype = ( - jnp.float8_e5m2 if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) else jnp.float8_e4m3fn - ) - assert_allclose(primitive_out, ref_out, dtype=test_q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_bf16(self, m, n, k, with_jax_gemm): - use_jax_dot_for_gemm(enabled=with_jax_gemm) - + def test_dense_grad_bf16(self, m, n, k): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -981,7 +961,7 @@ def ref_func(x, w, data_layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): - use_jax_dot_for_gemm(enabled=with_jax_gemm) + _use_jax_fp8_gemm(enabled=with_jax_gemm) data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -1004,10 +984,9 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, + scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, + is_2x2x=True ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 @@ -1054,7 +1033,7 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g """ Test layernorm_dense VJP Rule """ - use_jax_dot_for_gemm(enabled=with_jax_gemm) + _use_jax_fp8_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1072,7 +1051,7 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1130,7 +1109,7 @@ def ref_func(x, w, gamma, beta): @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest.mark.parametrize("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm @@ -1138,7 +1117,7 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ - use_jax_dot_for_gemm(enabled=with_jax_gemm) + _use_jax_fp8_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1165,7 +1144,7 @@ def test_layernorm_mlp_grad( n_quantizer_sets=2, scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1190,26 +1169,25 @@ def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ) ) - def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): - dim_nums = ((1,), (0,)), ((), ()) - + def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - - linear_1_out = jax.lax.dot_general(ln_out, kernel_1, dim_nums) + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ()))) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - act_out = _jax_act_lu(linear_1_out, activation_type) - - linear_2_out = jax.lax.dot_general(act_out, kernel_2, dim_nums) + x = _jax_act_lu(linear_1_out, activation_type) + linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) - return jnp.mean(linear_2_out) + return linear_2_out + + def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): + return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2)) value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2c6794e6bc..2475264c7c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -18,7 +18,7 @@ from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive -from .quantization import grouped_quantize +from .quantization import quantize, grouped_quantize from ..quantize import ( ScaledTensor, ScaledTensor2x, @@ -81,21 +81,16 @@ def transpose_contracting_dims(ndim, contracting_dims): def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: - lhs, rhs, e4m3, e5m2, e8m0 = map( + lhs, rhs, e4m3, e5m2 = map( dtypes.canonicalize_dtype, ( lhs_dtype, rhs_dtype, jnp.float8_e4m3fn, jnp.float8_e5m2, - jnp.uint8, # replace with jnp.float8_e8m0 when JAX/XLA merges support ), ) - # MXFP8 GEMM needs both operands to be MXFP8 (uint8 for now until JAX merges float8_e8m0) - if lhs is e8m0 and rhs is e8m0: - return True - # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): return True @@ -114,44 +109,37 @@ def _get_gemm_layout( def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): - lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) - lhs_contracting_dims, rhs_contracting_dims = map( - sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims - ) - lhs_q = lhs rhs_q = rhs if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: - scaling_mode = lhs_quantizer.scaling_mode + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_is_rowwise = lhs.ndim - 1 in lhs_cdims + flatten_axis = ( + min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 + ) lhs_q = lhs_quantizer.quantize( lhs, - is_rowwise=True if scaling_mode.is_tensor_scaling() else not lhs_is_transposed, - is_colwise=False if scaling_mode.is_tensor_scaling() else lhs_is_transposed, - flatten_axis=( - max(lhs_contracting_dims) + 1 if lhs_is_transposed else min(lhs_contracting_dims) - ), + is_rowwise=lhs_is_rowwise, + is_colwise=not lhs_is_rowwise, + flatten_axis=flatten_axis, ) - if lhs_is_transposed and scaling_mode.is_tensor_scaling(): - # Manually update data layout and columnwise flag to avoid transposing already - # transposed data - lhs_q.data_layout = "T" - lhs_q.is_colwise = True + if isinstance(lhs_q, ScaledTensor2x): + lhs_q = lhs_q.get_rowwise_tensor() if lhs_is_rowwise else lhs_q.get_colwise_tensor() if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: - scaling_mode = rhs_quantizer.scaling_mode + rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_is_rowwise = rhs.ndim - 1 in rhs_cdims + flatten_axis = ( + min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 + ) rhs_q = rhs_quantizer.quantize( rhs, - is_rowwise=True if scaling_mode.is_tensor_scaling() else rhs_is_transposed, - is_colwise=False if scaling_mode.is_tensor_scaling() else not rhs_is_transposed, - flatten_axis=( - min(rhs_contracting_dims) if rhs_is_transposed else max(rhs_contracting_dims) + 1 - ), + is_rowwise=rhs_is_rowwise, + is_colwise=not rhs_is_rowwise, + flatten_axis=flatten_axis, ) - if not rhs_is_transposed and scaling_mode.is_tensor_scaling(): - # Manually update data layout and columnwise flag to avoid transposing already - # transposed data - rhs_q.data_layout = "T" - rhs_q.is_colwise = True + if isinstance(rhs_q, ScaledTensor2x): + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_rowwise else rhs_q.get_colwise_tensor() return lhs_q, rhs_q @@ -270,7 +258,7 @@ def abstract( # Need extra workspace for swizzled scale factors lhs_swizzle_size = 0 rhs_swizzle_size = 0 - swizzle_dtype = jnp.uint8 # replace with jnp.float8_e8m0 when JAX merges support + swizzle_dtype = jnp.uint8 if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_swizzle_size = lhs_scale_inv.size rhs_swizzle_size = rhs_scale_inv.size @@ -535,7 +523,6 @@ def partition( register_primitive(GemmPrimitive) -@lru_cache(maxsize=1) def gemm_uses_jax_dot() -> bool: """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" return not GemmPrimitive.enabled() @@ -560,7 +547,7 @@ def _te_gemm( rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) - lhs_contracting_dims, rhs_contracting_dims = map( + lhs_cdims, rhs_cdims = map( sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims ) @@ -574,16 +561,13 @@ def _te_gemm( "`Quantizer` object to quantize the RHS operand." ) if isinstance(lhs_q, ScaledTensor2x): - # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise - # shape. Since we have access to both row-wise and column-wise tensors, we always - # choose the one that avoids transposing LHS in the GEMM kernel to comply with the - # NT-layout restriction for FP8 GEMM on Hopper. + # Choose the quantization of the contracting dimension(s) lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() scaling_mode = lhs_q.scaling_mode lhs_data = lhs_q.data lhs_scale_inv = lhs_q.scale_inv - if lhs_is_transposed and lhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): - lhs_contracting_dims = transpose_contracting_dims(lhs_q.ndim, lhs_contracting_dims) + if lhs_q.data_layout == "T": + lhs_cdims = transpose_contracting_dims(lhs_q.ndim, lhs_cdims) if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -591,10 +575,7 @@ def _te_gemm( "`Quantizer` object to quantize the LHS operand." ) if isinstance(rhs_q, ScaledTensor2x): - # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise - # shape. Since we have access to both row-wise and column-wise tensors, we always - # choose the one that avoids transposing LHS in the GEMM kernel to comply with the - # NT-layout restriction for FP8 GEMM on Hopper. + # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( "cuBLAS GEMM quantized operands have mismatched scaling types, " @@ -602,8 +583,8 @@ def _te_gemm( ) rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv - if rhs_is_transposed and rhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): - rhs_contracting_dims = transpose_contracting_dims(rhs_q.ndim, rhs_contracting_dims) + if rhs_q.data_layout == "T": + rhs_cdims = transpose_contracting_dims(rhs_q.ndim, rhs_cdims) # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -620,7 +601,7 @@ def _te_gemm( bias, gelu_input, out_dtype=out_dtype, - contracting_dims=(lhs_contracting_dims, rhs_contracting_dims), + contracting_dims=(lhs_cdims, rhs_cdims), scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -911,16 +892,21 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - # Quantize operands (if necessary) - lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, + contracting_dims) - if isinstance(lhs_q, ScaledTensor) or isinstance(rhs_q, ScaledTensor): - assert isinstance(lhs_q, ScaledTensor) and isinstance( - rhs_q, ScaledTensor - ), "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." + if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): return _jax_gemm_fp8_impl(lhs_q, rhs_q) - return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) + if ( + isinstance(lhs, jnp.ndarray) + and isinstance(rhs, jnp.ndarray) + and lhs_quantizer is None + and rhs_quantizer is None + ): + return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) + + raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array") def gemm( @@ -987,7 +973,7 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled - if gemm_uses_jax_dot(): + if not GemmPrimitive.enabled(): assert kwargs.get("bias", None) is None and not kwargs.get("fuse_bias", False), ( "TE GEMM was invoked with bias fusion options that are not supported by the " "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7736e0ec50..987b0e817d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -42,7 +42,9 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); - std::vector scale_shape(scale_inv.dimensions().begin(), scale_inv.dimensions().end()); + auto scale_dims = scale_inv.dimensions(); + std::vector scale_shape = {product(scale_dims, 0, axis_boundary), + product(scale_dims, axis_boundary, scale_dims.size())}; auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); @@ -53,7 +55,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } // Swizzle scaling factors for MXFP8 - if (is_block_scaling(scaling_mode)) { + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { // Get the swizzle buffer NVTE_CHECK(swizzled_scale_inv->element_count() > 0, "Missing swizzled inverse scale buffer in the JAX primitive."); diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 6f94620d99..2aeb46594b 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -90,40 +90,29 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, Returns: Tuple of (output, context) for backward pass """ - x_cdims, k_cdims = map(tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims) - x_is_transposed = x.ndim - 1 not in x_cdims - k_is_transposed = kernel.ndim - 1 in k_cdims - assert ( - not x_is_transposed and not k_is_transposed - ), "Forward-mode Dense layer implementation only supports 'NN' layout inputs." - - casted_x = tex.quantize( - x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, noop_scaled_tensor=True - ) + x_contracting_dims, k_contracting_dims = contracting_dims + + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + + casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, + noop_scaled_tensor=True) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) - rowwise_x = casted_x.get_rowwise_tensor() - colwise_x = casted_x.get_colwise_tensor() casted_kernel = tex.quantize( - kernel, - flatten_axis=max(k_cdims) + 1, - quantizer=quantizer_set.kernel, + kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) - colwise_k = casted_kernel.get_colwise_tensor() - rowwise_k = casted_kernel.get_rowwise_tensor() - # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) - # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + # GEMM NN use_bias = bias is not None output = tex.gemm( - rowwise_x, - colwise_k, + casted_x.get_rowwise_tensor(), + casted_kernel.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, ) if use_bias and tex.gemm_uses_jax_dot(): @@ -131,10 +120,13 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, output += jnp.reshape(bias, bias_new_shape) ctx = ( - colwise_x, - rowwise_k, + casted_x.get_colwise_tensor(), + casted_kernel.get_rowwise_tensor(), + x.shape, + kernel.shape, use_bias, quantizer_set, + flatten_axis_k, ) return output, ctx @@ -147,49 +139,49 @@ def _dense_bwd_rule( Returns: Tuple of gradients with respect to inputs """ + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + ( - colwise_x, - rowwise_k, + colwise_casted_x, + rowwise_casted_kernel, + x_shape, + kernel_shape, use_bias, quantizer_set, + flatten_axis_k, ) = ctx - # Original non-contracting dimensions in the forward pass are contracting dimensions for the - # backward pass. - fwd_x_cdims, fwd_k_cdims = map( - tex.sanitize_dims, (colwise_x.ndim, rowwise_k.ndim), contracting_dims - ) - fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_x.ndim, fwd_x_cdims) - fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_k.ndim, fwd_k_cdims) - # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input - flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, - is_dbias=use_bias, - flatten_axis=flatten_axis_grad, - quantizer=quantizer_set.dgrad, + grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) - # Prepare DGRAD and WGRAD operands and contracting dims - rowwise_g = casted_grad.get_rowwise_tensor() - rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) - colwise_g = casted_grad.get_colwise_tensor() - colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) - - # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) + # GEMM NT + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim + g_contracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + ) + # k_non_contracting_dims + k_contracting_dim = tuple( + dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + ) dgrad = tex.gemm( - rowwise_g, rowwise_k, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True + casted_grad.get_rowwise_tensor(), + rowwise_casted_kernel, + contracting_dims=(g_contracting_dim, k_contracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) - # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) - # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) + # GEMM TN + # x_non_contracting_dims + g_contracting_dim = x_contracting_dim = tuple( + range(0, len(x_shape) - len(fwd_x_contracting_dims)) + ) + wgrad = tex.gemm( - colwise_x, - colwise_g, - contracting_dims=(fwd_x_non_cdims, colwise_g_cdims), - grad=True, + colwise_casted_x, + casted_grad.get_colwise_tensor(), + contracting_dims=(x_contracting_dim, g_contracting_dim) ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index bd8984ecec..96e66ff946 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -173,12 +173,12 @@ def _layernorm_dense_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - x_cdims = (x.ndim - 1,) - k_cdims = (0,) + x_contracting_dims = (len(x.shape) - 1,) + k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) + casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -186,52 +186,47 @@ def _layernorm_dense_fwd_rule( zero_centered_gamma, epsilon, norm_type, - quantizer_set.x, - noop_scaled_tensor=True, + quantizer=quantizer_set.x, + noop_scaled_tensor=True ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) - # Layernorm output (batch..., hidden_in) - rowwise_ln_out = casted_ln_out.get_rowwise_tensor() - colwise_ln_out = casted_ln_out.get_colwise_tensor() - - # Kernel (hidden_in, hidden_out) - flatten_axis = 1 - kernel.ndim - casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True - ) + # Kernel in (hidden_in, hidden_out...) + flatten_axis = 1 - len(kernel.shape) + casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + noop_scaled_tensor=True) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) - rowwise_kernel = casted_kernel.get_rowwise_tensor() - colwise_kernel = casted_kernel.get_colwise_tensor() - - # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) - # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + # NN GEMM + # (batch..., hidden_in) x (hidden_in, hidden_out...) use_bias = bias is not None output = tex.gemm( - rowwise_ln_out, - colwise_kernel, + casted_ln_out.get_rowwise_tensor(), + casted_kernel.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, ) + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - colwise_ln_out, - rowwise_kernel, + casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, + casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, + x.shape, + kernel.shape, mu, rsigma, x, gamma, beta, - x_cdims, - k_cdims, + x_contracting_dims, + k_contracting_dims, use_bias, quantizer_set, + flatten_axis, ) return output, ctx @@ -242,7 +237,7 @@ def _layernorm_dense_bwd_rule( zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes, + dot_input_axes, # pylint: disable=unused-argument kernel_axes, ctx, grad, @@ -259,57 +254,58 @@ def _layernorm_dense_bwd_rule( Tuple of gradients for all input parameters """ ( - colwise_ln_out, - rowwise_kernel, + colwise_casted_ln_out, + rowwise_casted_kernel, + x_shape, + kernel_shape, mu, rsigma, x, gamma, beta, - fwd_x_cdims, - fwd_k_cdims, + x_contracting_dims_in_fwd, + k_contracting_dims_in_fwd, use_bias, quantizer_set, + flatten_axis, ) = ctx - # Original non-contracting dimensions in the forward pass are contracting dimensions for the - # backward pass. - fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) - fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_kernel.ndim, fwd_k_cdims) - # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input - flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, - is_dbias=use_bias, - flatten_axis=flatten_axis_grad, - quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, + grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True ) - # Prepare DGRAD and WGRAD operands and contracting dims - rowwise_g = casted_grad.get_rowwise_tensor() - rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) - colwise_g = casted_grad.get_colwise_tensor() - colwise_ln_out_cdims = fwd_x_non_cdims - colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim + g_constracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim) + ) + # k_non_contracting_dims + k_constracting_dim = tuple( + dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd + ) - # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) + # NT GEMM dgrad = tex.gemm( - rowwise_g, rowwise_kernel, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True + casted_grad.get_rowwise_tensor(), + rowwise_casted_kernel, + contracting_dims=(g_constracting_dim, k_constracting_dim), + ) + + dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) + + g_constracting_dim = x_constracting_dim = tuple( + range(0, len(x_shape) - len(x_contracting_dims_in_fwd)) ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, dot_input_axes) - # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) - # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) + # TN GEMM wgrad = tex.gemm( - colwise_ln_out, - colwise_g, - contracting_dims=(colwise_ln_out_cdims, colwise_g_cdims), - grad=True, + colwise_casted_ln_out, + casted_grad.get_colwise_tensor(), + contracting_dims=(x_constracting_dim, g_constracting_dim), ) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) - # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad, x, @@ -321,7 +317,6 @@ def _layernorm_dense_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) - dx = with_sharding_constraint_by_logical_axes(dx, layernorm_input_axes) return dx, wgrad, dgamma, dbeta, dbias, quantizer_set diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index f76205d9f7..8deb83d87a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -244,16 +244,16 @@ def _layernorm_mlp_fwd_rule( assert len(kernel_2.shape) == 2 assert kernel_1.shape[-2] == len(activation_type) - x_cdims = (x.ndim - 1,) - k_cdims = (0,) + x_contracting_dims = (len(x.shape) - 1,) + k_contracting_dims = (0,) - assert x.shape[x_cdims[0]] == kernel_1.shape[k_cdims[0]] + assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None - # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) + casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -266,35 +266,23 @@ def _layernorm_mlp_fwd_rule( ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - # FC1 kernel (hidden_in, act_len, hidden_out) - casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True - ) - - # Prepare FC1 FPROP operands and layouts - rowwise_ln_out = casted_ln_out.get_rowwise_tensor() - rowwise_kernel_1 = casted_kernel_1.get_rowwise_tensor() - colwise_ln_out = casted_ln_out.get_colwise_tensor() - colwise_kernel_1 = casted_kernel_1.get_colwise_tensor() + casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, + noop_scaled_tensor=True) - # FC1 GEMM: - # (batch..., hidden_in) x (hidden_in, act_len, hidden_out) = (batch..., act_len, hidden_out) - # FC1 FP8 GEMM: - # (batch..., hidden_in) x (hidden_out, act_len, hidden_in)^T = (batch..., act_len, hidden_out) - use_bias_1 = bias_1 is not None + # NN GEMM + # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( - rowwise_ln_out, - colwise_kernel_1, + casted_ln_out.get_rowwise_tensor(), + casted_kernel_1.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False ) if dot_1_input_axes is not None and kernel_1_axes is not None: dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_cdims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_cdims), + *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), + *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), ) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) @@ -305,34 +293,23 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) - # Activation (batch..., act_len, hidden_out) -> (batch..., hidden_out) - casted_act_out = tex.act_lu( - dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True - ) - casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) + # (batch..., hidden_in) -> (batch..., hidden) + casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + noop_scaled_tensor=True) - # FC2 kernel (hidden_out, hidden_in) - casted_kernel_2 = tex.quantize( - kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True - ) + casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - # Prepare FC2 FPROP operands and layouts - rowwise_act_out = casted_act_out.get_rowwise_tensor() - rowwise_kernel_2 = casted_kernel_2.get_rowwise_tensor() - colwise_act_out = casted_act_out.get_colwise_tensor() - colwise_kernel_2 = casted_kernel_2.get_colwise_tensor() + casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, + noop_scaled_tensor=True) - # FC2 GEMM: - # (batch..., hidden_out) x (hidden_out, hidden_in) = (batch..., hidden_in) - # FC2 FP8 GEMM: - # (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) + # NN GEMM + # (batch..., hidden_in) x (hidden_out, hidden_in) dot_2_output = tex.gemm( - rowwise_act_out, - colwise_kernel_2, + casted_act_out.get_rowwise_tensor(), + casted_kernel_2.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -348,13 +325,15 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - colwise_ln_out, - rowwise_kernel_1, + casted_ln_out.get_colwise_tensor(), + casted_kernel_1.get_rowwise_tensor(), dot_1_output, - colwise_act_out, - rowwise_kernel_2, - x_cdims, - k_cdims, + casted_act_out.get_colwise_tensor(), + casted_kernel_2.get_rowwise_tensor(), + x_contracting_dims, + k_contracting_dims, + kernel_1.shape, + kernel_2.shape, use_bias_1, use_bias_2, quantizer_sets, @@ -391,20 +370,22 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, rsigma, gamma, beta, - colwise_ln_out, - rowwise_kernel_1, + colwise_casted_ln_out, + rowwise_casted_kernel_1, dot_1_output, - colwise_act_out, - rowwise_kernel_2, - fwd_x_cdims, - fwd_k_cdims, + colwise_casted_act_out, + rowwise_casted_kernel_2, + x_contracting_dims_in_fwd, + k_contracting_dims_in_fwd, + kernel_1_shape, + kernel_2_shape, use_bias_1, use_bias_2, quantizer_sets, @@ -412,87 +393,84 @@ def _layernorm_mlp_bwd_rule( ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets - # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input - fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) - flatten_axis_grad = len(fwd_x_non_cdims) + # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) + casted_grad, dbias_2 = tex.quantize_dbias( - grad, - is_dbias=use_bias_2, - flatten_axis=flatten_axis_grad, - quantizer=ffn2_quantizer_set.dgrad, - noop_scaled_tensor=True, + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + noop_scaled_tensor=True ) - # Prepare FC2 DGRAD and WGRAD operands and contracting dims - rowwise_g = casted_grad.get_rowwise_tensor() - rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) - fwd_k2_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_2.ndim, fwd_k_cdims) - - colwise_g = casted_grad.get_colwise_tensor() - colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) - colwise_act_out_cdims = tex.get_non_contracting_dims(colwise_act_out.ndim, fwd_x_cdims) + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim + g_contracting_dims_2 = tuple( + range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) + ) + # k_non_contracting_dims + k_contracting_dims_2 = tuple( + dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd + ) - # FC2 DGRAD GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + # NT GEMM + # (batch..., hidden_out) x (hidden_in, hidden_out) dgrad_2 = tex.gemm( - rowwise_g, rowwise_kernel_2, contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), grad=True + casted_grad.get_rowwise_tensor(), + rowwise_casted_kernel_2, + contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), ) + dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - # FC2 WGRAD GEMM: - # (batch..., hidden_out)^T x (batch..., hidden_in) = (hidden_out, hidden_in) - # FC2 WGRAD FP8 GEMM: - # (hidden_out, batch...) x (hidden_in, batch...)^T = (hidden_out, hidden_in) + x_contracting_dims = g_contracting_dims = tuple( + range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) + ) + + # TN GEMM + # (hidden, batch...,) x (hidden, batch...) wgrad_2 = tex.gemm( - colwise_act_out, - colwise_g, - contracting_dims=(colwise_act_out_cdims, colwise_g_cdims), - grad=True, + colwise_casted_act_out, + casted_grad.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, g_contracting_dims), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) - # Activation gradient w/ bias fusion (batch..., hidden_out) -> (batch.., act_len, hidden_out) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( dgrad_2, dot_1_output, activation_type=activation_type, is_dbias=use_bias_1, - quantizer=ffn1_quantizer_set.dgrad, + quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, ) - # Prepare FC1 DGRAD and WGRAD operands and contracting dims - rowwise_dact_out = casted_dact_out.get_rowwise_tensor() - rowwise_dact_out_cdims = tuple(range(flatten_axis_grad, rowwise_dact_out.ndim)) - colwise_dact_out = casted_dact_out.get_colwise_tensor() - colwise_dact_out_cdims = tex.get_non_contracting_dims( - casted_dact_out.ndim, rowwise_dact_out_cdims + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim + dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim + g_contracting_dims_1 = tuple( + range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) + ) + # k_non_contracting_dims + k_contracting_dims_1 = tuple( + dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd ) - fwd_k1_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_1.ndim, fwd_k_cdims) - # FC1 DGRAD GEMM: - # (batch..., act_len, hidden_out) x (hidden_in, act_len, hidden_out)^T = (batch..., hidden_in) + # NT GEMM dgrad_1 = tex.gemm( - rowwise_dact_out, - rowwise_kernel_1, - contracting_dims=(rowwise_dact_out_cdims, fwd_k1_non_cdims), - grad=True, + casted_dact_out.get_rowwise_tensor(), + rowwise_casted_kernel_1, + contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), ) + dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) - # FC1 WGRAD GEMM: - # (batch..., hidden_in)^T x (batch..., act_len, hidden_out) = (hidden_in, act_len, hidden_out) - # FC1 WGRAD FP8 GEMM: - # (hidden_in, batch...) x (hidden_out, act_len, batch...)^T = (hidden_in, act_len, hidden_out) + # TN GEMM + # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( - colwise_ln_out, - colwise_dact_out, - contracting_dims=(fwd_x_non_cdims, colwise_dact_out_cdims), - grad=True, + colwise_casted_ln_out, + casted_dact_out.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, g_contracting_dims), ) + wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) - # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad_1, x, @@ -504,7 +482,6 @@ def _layernorm_mlp_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) - dx = with_sharding_constraint_by_logical_axes(dx, norm_input_axes) return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 1c5d77c05b..98dc38de24 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -142,7 +142,7 @@ def __post_init__(self): ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((1,), dtype=jnp.float32) + self.scale_inv = jnp.empty((0,), dtype=jnp.float32) else: expected_scale_shape = self.scaling_mode.get_scale_shape( From b80e2843ad1750821295fa3d16a998f6ba18f4e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Jun 2025 06:21:52 +0000 Subject: [PATCH 06/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 8 +++++--- transformer_engine/jax/cpp_extensions/gemm.py | 15 ++++---------- transformer_engine/jax/dense.py | 16 ++++++++++----- transformer_engine/jax/layernorm_dense.py | 14 ++++++++----- transformer_engine/jax/layernorm_mlp.py | 20 ++++++++++--------- 5 files changed, 40 insertions(+), 33 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ccb667b8f2..cd3e09b64c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -859,6 +859,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( def _use_jax_fp8_gemm(enabled=False): import os + if enabled: os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: @@ -916,7 +917,7 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=False + is_2x2x=False, ) primitive_out = tex.gemm( x, @@ -984,9 +985,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2475264c7c..124703eaec 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -114,9 +114,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) lhs_is_rowwise = lhs.ndim - 1 in lhs_cdims - flatten_axis = ( - min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 - ) + flatten_axis = min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 lhs_q = lhs_quantizer.quantize( lhs, is_rowwise=lhs_is_rowwise, @@ -129,9 +127,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) rhs_is_rowwise = rhs.ndim - 1 in rhs_cdims - flatten_axis = ( - min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 - ) + flatten_axis = min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( rhs, is_rowwise=rhs_is_rowwise, @@ -547,9 +543,7 @@ def _te_gemm( rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) - lhs_cdims, rhs_cdims = map( - sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims - ) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -892,8 +886,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, - contracting_dims) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): return _jax_gemm_fp8_impl(lhs_q, rhs_q) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 2aeb46594b..8c5da04371 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -95,12 +95,15 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, - noop_scaled_tensor=True) + casted_x = tex.quantize( + x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, + kernel, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.kernel, noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -152,7 +155,10 @@ def _dense_bwd_rule( ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) @@ -181,7 +187,7 @@ def _dense_bwd_rule( wgrad = tex.gemm( colwise_casted_x, casted_grad.get_colwise_tensor(), - contracting_dims=(x_contracting_dim, g_contracting_dim) + contracting_dims=(x_contracting_dim, g_contracting_dim), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 96e66ff946..6be7954862 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -187,14 +187,15 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, - noop_scaled_tensor=True + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM @@ -271,8 +272,11 @@ def _layernorm_dense_bwd_rule( ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8deb83d87a..a9aef69ab0 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -266,8 +266,9 @@ def _layernorm_mlp_fwd_rule( ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) @@ -276,7 +277,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_1.get_colwise_tensor(), contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) if dot_1_input_axes is not None and kernel_1_axes is not None: @@ -294,13 +295,15 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, - noop_scaled_tensor=True) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) @@ -397,8 +400,7 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, - noop_scaled_tensor=True + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From a7aa2f4f1a112a17a37ae356ccbbc3aa1c754fb8 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 17 Jun 2025 09:16:49 +0000 Subject: [PATCH 07/30] MXFP8 issue traced to scale factor padding with NaNs instead of zeros Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 9 +++- transformer_engine/jax/cpp_extensions/gemm.py | 46 ++++++++++++++++--- .../jax/csrc/extensions/gemm.cpp | 9 ++-- transformer_engine/jax/quantize/tensor.py | 27 ++++------- 4 files changed, 60 insertions(+), 31 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index cd3e09b64c..9ea44bb1f7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -904,12 +904,19 @@ def test_gemm_bf16(self, m, n, k, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(256, 512, 1024)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + if ( + not with_jax_gemm + and scaling_mode.is_1d_block_scaling() + and jnp.float8_e5m2 in (x_qtype, w_qtype) + ): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + _use_jax_fp8_gemm(enabled=with_jax_gemm) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 124703eaec..c54314e283 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -7,7 +7,7 @@ import operator from collections.abc import Iterable from typing import Tuple, Sequence, Union -from functools import partial, reduce, lru_cache +from functools import partial, reduce import jax import jax.numpy as jnp @@ -18,13 +18,14 @@ from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive -from .quantization import quantize, grouped_quantize +from .quantization import grouped_quantize from ..quantize import ( ScaledTensor, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, + BlockScaleQuantizer, GroupedQuantizer, QuantizeConfig, QuantizerSet, @@ -123,6 +124,10 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ) if isinstance(lhs_q, ScaledTensor2x): lhs_q = lhs_q.get_rowwise_tensor() if lhs_is_rowwise else lhs_q.get_colwise_tensor() + if jnp.any(jnp.isnan(lhs_q.data)): + print("Found NaNs in quantized LHS data.") + if jnp.any(jnp.isnan(lhs_q.scale_inv)): + print("Found NaNs in quantized LHS scale_inv.") if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) @@ -136,6 +141,10 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ) if isinstance(rhs_q, ScaledTensor2x): rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_rowwise else rhs_q.get_colwise_tensor() + if jnp.any(jnp.isnan(rhs_q.data)): + print("Found NaNs in quantized RHS data.") + if jnp.any(jnp.isnan(rhs_q.scale_inv)): + print("Found NaNs in quantized RHS scale_inv.") return lhs_q, rhs_q @@ -165,7 +174,10 @@ def abstract( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): + del use_split_accumulator + # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) ( @@ -288,6 +300,7 @@ def lowering( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): del out_dtype lhs_aval, _, rhs_aval, *_ = ctx.avals_in @@ -306,6 +319,7 @@ def lowering( "fuse_bias": fuse_bias, "fuse_gelu": fuse_gelu, "grad": grad, + "use_split_accumulator": use_split_accumulator, } operand_output_aliases = {} @@ -333,6 +347,7 @@ def impl( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): outputs = GemmPrimitive.inner_primitive.bind( lhs, @@ -347,6 +362,7 @@ def impl( fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, + use_split_accumulator=use_split_accumulator, ) return outputs[:-3] # discard workspace arrays @@ -360,6 +376,7 @@ def batcher( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args @@ -381,6 +398,7 @@ def batcher( fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, + use_split_accumulator=use_split_accumulator, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -393,11 +411,12 @@ def infer_sharding_from_operands( fuse_bias, fuse_gelu, grad, + use_split_accumulator, mesh, arg_infos, result_infos, ): - del out_dtype, scaling_mode, result_infos + del out_dtype, scaling_mode, use_split_accumulator, result_infos # Check contracting dimensions lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) @@ -467,6 +486,7 @@ def partition( fuse_bias, fuse_gelu, grad, + use_split_accumulator, mesh, arg_infos, result_infos, @@ -478,6 +498,7 @@ def partition( fuse_bias, fuse_gelu, grad, + use_split_accumulator, mesh, arg_infos, result_infos, @@ -535,6 +556,7 @@ def _te_gemm( fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, + use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands lhs_data = lhs @@ -600,6 +622,7 @@ def _te_gemm( fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, + use_split_accumulator=use_split_accumulator, ) @@ -889,6 +912,12 @@ def _jax_gemm_fp8_impl(lhs, rhs): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): + if isinstance(lhs_q, ScaledTensor2x): + lhs_is_transposed = lhs.ndim - 1 not in sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + if isinstance(rhs_q, ScaledTensor2x): + rhs_is_transposed = rhs.ndim - 1 in sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( @@ -941,6 +970,9 @@ def gemm( grad: bool, default = False Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with TE's custom call to cuBLAS GEMM. + use_split_accumulator: bool, default = True + Enable promoting some intermediate sums to higher precision when accumulating the result in + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Returns ------- @@ -966,13 +998,15 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) if not GemmPrimitive.enabled(): - assert kwargs.get("bias", None) is None and not kwargs.get("fuse_bias", False), ( + assert kwargs.get("bias", None) is None and not fuse_gelu, ( "TE GEMM was invoked with bias fusion options that are not supported by the " "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) - assert kwargs.get("gelu_input", None) is None and not kwargs.get("fuse_gelu", False), ( + assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( "TE GEMM was invoked with GeLU fusion options that are not supported by the " "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." @@ -989,8 +1023,6 @@ def gemm( ) # Discard empty outputs - fuse_bias = kwargs.get("fuse_bias", False) - fuse_gelu = kwargs.get("fuse_gelu", False) grad = kwargs.get("grad", False) clean_outputs = outputs[0] # first output is the final result and is never empty if (fuse_bias and grad) or (fuse_gelu and not grad): diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 987b0e817d..f7a84e41c9 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -100,7 +100,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad) { + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) @@ -162,8 +162,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, false, - num_math_sm, stream); + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } @@ -190,7 +190,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("rhs_transposed") .Attr("fuse_bias") .Attr("fuse_gelu") - .Attr("grad"), + .Attr("grad") + .Attr("use_split_accumulator"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 98dc38de24..b4e809106e 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -157,26 +157,15 @@ def __post_init__(self): f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" f" {self.scale_inv.shape}" ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, - self.is_colwise, - is_padded=False, - flatten_axis=self.flatten_axis, + pad_width = tuple( + (0, a - b) + for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + ) + + # This actually pad scale_inv with nan, should we pad it with 127 directly instead? + self.scale_inv = jnp.pad( + self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) - for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) - ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 - ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 1be8773ae58e7d71a0c98ec07085c8b1af47f056 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 07:41:22 -0700 Subject: [PATCH 08/30] padding scale with 2^-127 instead of nans Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/tensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index b4e809106e..19a8658b8c 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -162,11 +162,13 @@ def __post_init__(self): for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? + # padding with the smallest number it can present self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 ) + + def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 75008de5893c3a768bab2b4f73786e1c6bbc97d0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 07:41:44 -0700 Subject: [PATCH 09/30] fix bug on rhs_scale_inv usage Signed-off-by: Phuong Nguyen --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f7a84e41c9..249d229de4 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -111,7 +111,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, rhs, rhs_scale_inv, lhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], From 5b0c1f57d9c16840c04d2ca523bd63bf66f206ab Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 08:28:09 -0700 Subject: [PATCH 10/30] cleanup E8M0 type converter use it in gemm.cpp Signed-off-by: Phuong Nguyen --- transformer_engine/jax/csrc/extensions/ffi.cpp | 7 +++---- transformer_engine/jax/csrc/extensions/gemm.cpp | 14 ++++++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index a760df4a79..e77c38e990 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E4M3FN: return DType::kFloat8E4M3; break; - // case xla::ffi::DataType::F8E8M0FNU: - // return DType::kFloat8E8M0; - // break; + case xla::ffi::DataType::F8E8M0FNU: + return DType::kFloat8E8M0; + break; default: auto type_num = static_cast(type); - if (type_num == 33) return DType::kFloat8E8M0; NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", static_cast(type_num)); break; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 249d229de4..a3a1899b14 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -45,9 +45,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( auto scale_dims = scale_inv.dimensions(); std::vector scale_shape = {product(scale_dims, 0, axis_boundary), product(scale_dims, axis_boundary, scale_dims.size())}; - auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) - ? DType::kFloat8E8M0 - : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { @@ -66,14 +64,14 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( "Inverse scale factors need to have an 8-bit data type."); // Create tensor to hold swizzled scale factor - TensorWrapper output(NVTE_MXFP8_1D_SCALING); + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); if (rowwise) { output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } @@ -82,10 +80,10 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } } From b49d586325b5af2afd0fdfb5e028675c5dd7520a Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 18 Jun 2025 06:12:09 +0000 Subject: [PATCH 11/30] segfault fixed, passing all unittests on Blackwell Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 2 +- tests/jax/test_layer.py | 14 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 43 +++++++++++-------- .../jax/csrc/extensions/gemm.cpp | 10 +++-- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9ea44bb1f7..3de6453f5e 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1118,7 +1118,7 @@ def ref_func(x, w, gamma, beta): @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 389148415a..058dbf5bb4 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -479,6 +479,7 @@ def generate_inputs(self, data_shape, dtype): @pytest.mark.parametrize("data_shape", DATA_SHAPE) @pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("with_jax_gemm", (False, True)) @pytest.mark.parametrize("attrs", ATTRS) class BaseTester: """ @@ -494,20 +495,22 @@ def use_jax_dot_for_gemm(self, enabled=False): elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - def test_forward(self, data_shape, dtype, attrs): + def test_forward(self, data_shape, dtype, with_jax_gemm, attrs): """Test normal datatype forward""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype) - def test_backward(self, data_shape, dtype, attrs): + + def test_backward(self, data_shape, dtype, with_jax_gemm, attrs): """Test normal datatype backward""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - @pytest.mark.parametrize("with_jax_gemm", (False, True)) - def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): + def test_forward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): """Test forward with fp8 enabled""" self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) @@ -516,8 +519,7 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_g @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - @pytest.mark.parametrize("with_jax_gemm", (False, True)) - def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): + def test_backward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): """Test backward with fp8 enabled""" self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c54314e283..a8cbe75702 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -25,7 +25,6 @@ GroupedScaledTensor1x, ScalingMode, Quantizer, - BlockScaleQuantizer, GroupedQuantizer, QuantizeConfig, QuantizerSet, @@ -114,37 +113,43 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ rhs_q = rhs if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) - lhs_is_rowwise = lhs.ndim - 1 in lhs_cdims - flatten_axis = min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 + lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims + need_lhs_colwise = ( + lhs_is_transposed + and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not tex.is_non_nt_fp8_gemm_supported() + ) + ) + flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, + is_rowwise=not need_lhs_colwise, + is_colwise=need_lhs_colwise, flatten_axis=flatten_axis, ) if isinstance(lhs_q, ScaledTensor2x): - lhs_q = lhs_q.get_rowwise_tensor() if lhs_is_rowwise else lhs_q.get_colwise_tensor() - if jnp.any(jnp.isnan(lhs_q.data)): - print("Found NaNs in quantized LHS data.") - if jnp.any(jnp.isnan(lhs_q.scale_inv)): - print("Found NaNs in quantized LHS scale_inv.") + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) - rhs_is_rowwise = rhs.ndim - 1 in rhs_cdims - flatten_axis = min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 + rhs_is_transposed = rhs.ndim - 1 in rhs_cdims + need_rhs_colwise = ( + not rhs_is_transposed + and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not tex.is_non_nt_fp8_gemm_supported() + ) + ) + flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, + is_rowwise=not need_rhs_colwise, + is_colwise=need_rhs_colwise, flatten_axis=flatten_axis, ) if isinstance(rhs_q, ScaledTensor2x): - rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_rowwise else rhs_q.get_colwise_tensor() - if jnp.any(jnp.isnan(rhs_q.data)): - print("Found NaNs in quantized RHS data.") - if jnp.any(jnp.isnan(rhs_q.scale_inv)): - print("Found NaNs in quantized RHS scale_inv.") + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() return lhs_q, rhs_q diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a3a1899b14..c88aac3ff6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -42,9 +42,13 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); - auto scale_dims = scale_inv.dimensions(); - std::vector scale_shape = {product(scale_dims, 0, axis_boundary), - product(scale_dims, axis_boundary, scale_dims.size())}; + std::vector scale_shape = {1}; + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Block scaling also needs to be collapsed to match 2D data + scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), + product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + } + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); From bd9bca34cedbd9e36e6956c620cfaa1f2b1bdbf1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Jun 2025 06:07:35 -0700 Subject: [PATCH 12/30] fix for fuseddense tests Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 2 -- transformer_engine/jax/layernorm_dense.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0b7f68d79e..e4408f74c1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -629,8 +629,6 @@ def _te_gemm( ) -======= ->>>>>>> main class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 3cdbeecd5a..816436a00e 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -293,7 +293,6 @@ def _layernorm_dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - (g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim), ) @@ -307,7 +306,6 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - (x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim), ) From 8fcb1bb0f34de33c452c308170b29e9743284978 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Jun 2025 08:02:55 -0700 Subject: [PATCH 13/30] fix workspace alignment Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 9 ++++---- .../jax/csrc/extensions/gemm.cpp | 21 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e4408f74c1..d5946e1718 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -32,7 +32,7 @@ noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, ) -from ..sharding import get_padded_spec +from .misc import get_padded_spec __all__ = [ @@ -277,9 +277,10 @@ def abstract( rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) # Declare cuBLAS workspace - workspace = jax.core.ShapedArray( - shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 - ) + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0fb8734998..eed2ed2a4f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -21,6 +21,13 @@ namespace transformer_engine { namespace jax { + +static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { + // Move the pointer to the next 256B aligned address + return reinterpret_cast((reinterpret_cast(ptr) + 255) & + ~static_cast(255)); +} + std::tuple> xla_buffer_to_nvte_gemm_operand( cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { @@ -157,9 +164,11 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); - // cuBLAS workspace - std::vector workspace_shape = {static_cast(workspace->element_count())}; - auto workspace_ = TensorWrapper(workspace->untyped_data(), workspace_shape, DType::kByte); + // cuBLAS workspace + 256 alignment enforcement + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; + auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -197,12 +206,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, FFI_CudaGraph_Traits); -static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { - // Move the pointer to the next 256B aligned address - return reinterpret_cast((reinterpret_cast(ptr) + 255) & - ~static_cast(255)); -} - Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, From b2b41592b444ae3ac0b67b44ce7146ac6e25fc74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:05:07 +0000 Subject: [PATCH 14/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_layer.py | 1 - transformer_engine/jax/cpp_extensions/gemm.py | 32 ++++++++----------- .../jax/csrc/extensions/gemm.cpp | 8 ++--- transformer_engine/jax/quantize/tensor.py | 5 +-- 4 files changed, 16 insertions(+), 30 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 058dbf5bb4..9f3da5094f 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -501,7 +501,6 @@ def test_forward(self, data_shape, dtype, with_jax_gemm, attrs): QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype) - def test_backward(self, data_shape, dtype, with_jax_gemm, attrs): """Test normal datatype backward""" self.use_jax_dot_for_gemm(enabled=with_jax_gemm) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d5946e1718..9cd3c85619 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -36,13 +36,13 @@ __all__ = [ - "gemm", - "grouped_gemm", - "gemm_uses_jax_dot", - "sanitize_dims", - "get_non_contracting_dims", - "transpose_contracting_dims", - ] + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_contracting_dims", +] num_cublas_streams = get_num_compute_streams() @@ -109,12 +109,9 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims - need_lhs_colwise = ( - lhs_is_transposed - and ( - lhs_quantizer.scaling_mode.is_1d_block_scaling() - or not is_fp8_gemm_with_all_layouts_supported() - ) + need_lhs_colwise = lhs_is_transposed and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() ) flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( @@ -130,12 +127,9 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) rhs_is_transposed = rhs.ndim - 1 in rhs_cdims - need_rhs_colwise = ( - not rhs_is_transposed - and ( - rhs_quantizer.scaling_mode.is_1d_block_scaling() - or not is_fp8_gemm_with_all_layouts_supported() - ) + need_rhs_colwise = not rhs_is_transposed and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() ) flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index eed2ed2a4f..2c5f027ba1 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -21,7 +21,6 @@ namespace transformer_engine { namespace jax { - static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { // Move the pointer to the next 256B aligned address return reinterpret_cast((reinterpret_cast(ptr) + 255) & @@ -78,8 +77,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); if (rowwise) { output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, @@ -91,8 +89,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); @@ -205,7 +202,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("use_split_accumulator"), FFI_CudaGraph_Traits); - Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 04c9f2774f..a87326e9fe 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -149,8 +149,7 @@ def __post_init__(self): f" {self.scale_inv.shape}" ) pad_width = tuple( - (0, a - b) - for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) ) # padding with the smallest number it can present @@ -158,8 +157,6 @@ def __post_init__(self): self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 ) - - def tree_flatten(self): """Flattens the tensor for JAX tree operations. From ae4828c8df4904423da3bbf0edbc9475cba6412d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 18 Jun 2025 22:51:20 +0000 Subject: [PATCH 15/30] fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul Signed-off-by: Alp Dener all unit tests passing on H100x8 node Signed-off-by: Alp Dener [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci linting fixes Signed-off-by: Alp Dener fixed batch dimension numbers Signed-off-by: Alp Dener fixed FP8 scale sharding rule when there are no FP8 scales Signed-off-by: Alp Dener added error message for unsupported Shardy partitioner Signed-off-by: Alp Dener fixed test tolerances for FP8 cases Signed-off-by: Alp Dener fixed shardy test skip cases Signed-off-by: Alp Dener --- .../encoder/test_model_parallel_encoder.py | 23 +- examples/jax/encoder/test_multigpu_encoder.py | 23 +- .../encoder/test_multiprocessing_encoder.py | 24 +- tests/jax/test_custom_call_compute.py | 4 +- tests/jax/test_distributed_layernorm_mlp.py | 116 ++++- tests/jax/test_layer.py | 20 +- transformer_engine/jax/cpp_extensions/gemm.py | 458 ++++++++++++------ transformer_engine/jax/dense.py | 57 ++- transformer_engine/jax/layernorm_dense.py | 21 +- transformer_engine/jax/layernorm_mlp.py | 30 +- 10 files changed, 567 insertions(+), 209 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index b2bd18205f..d733fa7097 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -25,6 +25,7 @@ assert_params_sufficiently_sharded, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -307,7 +308,9 @@ def train_and_evaluate(args): key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit(encoder.init, + in_shardings=in_shardings, + out_shardings=out_shardings) var_collect = jit_encoder_init(init_rngs, inputs, masks) # Check if params are sufficiently sharded after initialization @@ -344,11 +347,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit(train_step, + in_shardings=in_shardings, + out_shardings=out_shardings) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, + in_shardings=in_shardings, + out_shardings=out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -459,8 +466,8 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -510,6 +517,8 @@ def test_te_mxfp8_with_sp(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -517,6 +526,8 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -526,6 +537,8 @@ def test_te_delayed_scaling_fp8_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_with_sp_shardy(self): """Test Transformer Engine with DelayedScaling FP8 + SP""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index b6f4db1084..8c9751c620 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -21,6 +21,7 @@ from common import is_bf16_supported, get_fp8_recipe_from_name_string import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -288,7 +289,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit(encoder.init, + in_shardings=in_shardings, + out_shardings=out_shardings) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -312,11 +315,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit(train_step, + in_shardings=in_shardings, + out_shardings=out_shardings) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, + in_shardings=in_shardings, + out_shardings=out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -424,8 +431,8 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -458,6 +465,8 @@ def test_te_mxfp8(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -465,6 +474,8 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -476,6 +487,8 @@ def test_te_delayed_scaling_fp8_shardy(self): # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c7606c3ab0..35a0e766a4 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -28,8 +28,8 @@ get_fp8_recipe_from_name_string, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -412,7 +412,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit(encoder.init, + in_shardings=in_shardings, + out_shardings=out_shardings) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -432,11 +434,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit(train_step, + in_shardings=in_shardings, + out_shardings=out_shardings) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, + in_shardings=in_shardings, + out_shardings=out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -578,8 +584,8 @@ class TestEncoder(unittest.TestCase): """Encoder unittests""" def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): - """Run 3 epochs for testing""" - args = encoder_parser([]) + """Run 5 epochs for testing""" + args = encoder_parser(["--epochs", "5"]) num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 @@ -628,6 +634,8 @@ def test_te_mxfp8(self): assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) @@ -636,6 +644,8 @@ def test_te_bf16_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" ) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) @@ -646,6 +656,8 @@ def test_te_delayed_scaling_fp8_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" ) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4876aef99a..a50d5363ae 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -898,7 +898,7 @@ def _generate_gemm_input(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) - primitive_out = tex.gemm(x, w, contracting_dims) + primitive_out = tex.gemm(x, w, dimension_numbers=(contracting_dims, ((), ()))) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @@ -929,7 +929,7 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi primitive_out = tex.gemm( x, w, - contracting_dims=contracting_dims, + dimension_numbers=(contracting_dims, ((), ())), lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, rhs_quantizer=( quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index f16c84094d..3c27f0d472 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -44,6 +44,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) if is_mxfp8_supported: SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) +SUPPORTED_RECIPES_WITH_SHARDY = SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES DTYPES = [jnp.bfloat16, jnp.float16] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] @@ -74,6 +75,15 @@ def generate_fsdp_and_tp_configs(): return configs +def use_jax_fp8_gemm(enabled=False): + import os + + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDistributedLayernormMLP: def generate_inputs(self, input_shape, activation_type, use_bias, dtype): @@ -146,8 +156,17 @@ def layernorm_fp8_mlp_prim_func( ) def _test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy, + with_jax_gemm, ): + use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -208,20 +227,25 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - assert_allclose(multi_fwd, single_fwd, dtype=dtype) + fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn + bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): assert isinstance(single_grads[i], list) for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): assert_allclose( - m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" + m_grad, + s_grad, + dtype=bwd_test_type, + err_msg=f"multi_grads[{i}] is not close" ) else: assert_allclose( multi_grads[i], single_grads[i], - dtype=dtype, + dtype=bwd_test_type, err_msg=f"multi_grads[{i}] is not close", ) @@ -232,8 +256,16 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): self._test_layernorm_mlp_grad( mesh_config, @@ -243,6 +275,7 @@ def test_layernorm_mlp_grad( dtype, fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -251,19 +284,22 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): - # We don't test block scaling with Shardy because at the time of writing, - # it is not supported in JAX's scaled_matmul_stablehlo. + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the + # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=recipe.DelayedScaling(), + fp8_recipe=fp8_recipe, use_shardy=True, + with_jax_gemm=True, ) def _test_layernorm_mlp( @@ -276,7 +312,9 @@ def _test_layernorm_mlp( use_fp8, fp8_recipe, use_shardy, + with_jax_gemm, ): + use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -340,9 +378,9 @@ def _test_layernorm_mlp( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("use_shardy", [False, True]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -352,7 +390,8 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, fp8_recipe=None, - use_shardy=use_shardy, + use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -361,9 +400,10 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -374,4 +414,52 @@ def test_layernorm_mlp_layer_fp8( use_fp8=True, fp8_recipe=fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, + ) + + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + def test_layernorm_mlp_layer_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX dot_general. + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=True, + with_jax_gemm=True, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) + def test_layernorm_mlp_layer_fp8_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + ): + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the + # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=True, + fp8_recipe=fp8_recipe, + use_shardy=True, + with_jax_gemm=True, ) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 9f3da5094f..d59e130530 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -479,7 +479,6 @@ def generate_inputs(self, data_shape, dtype): @pytest.mark.parametrize("data_shape", DATA_SHAPE) @pytest.mark.parametrize("dtype", DTYPE) -@pytest.mark.parametrize("with_jax_gemm", (False, True)) @pytest.mark.parametrize("attrs", ATTRS) class BaseTester: """ @@ -488,39 +487,28 @@ class BaseTester: runner = BaseRunner - def use_jax_dot_for_gemm(self, enabled=False): - """Enable/disable TE custom cuBLAS GEMM primitive.""" - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - def test_forward(self, data_shape, dtype, with_jax_gemm, attrs): + def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype) - def test_backward(self, data_shape, dtype, with_jax_gemm, attrs): + def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_forward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): + def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_backward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): + def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9cd3c85619..807ca62713 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -15,7 +15,7 @@ from jax.sharding import NamedSharding, PartitionSpec import transformer_engine_jax as tex -from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams +from transformer_engine_jax import get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -41,7 +41,7 @@ "gemm_uses_jax_dot", "sanitize_dims", "get_non_contracting_dims", - "transpose_contracting_dims", + "transpose_dims", ] @@ -60,7 +60,7 @@ def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: dims_ = dims if isinstance(dims, Iterable) else (dims,) if len(dims_) == 0: return dims_ - return tuple(ndim + dim if dim < 0 else dim for dim in dims_) + return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None) def get_non_contracting_dims(ndim, contracting_dims): @@ -69,10 +69,13 @@ def get_non_contracting_dims(ndim, contracting_dims): return tuple(dim for dim in range(ndim) if dim not in contracting_dims) -def transpose_contracting_dims(ndim, contracting_dims): - """Compute the new dimension numbers for contracting dimensions after a transpose.""" - contracting_dims = sanitize_dims(ndim, contracting_dims) - return tuple(ndim - i - 1 for i in contracting_dims)[::-1] +def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1): + """Compute the new dimension numbers after transpose.""" + if len(dims_to_transpose) == 0: + return dims_to_transpose + flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis + transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis)) + return tuple(transposed_dims.index(dim) for dim in dims_to_transpose) def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: @@ -106,6 +109,7 @@ def _get_gemm_layout( def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): lhs_q = lhs rhs_q = rhs + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims @@ -120,9 +124,6 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_colwise=need_lhs_colwise, flatten_axis=flatten_axis, ) - # TODO: remove - # if isinstance(lhs_q, ScaledTensor2x): - # lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) @@ -138,8 +139,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_colwise=need_rhs_colwise, flatten_axis=flatten_axis, ) - # if isinstance(rhs_q, ScaledTensor2x): - # rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x) @@ -166,7 +166,7 @@ def abstract( bias, gelu_input, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -177,6 +177,7 @@ def abstract( # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) + contracting_dims, _ = dimension_numbers ( lhs_contracting_dims, rhs_contracting_dims, @@ -293,7 +294,7 @@ def lowering( bias, gelu_input, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -301,6 +302,7 @@ def lowering( use_split_accumulator, ): del out_dtype + contracting_dims, _ = dimension_numbers lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -340,7 +342,7 @@ def impl( bias, gelu_input, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -355,7 +357,7 @@ def impl( bias, gelu_input, out_dtype=out_dtype, - contracting_dims=contracting_dims, + dimension_numbers=dimension_numbers, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -369,7 +371,7 @@ def batcher( batched_args, batch_dims, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -378,12 +380,30 @@ def batcher( ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args - lhs_bdims, *_ = batch_dims + lhs_bdims, _, rhs_bdims, *_ = batch_dims + contracting_dims, batch_dims = dimension_numbers + arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) + arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) - # Output is batched like LHS only if LHS is batched and RHS is not - out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None,) - bias_bdims = (None,) # Bias is never batched - pre_gelu_bdims = (None,) # Pre-GeLU output, if exists, is batched like GEMM output + # Output is batched like the non-contracting batch dimensions of the LHS operand + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) + lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) + out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + + # Bias gradient is never batched + bias_bdims = (None,) + + # Pre-GeLU output, if exists, is batched like GEMM output + pre_gelu_bdims = (None,) if fuse_gelu and not grad: pre_gelu_bdims = out_bdims @@ -391,7 +411,7 @@ def batcher( GemmPrimitive.outer_primitive.bind( *batched_args, out_dtype=out_dtype, - contracting_dims=contracting_dims, + dimension_numbers=dimension_numbers, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -401,10 +421,171 @@ def batcher( (out_bdims, bias_bdims, pre_gelu_bdims), ) + @staticmethod + def _decompose_operand_specs(specs, contracting_dims, batch_dims): + ndims = len(specs) + cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) + + # Batch specs + bspecs = tuple(specs[i] for i in bdims) + + # Non-batch leading dimension specs + lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) + + # Non-batch contracting dimension specs + cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + + return bspecs, lspecs, cspecs + + @staticmethod + def _parse_operand_output_specs(arg_infos, dimension_numbers): + lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + ) + ( + (lhs_bspecs, lhs_lspecs, lhs_cspecs), + (rhs_bspecs, rhs_lspecs, rhs_cspecs), + ) = map( + GemmPrimitive._decompose_operand_specs, + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + + # Batched dimensions must have the same sharding + if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: + assert all( + lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) + ), ( + "cuBLAS GEMM operand batch dimensions must have the same sharding: " + f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + ) + + # Only one each of the non-batched leading dimensions and non-batched contracting + # dimensions can be sharded + lhs_ldims, rhs_ldims = map( + lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + (lhs_ndim, rhs_ndim), + (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), + ) + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( + lambda specs: tuple(spec for spec in specs if spec is not None), + (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), + ) + assert ( + len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1 + ), ( + "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " + f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." + ) + assert ( + len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1 + ), ( + "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " + f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + ) + + # Extract single leading and contracting dimension specs + (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + lambda specs: None if len(specs) == 0 else specs[0], + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + ) + + # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts + # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. + # 1. K1 == K2 != None and N == None + # LHS: (B, M, K) + # RHS: (B, None, K) + # OUT: (B, M, None) --(AR)-> (B, M, None) + # 2. K1 == K2 != None and M == N != None + # LHS: (B, M, K) + # RHS: (B, N, K)--(AG)->(B, None, K) + # OUT: (B, M, None) --(RS)--> (B, M, N) + # 3. M == N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, M, K)--(AG)->(B, None, None) + # OUT: (B, M, None) + # 4. M != N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, N, K)--(AG)->(B, N, None) + # OUT: (B, M, N) + reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec + all_reduce_output = reduce_flag and rhs_lspec is None + reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec + all_reduce_spec = reduce_scatter_spec = scatter_dim = None + + lhs_non_contracting_specs, rhs_non_contracting_specs = map( + lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + ) + out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) + if reduce_scatter_output: + # All-gather (if necessary) the non-batch non-contracting dimension of RHS + # (B, N, K) --(AG)-> (B, None, K) + # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) + rhs_spec = tuple( + rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) + ) + reduce_scatter_spec = lhs_cspec + scatter_dim = out_specs.index(rhs_lspec) + + elif all_reduce_output: + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + all_reduce_spec = lhs_cspec + else: + # All-gather (if necessary) the non-batch contracting dimensions + # (B, M, K) --(AG)-> (B, M, None) + # (B, N, K) --(AG)-> (B, N, None) + # (B, M, None) x (B, N, None)^T = (B, M, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] + for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Check if RHS non-contracting spec also appears in the LHS non-contracting specs + if rhs_lspec is not None and rhs_lspec in tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims + ): + # All-gather (if necessary) the non-batch non-contracting dimensions of RHS + # (B, N, None) --(AG)-> (B, None, None) + # (B, M, None) x (B, None, None)^T = (B, M, None) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_contracting_specs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs), + (out_specs, bias_specs, gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) + @staticmethod def infer_sharding_from_operands( out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -414,72 +595,29 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del out_dtype, scaling_mode, use_split_accumulator, result_infos + del out_dtype, scaling_mode, grad, use_split_accumulator, result_infos - # Check contracting dimensions - lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) - operand_ndims = (len(lhs_spec), len(rhs_spec)) - lhs_contracting_dims, rhs_contracting_dims = map( - sanitize_dims, operand_ndims, contracting_dims - ) - lhs_contracting_specs, rhs_contracting_specs = map( - lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], - (lhs_spec, rhs_spec), - (lhs_contracting_dims, rhs_contracting_dims), - ) - assert ( - len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1 - ), "cuBLAS GEMM operands can have only one sharded contracting dimension." - lhs_contracting_spec, rhs_contracting_spec = map( - lambda spec: None if len(spec) == 0 else spec[0], - (lhs_contracting_specs, rhs_contracting_specs), + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) ) - assert ( - lhs_contracting_spec == rhs_contracting_spec - ), "cuBLAS GEMM operands must have the same sharding in contracting dimensions." + out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Sanity check leading dimensions, allow for simultaneous batch and sequence sharding - lhs_leading_dims, rhs_leading_dims = map( - get_non_contracting_dims, operand_ndims, (lhs_contracting_dims, rhs_contracting_dims) - ) - lhs_leading_specs, rhs_leading_specs = map( - lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], - (lhs_spec, rhs_spec), - (lhs_leading_dims, rhs_leading_dims), - ) - assert len(lhs_leading_specs) <= 1 and len(rhs_leading_specs) <= 1, ( - "cuBLAS GEMM operands cannot have more than one sharded leading dimensions. This error " - "usually means a sequence-parallel operand was not all-gathered before the GEMM op." - ) + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) - # Determine output sharding - lhs_leading_spec, rhs_leading_spec = map( - lambda spec: None if len(spec) == 0 else spec[0], (lhs_leading_specs, rhs_leading_specs) - ) - out_spec = (lhs_leading_spec, rhs_leading_spec) - if operand_ndims[0] > 2 and operand_ndims[1] == 2: - # Restore batch dimensions/sharding to the output - out_spec = (*lhs_leading_specs, rhs_leading_spec) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) - - # Bias gradient sharding inherits the RHS contracting spec - bias_spec = (None,) - if fuse_bias and grad: - bias_spec = (rhs_contracting_spec,) - bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - - # Pre-GeLU sharding matches output sharding - pre_gelu_spec = (None,) - if fuse_gelu and not grad: - pre_gelu_spec = out_spec - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) - return (out_sharding, bias_sharding, pre_gelu_sharding) + return [out_sharding, dbias_sharding, pre_gelu_sharding] @staticmethod def partition( out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -489,51 +627,96 @@ def partition( arg_infos, result_infos, ): - out_shardings = GemmPrimitive.infer_sharding_from_operands( - out_dtype, - contracting_dims, - scaling_mode, - fuse_bias, - fuse_gelu, - grad, - use_split_accumulator, - mesh, - arg_infos, - result_infos, + del result_infos + + ( + (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), + (out_specs, dbias_specs, pre_gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) + + # Assemble argument shardings + # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + none_sharding = NamedSharding(mesh, PartitionSpec(None)) + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + arg_shardings = ( + lhs_sharding, + lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_sharding, + rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, ) - output_spec = out_shardings[0].spec - # Operand shardings are already guarded with asserts so leave them unchanged here - lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) - lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec)) - rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + # Discard bias input spec if there is no bias fusion + if not fuse_bias: + bias_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),) + + # Discard pre-GeLU input spec if there is no GeLU fusion + if not fuse_gelu: + gelu_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + + # Assemble output shardings + out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] + + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) + + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) + + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + outputs = GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + dimension_numbers=dimension_numbers, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) - # Any distributed scales (e.g. MXFP8) need to be gathered - scale_sharding = NamedSharding(mesh, PartitionSpec(None)) + # All-Reduce/Reduce-Scatter GEMM output + if all_reduce_spec is not None: + outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) + elif reduce_scatter_spec is not None: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum_scatter( + outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) - # Bias has to be sharded same as the trailing dimension of the GEMM output - bias_spec = (None,) - if fuse_bias and not grad: - bias_spec = (output_spec[-1],) - bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + return outputs - # Pre-GeLU output has to be sharded same as the GEMM output - pre_gelu_spec = (None,) - if fuse_gelu and grad: - pre_gelu_spec = output_spec - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + return mesh, _sharded_impl, out_shardings, arg_shardings - arg_shardings = ( - lhs_sharding, - scale_sharding, - rhs_sharding, - scale_sharding, - bias_sharding, - pre_gelu_sharding, + @staticmethod + def shardy_sharding_rule(*args, **kwargs): + del args, kwargs + raise NotImplementedError( + "TE cuBLAS GEMM custom op does not support the Shardy partitioner. You can disable the " + "custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE=\"^(?!GemmPrimitive$).+$\"` in the " + "environment, which will make GEMM operations in TE will execute with native " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` calls." ) - return mesh, GemmPrimitive.impl, out_shardings, arg_shardings - register_primitive(GemmPrimitive) @@ -550,7 +733,7 @@ def _te_gemm( gelu_input: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, @@ -562,8 +745,10 @@ def _te_gemm( lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING + contracting_dims, batch_dims = dimension_numbers lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -581,7 +766,8 @@ def _te_gemm( lhs_data = lhs_q.data lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": - lhs_cdims = transpose_contracting_dims(lhs_q.ndim, lhs_cdims) + lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -598,7 +784,8 @@ def _te_gemm( rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": - rhs_cdims = transpose_contracting_dims(rhs_q.ndim, rhs_cdims) + rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -615,7 +802,7 @@ def _te_gemm( bias, gelu_input, out_dtype=out_dtype, - contracting_dims=(lhs_cdims, rhs_cdims), + dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -810,9 +997,11 @@ def _calculate_remaining_shape(shape, contracting_dims): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = transpose_contracting_dims(lhs.data.ndim, lhs_contract) + lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) + lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": - rhs_contract = transpose_contracting_dims(rhs.data.ndim, rhs_contract) + rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) + rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -911,12 +1100,6 @@ def _jax_gemm_fp8_impl(lhs, rhs): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): - # if isinstance(lhs_q, ScaledTensor2x): - # lhs_is_transposed = lhs.ndim - 1 not in sanitize_dims(lhs.ndim, contracting_dims[0]) - # lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() - # if isinstance(rhs_q, ScaledTensor2x): - # rhs_is_transposed = rhs.ndim - 1 in sanitize_dims(rhs.ndim, contracting_dims[1]) - # rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( @@ -933,7 +1116,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, **kwargs, @@ -950,10 +1133,11 @@ def gemm( Object for down-casting the LHS operand for quantized GEMM. rhs_quantizer: Quantizer, default = None Object for down-casting the RHS operand for quantized GEMM. - contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (-2, )) - Tuple of two sequences representing the contracting dimensions. The first sequence - represents the contracting dimensions of the LHS operand, and the second sequence - represents the contracting dimensions of the RHS operand. + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]], default = (((-1, ), (0, )), ((), ())) + Tuple of two tuples of sequences representing the contracting and batched dimensions, + respectively. The first sequence in each tuple represents the contracting/batched + dimensions of the LHS operand, and the second sequence represents the contracting/batched + dimensions of the RHS operand. bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1010,14 +1194,14 @@ def gemm( "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) - return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + return _jax_gemm(lhs, rhs, dimension_numbers[0], lhs_quantizer, rhs_quantizer) outputs = _te_gemm( lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, - contracting_dims=contracting_dims, + dimension_numbers=dimension_numbers, **kwargs, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index b2721915a1..6ed91f4319 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -30,6 +30,7 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -43,6 +44,7 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + batch_first: Assume that X is batched in the first dimension. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -51,17 +53,19 @@ def dense( # Remove when tex.quantize() can handle quantizer=None if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims) + output = tex.gemm(x, kernel, dimension_numbers=(contracting_dims, ((), ()))) if bias is not None: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) + output = _dense( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -75,23 +79,46 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types + batch_first: Assume that X is batched in the first dimension. Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +def _dense_fwd_rule( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set +): """Forward pass rule for dense layer transformation. Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims + x_contracting_dims, k_contracting_dims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + + # Check supported input layout + x_is_transposed = x.ndim - 1 not in x_contracting_dims + k_is_transposed = kernel.ndim - 1 in k_contracting_dims + assert ( + not x_is_transposed and not k_is_transposed + ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + + # Determine X batch dimension + # - If `batch_first=True` -> (batch, leading..., contracting...) + # - Otherwise -> (leading..., batch, contracting...) + # NOTE: Always assume a single batch dimension + x_bdim = None + num_cdims = len(x_contracting_dims) + if x.ndim >= num_cdims + 2: + # Assume X is batched if it has at least +2 dimensions more than the number of contracting + # dimensions. + x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) @@ -114,7 +141,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, output = tex.gemm( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -131,20 +158,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( casted_x_lhs, casted_kernel_rhs, @@ -153,8 +179,13 @@ def _dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) = ctx + fwd_x_contracting_dims, fwd_k_contracting_dims = map( + tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims + ) + casted_grad, dbias = tex.quantize_dbias( grad, is_dbias=use_bias, @@ -175,7 +206,7 @@ def _dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, - contracting_dims=(g_contracting_dim, k_contracting_dim), + dimension_numbers=((g_contracting_dim, k_contracting_dim), ((x_bdim,), ())), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -188,7 +219,7 @@ def _dense_bwd_rule( wgrad = tex.gemm( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), - contracting_dims=(x_contracting_dim, g_contracting_dim), + dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim, ))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 816436a00e..5608b24c5f 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -37,6 +37,7 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -57,6 +58,7 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix + batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers for different tensor types Returns: @@ -80,6 +82,7 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -94,6 +97,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -108,6 +112,7 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], + batch_first: bool, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -127,6 +132,7 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding + batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers Returns: @@ -144,6 +150,7 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -161,6 +168,7 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -178,6 +186,10 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] + x_bdim = None + if x.ndim > 2: + x_bdim = 0 if batch_first else x.ndim - 2 + x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -205,7 +217,7 @@ def _layernorm_dense_fwd_rule( output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -229,6 +241,7 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) return output, ctx @@ -241,6 +254,7 @@ def _layernorm_dense_bwd_rule( layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument kernel_axes, + batch_first, # pylint: disable=unused-argument ctx, grad, ): @@ -270,6 +284,7 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( @@ -293,7 +308,7 @@ def _layernorm_dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - contracting_dims=(g_constracting_dim, k_constracting_dim), + dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -306,7 +321,7 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - contracting_dims=(x_constracting_dim, g_constracting_dim), + dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim, ))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fb9a3163b9..53decc3c2a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -48,6 +48,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + batch_first: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -79,6 +80,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -124,12 +126,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -149,6 +152,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + batch_first: bool, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -174,6 +178,7 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of quantizer sets Returns: @@ -198,6 +203,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output @@ -222,6 +228,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -254,6 +261,10 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + x_bdim = None + if x.ndim > 2: + x_bdim = 0 if batch_first else x.ndim - 2 + use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -280,7 +291,7 @@ def _layernorm_mlp_fwd_rule( dot_1_output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) @@ -315,7 +326,7 @@ def _layernorm_mlp_fwd_rule( dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, ) @@ -345,6 +356,7 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) return dot_2_output, ctx @@ -362,6 +374,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, ctx, grad, ): @@ -378,7 +391,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first ( x, mu, @@ -397,6 +410,7 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -422,7 +436,7 @@ def _layernorm_mlp_bwd_rule( dgrad_2 = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, - contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + dimension_numbers=((g_contracting_dims_2, k_contracting_dims_2), ((x_bdim,), ())), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -436,7 +450,7 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, g_contracting_dims), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim, ), (x_bdim, ))), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -463,7 +477,7 @@ def _layernorm_mlp_bwd_rule( dgrad_1 = tex.gemm( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, - contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + dimension_numbers=((g_contracting_dims_1, k_contracting_dims_1), ((x_bdim,), ())), ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -473,7 +487,7 @@ def _layernorm_mlp_bwd_rule( wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, g_contracting_dims), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim, ))), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From ddaaab95dc32f7f9e38da389c4ccc1c957cb49af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:40:05 +0000 Subject: [PATCH 16/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../encoder/test_model_parallel_encoder.py | 33 ++++++++++--------- examples/jax/encoder/test_multigpu_encoder.py | 33 ++++++++++--------- .../encoder/test_multiprocessing_encoder.py | 33 ++++++++++--------- tests/jax/test_distributed_layernorm_mlp.py | 6 ++-- transformer_engine/jax/cpp_extensions/gemm.py | 10 ++---- transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/layernorm_dense.py | 2 +- transformer_engine/jax/layernorm_mlp.py | 4 +-- 8 files changed, 65 insertions(+), 58 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index d733fa7097..00203a4537 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -308,9 +308,9 @@ def train_and_evaluate(args): key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) # Check if params are sufficiently sharded after initialization @@ -347,15 +347,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -517,8 +517,9 @@ def test_te_mxfp8_with_sp(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -526,8 +527,9 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -537,8 +539,9 @@ def test_te_delayed_scaling_fp8_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_with_sp_shardy(self): """Test Transformer Engine with DelayedScaling FP8 + SP""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 8c9751c620..3d7fab0115 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -289,9 +289,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -315,15 +315,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -465,8 +465,9 @@ def test_te_mxfp8(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -474,8 +475,9 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -487,8 +489,9 @@ def test_te_delayed_scaling_fp8_shardy(self): # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 35a0e766a4..b5d03c0796 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -412,9 +412,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -434,15 +434,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -634,8 +634,9 @@ def test_te_mxfp8(self): assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) @@ -644,8 +645,9 @@ def test_te_bf16_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" ) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) @@ -656,8 +658,9 @@ def test_te_delayed_scaling_fp8_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" ) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 3c27f0d472..a093ff5d91 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -239,7 +239,7 @@ def _test_layernorm_mlp_grad( m_grad, s_grad, dtype=bwd_test_type, - err_msg=f"multi_grads[{i}] is not close" + err_msg=f"multi_grads[{i}] is not close", ) else: assert_allclose( @@ -400,7 +400,9 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES) + @pytest_parametrize_wrapper( + "fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES + ) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 807ca62713..fc84e21947 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -475,15 +475,11 @@ def _parse_operand_output_specs(arg_infos, dimension_numbers): lambda specs: tuple(spec for spec in specs if spec is not None), (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), ) - assert ( - len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1 - ), ( + assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." ) - assert ( - len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1 - ), ( + assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." ) @@ -712,7 +708,7 @@ def shardy_sharding_rule(*args, **kwargs): del args, kwargs raise NotImplementedError( "TE cuBLAS GEMM custom op does not support the Shardy partitioner. You can disable the " - "custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE=\"^(?!GemmPrimitive$).+$\"` in the " + 'custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE="^(?!GemmPrimitive$).+$"` in the ' "environment, which will make GEMM operations in TE will execute with native " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` calls." ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 6ed91f4319..0d4a1b7524 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -219,7 +219,7 @@ def _dense_bwd_rule( wgrad = tex.gemm( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), - dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim, ))), + dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim,))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 5608b24c5f..62fa2cfcd2 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -321,7 +321,7 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim, ))), + dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim,))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 53decc3c2a..5d129aa54d 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -450,7 +450,7 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim, ), (x_bdim, ))), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -487,7 +487,7 @@ def _layernorm_mlp_bwd_rule( wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim, ))), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From 44e5b81cfe2e20a74fa7dc6df0aadc8ad1f8c79d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 25 Jun 2025 21:48:30 +0000 Subject: [PATCH 17/30] moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly Signed-off-by: Alp Dener --- examples/jax/encoder/test_model_parallel_encoder.py | 6 ++---- examples/jax/encoder/test_multigpu_encoder.py | 4 +--- examples/jax/encoder/test_multiprocessing_encoder.py | 4 +--- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 00203a4537..490d062fbc 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -66,12 +66,10 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) - x = x.reshape(x.shape[0], -1) - if self.enable_seq_paral: # Trigger all-gather to collect a complete tensor alone sequence on each device. x = jax.lax.with_sharding_constraint( - x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) + x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None) ) x = te_flax.DenseGeneral( @@ -86,7 +84,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) return x diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 3d7fab0115..be2e9795ad 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -56,13 +56,11 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) - x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256)(x) x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) return x diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index b5d03c0796..5b7f2e2c89 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -66,8 +66,6 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) - x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), @@ -80,7 +78,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) return x From b8ca0b1dd2715d5580039171a705256c67f7d2a6 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 27 Jun 2025 16:37:46 +0000 Subject: [PATCH 18/30] added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding Signed-off-by: Alp Dener --- .../encoder/test_model_parallel_encoder.py | 6 +- examples/jax/encoder/test_multigpu_encoder.py | 18 +-- .../encoder/test_multiprocessing_encoder.py | 4 +- transformer_engine/jax/cpp_extensions/gemm.py | 66 +++++++++- transformer_engine/jax/quantize/helper.py | 121 +++++++++++++++++- transformer_engine/jax/quantize/tensor.py | 26 +--- 6 files changed, 204 insertions(+), 37 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 490d062fbc..00203a4537 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -66,10 +66,12 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + if self.enable_seq_paral: # Trigger all-gather to collect a complete tensor alone sequence on each device. x = jax.lax.with_sharding_constraint( - x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None) + x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) ) x = te_flax.DenseGeneral( @@ -84,7 +86,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) + x = nn.Dense(features=2)(x) return x diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index be2e9795ad..44cafa7396 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -56,11 +56,13 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + x = te_flax.DenseGeneral(features=256)(x) x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) + x = nn.Dense(features=2)(x) return x @@ -436,7 +438,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -444,7 +446,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -452,7 +454,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -460,7 +462,7 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf( @@ -470,7 +472,7 @@ def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf( @@ -482,7 +484,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @@ -496,7 +498,7 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 5b7f2e2c89..b5d03c0796 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -66,6 +66,8 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), @@ -78,7 +80,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) + x = nn.Dense(features=2)(x) return x diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index fc84e21947..16b8553706 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -31,6 +31,8 @@ QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, + apply_padding_to_scale_inv, + remove_padding_from_scale_inv, ) from .misc import get_padded_spec @@ -153,7 +155,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12) + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14) inner_primitive = None outer_primitive = None @@ -167,13 +169,15 @@ def abstract( gelu_input, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - del use_split_accumulator + del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) @@ -295,13 +299,15 @@ def lowering( gelu_input, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - del out_dtype + del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype contracting_dims, _ = dimension_numbers lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -343,12 +349,36 @@ def impl( gelu_input, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): + lhs_cdims, rhs_cdims = map( + sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0] + ) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + is_colwise=lhs_quantized_colwise, + flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, + scaling_mode, + rhs.shape, + is_colwise=rhs_quantized_colwise, + flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + ) + outputs = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, @@ -358,6 +388,8 @@ def impl( gelu_input, out_dtype=out_dtype, dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -372,6 +404,8 @@ def batcher( batch_dims, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -412,6 +446,8 @@ def batcher( *batched_args, out_dtype=out_dtype, dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -582,6 +618,8 @@ def _parse_operand_output_specs(arg_infos, dimension_numbers): def infer_sharding_from_operands( out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -591,7 +629,8 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del out_dtype, scaling_mode, grad, use_split_accumulator, result_infos + del out_dtype, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, grad, + del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) @@ -614,6 +653,8 @@ def infer_sharding_from_operands( def partition( out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -678,6 +719,8 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): gelu_input, out_dtype=out_dtype, dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -722,6 +765,15 @@ def gemm_uses_jax_dot() -> bool: return not GemmPrimitive.enabled() +def _get_scale_inv_without_padding(scaled_tensor): + return remove_padding_from_scale_inv( + scaled_tensor.scale_inv, + scaled_tensor.scaling_mode, + scaled_tensor.data.shape, + is_colwise=scaled_tensor.is_colwise, + flatten_axis=scaled_tensor.flatten_axis, + ) + def _te_gemm( lhs: Union[jax.Array, ScaledTensor], rhs: Union[jax.Array, ScaledTensor], @@ -760,7 +812,7 @@ def _te_gemm( lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() scaling_mode = lhs_q.scaling_mode lhs_data = lhs_q.data - lhs_scale_inv = lhs_q.scale_inv + lhs_scale_inv = _get_scale_inv_without_padding(lhs_q) if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) @@ -778,7 +830,7 @@ def _te_gemm( f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." ) rhs_data = rhs_q.data - rhs_scale_inv = rhs_q.scale_inv + rhs_scale_inv = _get_scale_inv_without_padding(rhs_q) if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) @@ -799,6 +851,8 @@ def _te_gemm( gelu_input, out_dtype=out_dtype, dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), + lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, + rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c0617eafbb..3c217d65fb 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -9,7 +9,9 @@ """ from contextlib import contextmanager from enum import Enum -from typing import Optional, Tuple, Dict, Union +from typing import Optional, Tuple, Dict, Union, Sequence +from functools import reduce +import operator import jax import jax.numpy as jnp @@ -29,6 +31,8 @@ "is_fp8_available", "update_collections", "get_delayed_scaling", + "apply_padding_to_scale_inv", + "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", ] @@ -471,4 +475,119 @@ def update_collections(new: Collection, original: Collection) -> Collection: return new_coll +def remove_padding_from_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Slice padding out of padded inverse scale factors. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Inverse scale factor without padding. + """ + # Get expected unpadded scale shape and check if inverse scale already matches + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: + return scale_inv + + # Get the padded scale shape and make sure inverse scale matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=True, + flatten_axis=flatten_axis, + ) + assert scale_inv.shape == padded_scale_shape, ( + f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got " + f"{scale_inv.shape} instead." + ) + + # Reshape scale inverse to 2D in two stages to preserve the flatten axis + padded_scale_shape_2d = ( + reduce(operator.mul, padded_scale_shape[ : flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis : ]) + ) + scale_inv_2d = jnp.reshape( + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis : ])), + padded_scale_shape_2d + ) + + # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape + unpadded_scale_shape_2d = ( + reduce(operator.mul, unpadded_scale_shape[ : flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis : ]) + ) + scale_inv_2d_unpadded = jnp.asarray( + scale_inv_2d[ : unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + ) + + # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis + scale_inv_unpadded = jnp.reshape( + jnp.reshape( + scale_inv_2d_unpadded, + (*unpadded_scale_shape[: flatten_axis], scale_inv_2d_unpadded.shape[1]) + ), + unpadded_scale_shape + ) + return scale_inv_unpadded + + +def apply_padding_to_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Pad the scale inverse with zeros to match the necessary padded shape for this scaling + mode. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Padded inverse scale factor. + """ + # Get the expected padded scale shape and check if inverse scale already matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape: + return scale_inv + + # Get the expected unpadded scale shape and make sure inverse scales match + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) + + # Pad the scales with the lowest representable value (2^-127) and return + pad_width = tuple( + (0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape) + ) + return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index a87326e9fe..d454f8f43a 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -17,6 +17,7 @@ from transformer_engine_jax import QuantizeLayout +from .helper import apply_padding_to_scale_inv from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap from ..sharding import ( @@ -136,26 +137,13 @@ def __post_init__(self): self.scale_inv = jnp.empty((0,), dtype=jnp.float32) else: - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis + self.scale_inv = apply_padding_to_scale_inv( + self.scale_inv, + self.scaling_mode, + self.data.shape, + is_colwise=self.is_colwise, + flatten_axis=self.flatten_axis, ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) - ) - - # padding with the smallest number it can present - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 - ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 718758259df2bb0e4fa0ac83ebb31132d0021244 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Jun 2025 16:40:13 +0000 Subject: [PATCH 19/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 13 ++++++--- transformer_engine/jax/quantize/helper.py | 28 ++++++++----------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 16b8553706..0f43ea47b8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -357,9 +357,7 @@ def impl( grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map( - sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0] - ) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0]) lhs_transposed, rhs_transposed = _get_gemm_layout( (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) ) @@ -629,7 +627,13 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del out_dtype, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, grad, + del ( + out_dtype, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + grad, + ) del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( @@ -774,6 +778,7 @@ def _get_scale_inv_without_padding(scaled_tensor): flatten_axis=scaled_tensor.flatten_axis, ) + def _te_gemm( lhs: Union[jax.Array, ScaledTensor], rhs: Union[jax.Array, ScaledTensor], diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 3c217d65fb..122265ea27 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -497,10 +497,7 @@ def remove_padding_from_scale_inv( """ # Get expected unpadded scale shape and check if inverse scale already matches unpadded_scale_shape = scaling_mode.get_scale_shape( - data_shape, - is_colwise=is_colwise, - is_padded=False, - flatten_axis=flatten_axis + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: return scale_inv @@ -519,30 +516,30 @@ def remove_padding_from_scale_inv( # Reshape scale inverse to 2D in two stages to preserve the flatten axis padded_scale_shape_2d = ( - reduce(operator.mul, padded_scale_shape[ : flatten_axis]), - reduce(operator.mul, padded_scale_shape[flatten_axis : ]) + reduce(operator.mul, padded_scale_shape[:flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis:]), ) scale_inv_2d = jnp.reshape( - jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis : ])), - padded_scale_shape_2d + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])), + padded_scale_shape_2d, ) # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape unpadded_scale_shape_2d = ( - reduce(operator.mul, unpadded_scale_shape[ : flatten_axis]), - reduce(operator.mul, unpadded_scale_shape[flatten_axis : ]) + reduce(operator.mul, unpadded_scale_shape[:flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis:]), ) scale_inv_2d_unpadded = jnp.asarray( - scale_inv_2d[ : unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] ) # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis scale_inv_unpadded = jnp.reshape( jnp.reshape( scale_inv_2d_unpadded, - (*unpadded_scale_shape[: flatten_axis], scale_inv_2d_unpadded.shape[1]) + (*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]), ), - unpadded_scale_shape + unpadded_scale_shape, ) return scale_inv_unpadded @@ -585,9 +582,8 @@ def apply_padding_to_scale_inv( ) # Pad the scales with the lowest representable value (2^-127) and return - pad_width = tuple( - (0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape) - ) + pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME From 0230a5e8089c5d1c3a2d8e666cdbcb7dcec13778 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 8 Jul 2025 22:35:46 +0000 Subject: [PATCH 20/30] updated shardy rules for all custom ops to decouple block scale rules from their tensors Signed-off-by: Alp Dener --- .../run_test_multiprocessing_encoder.sh | 2 +- .../encoder/test_model_parallel_encoder.py | 50 ++++--- examples/jax/encoder/test_multigpu_encoder.py | 36 ++--- .../encoder/test_multiprocessing_encoder.py | 35 +++-- tests/jax/test_distributed_layernorm_mlp.py | 36 +++-- .../jax/cpp_extensions/activation.py | 36 +++-- transformer_engine/jax/cpp_extensions/gemm.py | 126 ++++++++++++++++-- .../jax/cpp_extensions/normalization.py | 12 +- .../jax/cpp_extensions/quantization.py | 12 +- .../jax/quantize/scaling_modes.py | 61 +++++---- 10 files changed, 265 insertions(+), 141 deletions(-) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index a21d5ecb57..2a1ac0f8fa 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do LOG_FILE="${TEST_CASE}_gpu_${i}.log" # Run pytest and redirect stdout and stderr to the log file - pytest -c "$TE_PATH/tests/jax/pytest.ini" \ + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ --num-process=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 00203a4537..17e0e53175 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -473,7 +473,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -481,7 +481,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -489,14 +489,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -505,7 +505,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -514,34 +514,25 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_delayed_scaling_fp8_with_sp_shardy(self): """Test Transformer Engine with DelayedScaling FP8 + SP""" self.args.enable_shardy = True @@ -549,9 +540,30 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.8 - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf(tex.gemm_uses_jax_dot(), + "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.43 and actual[1] > 0.8 + + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf(tex.gemm_uses_jax_dot(), + "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + def test_te_mxfp8_with_sp_shardy(self): + """Test Transformer Engine with MXFP8 + SP""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.43 and actual[1] > 0.8 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 44cafa7396..08399e41d6 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -438,7 +438,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 + assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -446,7 +446,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 + assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -454,7 +454,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 + assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -462,43 +462,43 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 + assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 + assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 - - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.536 and actual[1] > 0.73 + assert actual[0] < 0.53 and actual[1] > 0.74 + + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf(tex.gemm_uses_jax_dot(), + "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.53 and actual[1] > 0.74 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index b5d03c0796..4af9a89d7d 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -607,7 +607,7 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def test_te_bf16(self): """Test Transformer Engine with BF16""" result = self.exec(False, None) - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.43 and result[1] > 0.8 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -615,7 +615,7 @@ def test_te_bf16(self): def test_te_delayed_scaling_fp8(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling") - assert result[0] < 0.506 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.8 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -623,7 +623,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.507 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.8 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -631,40 +631,39 @@ def test_te_current_scaling_fp8(self): def test_te_mxfp8(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling") - assert result[0] < 0.505 and result[1] > 0.754 + assert result[0] < 0.43 and result[1] > 0.8 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.43 and result[1] > 0.8 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" ) - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.506 and result[1] > 0.753 - - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + assert result[0] < 0.43 and result[1] > 0.8 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" ) - @unittest.skipIf( - not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" - ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.507 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.8 + + @unittest.skipIf( + not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" + ) + @unittest.skipIf(tex.gemm_uses_jax_dot(), + "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) + assert result[0] < 0.43 and result[1] > 0.8 if __name__ == "__main__": diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index a093ff5d91..c958326472 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -44,7 +44,6 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) if is_mxfp8_supported: SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) -SUPPORTED_RECIPES_WITH_SHARDY = SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES DTYPES = [jnp.bfloat16, jnp.float16] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] @@ -284,13 +283,13 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm, ): - # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with - # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the - # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. + if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") self._test_layernorm_mlp_grad( mesh_config, activation_type, @@ -299,7 +298,7 @@ def test_layernorm_mlp_grad_shardy( dtype, fp8_recipe=fp8_recipe, use_shardy=True, - with_jax_gemm=True, + with_jax_gemm=with_jax_gemm, ) def _test_layernorm_mlp( @@ -400,9 +399,7 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper( - "fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES - ) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm @@ -424,11 +421,10 @@ def test_layernorm_mlp_layer_fp8( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm ): - # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with - # native JAX dot_general. self._test_layernorm_mlp( mesh_config, activation_type, @@ -438,7 +434,7 @@ def test_layernorm_mlp_layer_shardy( use_fp8=False, fp8_recipe=None, use_shardy=True, - with_jax_gemm=True, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -447,13 +443,13 @@ def test_layernorm_mlp_layer_shardy( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm ): - # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with - # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the - # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. + if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") self._test_layernorm_mlp( mesh_config, activation_type, @@ -463,5 +459,5 @@ def test_layernorm_mlp_layer_fp8_shardy( use_fp8=True, fp8_recipe=fp8_recipe, use_shardy=True, - with_jax_gemm=True, + with_jax_gemm=with_jax_gemm, ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 341dcb0c8c..7c201e84f3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -415,37 +415,35 @@ def shardy_sharding_rule( result_types, ): del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - + prefix = "ActLuPrimitive_" x_rank = len(value_types[0].shape) scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2 + x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 ) - x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) + x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}", ) out = (*x_axes[:-2], x_axes[-1]) scale_inv = scale_rules.rowwise_rule - colwise_scale_inv = scale_rules.colwise_rule + colwise_out = (prefix + "out_colwise", ) + colwise_scale_inv = (prefix + "scale_inv_colwise", ) if is_2x: + colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple( multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) ) else: colwise_out = out - else: - colwise_out = ("j",) - colwise_scale_inv = ("k",) # amax is always a unit tensor. - amax = ("l",) + amax = (prefix + "amax", ) return SdyShardingRule( ( x_axes, - "…1", + ("…1", ), ), (out, colwise_out, scale_inv, colwise_scale_inv, amax), - **scale_rules.factor_sizes, ) @@ -890,28 +888,26 @@ def shardy_sharding_rule( result_types, ): del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - - x_rank = len(value_types[1].shape) + prefix = "BaseDActLuDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2 + len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec + dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise", ) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: - colwise_out = tuple(x_axes) - else: - colwise_out = ("j",) + colwise_out = out - dbias = x_axes[-2:] if is_dbias else ("k",) - amax = ("…4",) + dbias = x_axes[-2:] if is_dbias else (prefix + "dbias", ) + amax = (prefix + "amax", ) return SdyShardingRule( - (("…0",), tuple(x_axes), ("…2",)), + (dz_axes, x_axes, ("…2", )), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), - **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0f43ea47b8..e12e38f5e6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -13,6 +13,7 @@ import jax.numpy as jnp from jax import dtypes from jax.sharding import NamedSharding, PartitionSpec +from jax.experimental.custom_partitioning import SdyShardingRule import transformer_engine_jax as tex from transformer_engine_jax import get_num_compute_streams @@ -179,13 +180,52 @@ def abstract( ): del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + def _dims_are_consecutive(dims): + if len(dims) <= 1: + return True + return sorted(dims) == list(range(min(dims), max(dims) + 1)) + # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) - contracting_dims, _ = dimension_numbers + contracting_dims, batch_dims = dimension_numbers + ( lhs_contracting_dims, rhs_contracting_dims, ) = map(sanitize_dims, operand_ndims, contracting_dims) + assert _dims_are_consecutive(lhs_contracting_dims), ( + "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " + f"{lhs_contracting_dims}." + ) + assert _dims_are_consecutive(rhs_contracting_dims), ( + "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " + f"{rhs_contracting_dims}." + ) + + ( + lhs_batch_dims, + rhs_batch_dims, + ) = map(sanitize_dims, operand_ndims, batch_dims) + assert _dims_are_consecutive(lhs_batch_dims), ( + "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got " + f"{lhs_batch_dims}." + ) + assert _dims_are_consecutive(rhs_batch_dims), ( + "cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got " + f"{rhs_batch_dims}." + ) + if len(lhs_batch_dims) == 0: + assert len(rhs_batch_dims) == 0, ( + "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." + ) + elif len(rhs_batch_dims) != 0: + assert ( + all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) + and all(bdim in rhs_contracting_dims for bdim in rhs_batch_dims) + ), ( + "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." + ) + lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), @@ -751,15 +791,85 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): return mesh, _sharded_impl, out_shardings, arg_shardings @staticmethod - def shardy_sharding_rule(*args, **kwargs): - del args, kwargs - raise NotImplementedError( - "TE cuBLAS GEMM custom op does not support the Shardy partitioner. You can disable the " - 'custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE="^(?!GemmPrimitive$).+$"` in the ' - "environment, which will make GEMM operations in TE will execute with native " - "`jax.lax.dot_general` and `jax.nn.scaled_matmul` calls." + def shardy_sharding_rule( + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + operand_types, + result_types, + ): + del out_dtype, grad, use_split_accumulator, mesh, result_types + prefix = "GemmPrimitive_" + + def _generate_operand_rules(name, ndim, cdims, bdims): + specs = [] + ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) + for i in range(ndim): + dim_name = None + if i in bdims: + dim_idx = bdims.index(i) if len(bdims) > 1 else "" + dim_name = f"b{dim_idx}" + elif i in cdims: + dim_idx = cdims.index(i) if len(cdims) > 1 else "" + dim_name = f"k{dim_idx}" + else: + dim_idx = ldims.index(i) if len(ldims) > 1 else "" + dim_name = f"{name}_l{dim_idx}" + specs.append(prefix + dim_name) + return specs + + lhs, _, rhs, *_ = operand_types + operand_ndims = (len(lhs.shape), len(rhs.shape)) + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( + lambda dims: map(sanitize_dims, operand_ndims, dims), + dimension_numbers, + ) + lhs_specs, rhs_specs = map( + _generate_operand_rules, + ("lhs", "rhs"), + operand_ndims, + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), ) + lhs_scale_specs = ("…1", ) + rhs_scale_specs = ("…2", ) + if scaling_mode.is_1d_block_scaling(): + # Shardy rules for MXFP8 scales cannot be related to the operands because of the + # global-unpadding and local-padding workflow. This can potentially insert expensive + # re-shards in the partition call later if the scales are not already sharded correctly. + lhs_scale_specs, rhs_scale_specs = map( + lambda specs : tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), + (lhs_specs, rhs_specs) + ) + lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) + rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) + out_spec = (*lhs_non_cspec, *rhs_non_cspec) + bias_spec = rhs_non_cspec if fuse_bias else ("…4", ) + gelu_spec = out_spec if fuse_gelu else ("…5", ) + + return SdyShardingRule( + operand_mappings=( + lhs_specs, + lhs_scale_specs, + rhs_specs, + rhs_scale_specs, + bias_spec, + gelu_spec, + ), + result_mappings=( + out_spec, + bias_spec, + gelu_spec, + ), + ) register_primitive(GemmPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index bf5c257d7b..9d7351ac99 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -587,16 +587,17 @@ def shardy_sharding_rule( result_types, ) + prefix = "NormFwdPrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1 + len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec - out = x_axes[:-1] + ("k",) - colwise_out = out if is_2x else ("…4",) + out = x_axes + colwise_out = out if is_2x else (prefix + "out_colwise", ) rsigma = x_axes[:-1] - mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma - amax = ("…6",) + mu = (prefix + "mu", ) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = (prefix + "amax",) return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), @@ -609,7 +610,6 @@ def shardy_sharding_rule( mu, rsigma, ), - **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 3cb0e1cdfb..73c1f6bc58 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -488,9 +488,10 @@ def shardy_sharding_rule( ): del out_dtype, scale_dtype, is_outer, mesh, result_types + prefix = "BaseDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( len(value_types[0].shape), - unique_var="BaseDBiasQuantizePrimitive_i", + unique_var=prefix + "x", flatten_axis=flatten_axis, ) @@ -498,22 +499,19 @@ def shardy_sharding_rule( colwise_scale_inv = scale_rules.colwise_rule out = x_axes + colwise_out = (prefix + "out_colwise", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) else: colwise_out = x_axes - else: - colwise_out = ("j",) - colwise_scale_inv = ("k",) - dbias = x_axes[flatten_axis:] if is_dbias else ("l",) - amax = ("m",) + dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias", ) + amax = (prefix + "amax", ) return SdyShardingRule( (x_axes, ("…1",)), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), - **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index f45a05a399..9862f0d119 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -17,7 +17,7 @@ import operator import numpy as np -from jax.experimental.custom_partitioning import CompoundFactor +from jax.experimental.custom_partitioning import BATCHING from jax.tree_util import register_pytree_node_class import jax.numpy as jnp @@ -252,8 +252,9 @@ def get_shardy_sharding_rules( The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"x{i}" for i in range(input_rank)) - return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -488,31 +489,41 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - input_spec = [f"x{i}" for i in range(input_rank)] - - # We have to use two different factors in the two CompoundFactors because of Shardy - # verifier requirements, even though they are the same. - rowwise_var = unique_var - colwise_var = f"{unique_var}_" - input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") - - # The rowwise and colwise scale tensors should be sharded the same way as the input. - # However, we need to adjust the dimensions where the block scaling factor applies. - rowwise = input_spec.copy() - rowwise[-1] = rowwise_var - - colwise = input_spec.copy() - colwise[flatten_axis - 1] = colwise_var - - # This implementation needs to be updated for different block dims. - assert self._block_dims == (1, 32) + del flatten_axis + input_spec = [f"{unique_var}{i}" for i in range(input_rank)] + rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] + colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] + + # TODO (Alp): Padding the scales breaks the size relationship in CompoundFactors. + # Unfortunately, because Shardy rules are applied to the inner primitive, the + # only way to preserve the relationship is to lower unpadded scales to the + # underlying custom call and pad them in C++. Until that's implemented, the + # Shardy rules for block scales have to be completely disconnected from the + # Shardy rules for the tensor they belong to. + + # # We have to use two different factors in the two CompoundFactors because of Shardy + # # verifier requirements, even though they are the same. + # rowwise_var = unique_var + # colwise_var = f"{unique_var}_" + # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") + # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") + + # # The rowwise and colwise scale tensors should be sharded the same way as the input. + # # However, we need to adjust the dimensions where the block scaling factor applies. + # rowwise = input_spec.copy() + # rowwise[-1] = rowwise_var + + # colwise = input_spec.copy() + # colwise[flatten_axis - 1] = colwise_var + + # # This implementation needs to be updated for different block dims. + # assert self._block_dims == (1, 32) return QuantizeShardyRules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {"block_size_rowwise": 32, "block_size_colwise": 32}, + {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, ) @@ -613,7 +624,9 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules( + input_rank, unique_var, flatten_axis + ) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 From 875f4017a5986711f1f3838636b0131133110e82 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Jul 2025 22:38:29 +0000 Subject: [PATCH 21/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../encoder/test_model_parallel_encoder.py | 10 ++++--- examples/jax/encoder/test_multigpu_encoder.py | 5 ++-- .../encoder/test_multiprocessing_encoder.py | 5 ++-- tests/jax/test_distributed_layernorm_mlp.py | 9 +++++- .../jax/cpp_extensions/activation.py | 18 ++++++------ transformer_engine/jax/cpp_extensions/gemm.py | 28 +++++++++---------- .../jax/cpp_extensions/normalization.py | 4 +-- .../jax/cpp_extensions/quantization.py | 6 ++-- .../jax/quantize/scaling_modes.py | 6 ++-- 9 files changed, 49 insertions(+), 42 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 17e0e53175..8896fc4251 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -543,8 +543,9 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf(tex.gemm_uses_jax_dot(), - "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True @@ -554,8 +555,9 @@ def test_te_mxfp8_shardy(self): assert actual[0] < 0.43 and actual[1] > 0.8 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf(tex.gemm_uses_jax_dot(), - "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) def test_te_mxfp8_with_sp_shardy(self): """Test Transformer Engine with MXFP8 + SP""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 08399e41d6..b2fed379e0 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -490,8 +490,9 @@ def test_te_current_scaling_fp8_shardy(self): assert actual[0] < 0.53 and actual[1] > 0.74 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf(tex.gemm_uses_jax_dot(), - "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 4af9a89d7d..0ac9a58c5a 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -658,8 +658,9 @@ def test_te_current_scaling_fp8_shardy(self): @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" ) - @unittest.skipIf(tex.gemm_uses_jax_dot(), - "`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index c958326472..048a4aedf8 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -286,7 +286,14 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm, + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 7c201e84f3..57133f48aa 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -420,12 +420,12 @@ def shardy_sharding_rule( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 ) - x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}", ) + x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) out = (*x_axes[:-2], x_axes[-1]) scale_inv = scale_rules.rowwise_rule - colwise_out = (prefix + "out_colwise", ) - colwise_scale_inv = (prefix + "scale_inv_colwise", ) + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: @@ -436,12 +436,12 @@ def shardy_sharding_rule( colwise_out = out # amax is always a unit tensor. - amax = (prefix + "amax", ) + amax = (prefix + "amax",) return SdyShardingRule( ( x_axes, - ("…1", ), + ("…1",), ), (out, colwise_out, scale_inv, colwise_scale_inv, amax), ) @@ -895,18 +895,18 @@ def shardy_sharding_rule( x_axes = scale_rules.input_spec dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes - colwise_out = (prefix + "out_colwise", ) + colwise_out = (prefix + "out_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: colwise_out = out - dbias = x_axes[-2:] if is_dbias else (prefix + "dbias", ) - amax = (prefix + "amax", ) + dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) + amax = (prefix + "amax",) return SdyShardingRule( - (dz_axes, x_axes, ("…2", )), + (dz_axes, x_axes, ("…2",)), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e12e38f5e6..47ec2ac346 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -215,16 +215,13 @@ def _dims_are_consecutive(dims): f"{rhs_batch_dims}." ) if len(lhs_batch_dims) == 0: - assert len(rhs_batch_dims) == 0, ( - "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." - ) - elif len(rhs_batch_dims) != 0: assert ( - all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) - and all(bdim in rhs_contracting_dims for bdim in rhs_batch_dims) - ), ( - "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." - ) + len(rhs_batch_dims) == 0 + ), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." + elif len(rhs_batch_dims) != 0: + assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all( + bdim in rhs_contracting_dims for bdim in rhs_batch_dims + ), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), @@ -838,22 +835,22 @@ def _generate_operand_rules(name, ndim, cdims, bdims): (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims), ) - lhs_scale_specs = ("…1", ) - rhs_scale_specs = ("…2", ) + lhs_scale_specs = ("…1",) + rhs_scale_specs = ("…2",) if scaling_mode.is_1d_block_scaling(): # Shardy rules for MXFP8 scales cannot be related to the operands because of the # global-unpadding and local-padding workflow. This can potentially insert expensive # re-shards in the partition call later if the scales are not already sharded correctly. lhs_scale_specs, rhs_scale_specs = map( - lambda specs : tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), - (lhs_specs, rhs_specs) + lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), + (lhs_specs, rhs_specs), ) lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) - bias_spec = rhs_non_cspec if fuse_bias else ("…4", ) - gelu_spec = out_spec if fuse_gelu else ("…5", ) + bias_spec = rhs_non_cspec if fuse_bias else ("…4",) + gelu_spec = out_spec if fuse_gelu else ("…5",) return SdyShardingRule( operand_mappings=( @@ -871,6 +868,7 @@ def _generate_operand_rules(name, ndim, cdims, bdims): ), ) + register_primitive(GemmPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 9d7351ac99..3b563efbd0 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -594,9 +594,9 @@ def shardy_sharding_rule( x_axes = scale_rules.input_spec out = x_axes - colwise_out = out if is_2x else (prefix + "out_colwise", ) + colwise_out = out if is_2x else (prefix + "out_colwise",) rsigma = x_axes[:-1] - mu = (prefix + "mu", ) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma amax = (prefix + "amax",) return SdyShardingRule( diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 73c1f6bc58..3d08f00df0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -499,15 +499,15 @@ def shardy_sharding_rule( colwise_scale_inv = scale_rules.colwise_rule out = x_axes - colwise_out = (prefix + "out_colwise", ) + colwise_out = (prefix + "out_colwise",) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) else: colwise_out = x_axes - dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias", ) - amax = (prefix + "amax", ) + dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) + amax = (prefix + "amax",) return SdyShardingRule( (x_axes, ("…1",)), diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 9862f0d119..7d9b67ced9 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -523,7 +523,7 @@ def get_shardy_sharding_rules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, + {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, ) @@ -624,9 +624,7 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules( - input_rank, unique_var, flatten_axis - ) + return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 From 995fb1144c7adde8eca6f43dd9c7f1431a821ba2 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 9 Jul 2025 15:25:07 +0000 Subject: [PATCH 22/30] fixed linting errors Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 4 +++- transformer_engine/jax/quantize/scaling_modes.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 47ec2ac346..c5b1e19b07 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -802,7 +802,9 @@ def shardy_sharding_rule( operand_types, result_types, ): - del out_dtype, grad, use_split_accumulator, mesh, result_types + del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator + del mesh, result_types + prefix = "GemmPrimitive_" def _generate_operand_rules(name, ndim, cdims, bdims): diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 7d9b67ced9..fc4fd13531 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -494,7 +494,7 @@ def get_shardy_sharding_rules( rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] - # TODO (Alp): Padding the scales breaks the size relationship in CompoundFactors. + # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. # Unfortunately, because Shardy rules are applied to the inner primitive, the # only way to preserve the relationship is to lower unpadded scales to the # underlying custom call and pad them in C++. Until that's implemented, the From d9e55a9f301a2d5a060660ffba6f484c909f5447 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 9 Jul 2025 19:37:41 +0000 Subject: [PATCH 23/30] changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers Signed-off-by: Alp Dener --- examples/jax/encoder/requirements.txt | 2 +- .../encoder/test_model_parallel_encoder.py | 22 +- examples/jax/encoder/test_multigpu_encoder.py | 18 +- .../encoder/test_multiprocessing_encoder.py | 16 +- tests/jax/test_custom_call_compute.py | 71 +++---- tests/jax/test_distributed_layernorm_mlp.py | 194 +++++++++--------- tests/jax/utils.py | 26 ++- transformer_engine/jax/cpp_extensions/gemm.py | 87 ++++---- transformer_engine/jax/dense.py | 13 +- transformer_engine/jax/layernorm_dense.py | 11 +- transformer_engine/jax/layernorm_mlp.py | 20 +- 11 files changed, 253 insertions(+), 227 deletions(-) diff --git a/examples/jax/encoder/requirements.txt b/examples/jax/encoder/requirements.txt index 26af82aa49..aba6a0e929 100644 --- a/examples/jax/encoder/requirements.txt +++ b/examples/jax/encoder/requirements.txt @@ -1,4 +1,4 @@ -datasets +datasets<=3.6 flax>=0.7.1 nltk>=3.8.2 optax diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 8896fc4251..7d5fefddaa 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -473,7 +473,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -481,7 +481,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -489,14 +489,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -505,7 +505,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -514,14 +514,14 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -530,7 +530,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp_shardy(self): @@ -540,7 +540,7 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf( @@ -552,7 +552,7 @@ def test_te_mxfp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf( @@ -565,7 +565,7 @@ def test_te_mxfp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.43 and actual[1] > 0.8 + assert actual[0] < 0.43 and actual[1] > 0.80 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index b2fed379e0..40431b8cdd 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -432,13 +432,13 @@ class TestEncoder(unittest.TestCase): def setUp(self): """Run 5 epochs for testing""" - self.args = encoder_parser(["--epochs", "5"]) + self.args = encoder_parser(["--epochs", "6"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -446,7 +446,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -454,7 +454,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -462,14 +462,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -478,7 +478,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8_shardy(self): @@ -487,7 +487,7 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf( @@ -499,7 +499,7 @@ def test_te_mxfp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.53 and actual[1] > 0.74 + assert actual[0] < 0.50 and actual[1] > 0.75 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 0ac9a58c5a..c1da4db4a9 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -607,7 +607,7 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def test_te_bf16(self): """Test Transformer Engine with BF16""" result = self.exec(False, None) - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -615,7 +615,7 @@ def test_te_bf16(self): def test_te_delayed_scaling_fp8(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling") - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -623,7 +623,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -631,13 +631,13 @@ def test_te_current_scaling_fp8(self): def test_te_mxfp8(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling") - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -645,7 +645,7 @@ def test_te_bf16_shardy(self): def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -653,7 +653,7 @@ def test_te_delayed_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -664,7 +664,7 @@ def test_te_current_scaling_fp8_shardy(self): def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.8 + assert result[0] < 0.43 and result[1] > 0.80 if __name__ == "__main__": diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index a50d5363ae..40f35e2697 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -13,6 +13,7 @@ from utils import ( assert_allclose, pytest_parametrize_wrapper, + use_jax_gemm, ) from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -917,8 +918,6 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi ): pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") - _use_jax_fp8_gemm(enabled=with_jax_gemm) - x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, @@ -926,15 +925,16 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi bwd_dtype=jnp.float8_e5m2, is_2x2x=False, ) - primitive_out = tex.gemm( - x, - w, - dimension_numbers=(contracting_dims, ((), ())), - lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, - rhs_quantizer=( - quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad - ), - ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + dimension_numbers=(contracting_dims, ((), ())), + lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, + rhs_quantizer=( + quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), + ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) @@ -969,8 +969,6 @@ def ref_func(x, w, data_layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): - _use_jax_fp8_gemm(enabled=with_jax_gemm) - data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -999,10 +997,11 @@ def ref_func(x, w, bias, data_layout): ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( - value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) - ) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( + value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) + ) ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( x, w, bias, data_layout @@ -1042,8 +1041,6 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g """ Test layernorm_dense VJP Rule """ - _use_jax_fp8_gemm(enabled=with_jax_gemm) - # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1098,13 +1095,14 @@ def ref_func(x, w, gamma, beta): ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - prim_out, ( - prim_x_grad, - prim_w_grad, - prim_gamma_grad, - prim_beta_grad, - ) = value_n_grad_prim_func(x, w, gamma, beta) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + prim_out, ( + prim_x_grad, + prim_w_grad, + prim_gamma_grad, + prim_beta_grad, + ) = value_n_grad_prim_func(x, w, gamma, beta) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) @@ -1126,8 +1124,6 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ - _use_jax_fp8_gemm(enabled=with_jax_gemm) - # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1202,15 +1198,16 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_ref_func = value_and_grad(ref_func, range(6)) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - prim_out, ( - prim_x_grad, - prim_gamma_grad, - prim_kernel_1_grad, - prim_kernel_2_grad, - prim_bias_1_grad, - prim_bias_2_grad, - ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + prim_out, ( + prim_x_grad, + prim_gamma_grad, + prim_kernel_1_grad, + prim_kernel_2_grad, + prim_bias_1_grad, + prim_bias_2_grad, + ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ref_out, ( ref_x_grad, diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 048a4aedf8..8c377cdd22 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -13,6 +13,7 @@ assert_tree_like_allclose, is_devices_enough, pytest_parametrize_wrapper, + use_jax_gemm, ) from transformer_engine.common import recipe @@ -74,15 +75,6 @@ def generate_fsdp_and_tp_configs(): return configs -def use_jax_fp8_gemm(enabled=False): - import os - - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - class TestDistributedLayernormMLP: def generate_inputs(self, input_shape, activation_type, use_bias, dtype): @@ -165,7 +157,6 @@ def _test_layernorm_mlp_grad( use_shardy, with_jax_gemm, ): - use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -178,53 +169,56 @@ def _test_layernorm_mlp_grad( self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) ) - # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - single_jitter = jax.jit( - value_and_grad_func, - static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), - ) - single_fwd, single_grads = single_jitter(*inputs, *static_inputs) - - # Multi GPUs - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) - k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) - k1_ = jax.device_put(k1, k1_sharding) - k2_ = jax.device_put(k2, k2_sharding) - if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) - b1_ = jax.device_put(b1, b1_sharding) - else: - b1_sharding = b1_ = None - multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]] - - # Position ref for sharding pspec lists - # x, gamma, k1, k2, b1, - # b2 - in_shardings = ( - None, - None, - k1_sharding, - k2_sharding, - b1_sharding, - None, - ) - out_shardings = ( - None, - (None, None, k1_sharding, k2_sharding, b1_sharding, None), - ) - - multi_jitter = jax.jit( - value_and_grad_func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), - ) # +1 for multi_gpus - - multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) + with use_jax_gemm(enabled=with_jax_gemm): + # Single GPU + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + single_jitter = jax.jit( + value_and_grad_func, + static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), + ) + single_fwd, single_grads = single_jitter(*inputs, *static_inputs) + + # Multi GPUs + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ): + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) + k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) + k1_ = jax.device_put(k1, k1_sharding) + k2_ = jax.device_put(k2, k2_sharding) + if use_bias: + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) + b1_ = jax.device_put(b1, b1_sharding) + else: + b1_sharding = b1_ = None + multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]] + + # Position ref for sharding pspec lists + # x, gamma, k1, k2, b1, + # b2 + in_shardings = ( + None, + None, + k1_sharding, + k2_sharding, + b1_sharding, + None, + ) + out_shardings = ( + None, + (None, None, k1_sharding, k2_sharding, b1_sharding, None), + ) + + multi_jitter = jax.jit( + value_and_grad_func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), + ) # +1 for multi_gpus + + multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 @@ -320,7 +314,6 @@ def _test_layernorm_mlp( use_shardy, with_jax_gemm, ): - use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -331,48 +324,49 @@ def _test_layernorm_mlp( x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) init_rngs = {"params": subkeys[1]} - # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): - ln_mlp_single = LayerNormMLP( - layernorm_type=layernorm_type, - transpose_batch_sequence=False, # input: [batch, seqlen, hidden] - intermediate_dim=INTERMEDIATE, - activations=activation_type, - use_bias=use_bias, - ) - params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) - mlp_out_single, ln_out_single = ln_mlp_single.apply( - params_single, x, deterministic=True - ) - - # Multi GPUs - device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource - ): - ln_mlp_sharded = LayerNormMLP( - layernorm_type=layernorm_type, - transpose_batch_sequence=False, - intermediate_dim=INTERMEDIATE, - activations=activation_type, - scale_axes=LN_SCALE_AXES, - ln_bias_axes=LN_BIAS_AXES, - kernel_axes_1=KERNEL_1_AXES, - kernel_axes_2=KERNEL_2_AXES, - use_bias=use_bias, - bias_axes_1=BIAS_1_AXES, - bias_axes_2=BIAS_2_AXES, - layernorm_input_axes=LAYERNORM_INPUT_AXES, - dot_1_input_axes=DOT_1_INPUT_AXES, - dot_2_input_axes=DOT_2_INPUT_AXES, - name="mlp", - ) - params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) - mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( - params_sharded, x, deterministic=True - ) + with use_jax_gemm(enabled=with_jax_gemm): + # Single GPUs + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + ln_mlp_single = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, # input: [batch, seqlen, hidden] + intermediate_dim=INTERMEDIATE, + activations=activation_type, + use_bias=use_bias, + ) + params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) + mlp_out_single, ln_out_single = ln_mlp_single.apply( + params_single, x, deterministic=True + ) + + # Multi GPUs + device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast( + enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ): + ln_mlp_sharded = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, + intermediate_dim=INTERMEDIATE, + activations=activation_type, + scale_axes=LN_SCALE_AXES, + ln_bias_axes=LN_BIAS_AXES, + kernel_axes_1=KERNEL_1_AXES, + kernel_axes_2=KERNEL_2_AXES, + use_bias=use_bias, + bias_axes_1=BIAS_1_AXES, + bias_axes_2=BIAS_2_AXES, + layernorm_input_axes=LAYERNORM_INPUT_AXES, + dot_1_input_axes=DOT_1_INPUT_AXES, + dot_2_input_axes=DOT_2_INPUT_AXES, + name="mlp", + ) + params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) + mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( + params_sharded, x, deterministic=True + ) # Make sure params values are the same assert_tree_like_allclose(params_sharded["params"], params_single["params"]) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 39eb368375..f3bb10bd2f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -3,11 +3,12 @@ # See LICENSE for license information. """Utility for the TE layer tests""" +import os import functools import math import operator from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional -import os +from contextlib import contextmanager import jax import jax.numpy as jnp @@ -28,7 +29,6 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype Array = Any PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] @@ -160,7 +160,7 @@ class DotProductAttention(nn.Module): transpose_batch_sequence: bool = True scale_attn_logits: bool = True dropout_rate: float = 0.0 - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 float32_logits: bool = False """Computes dot-product attention given query, key, and value. @@ -283,7 +283,7 @@ class DenseGeneral(nn.Module): features: Union[Iterable[int], int] axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 kernel_init: Initializer = None kernel_axes: Tuple[str, ...] = () use_bias: bool = False @@ -525,7 +525,7 @@ class MultiHeadAttention(nn.Module): num_gqa_groups: int | None = None head_dim: int = 64 transpose_batch_sequence: bool = True - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 dropout_rate: float = 0.0 kernel_init: Initializer = None float32_logits: bool = False # computes logits in float32 for stability. @@ -1424,7 +1424,7 @@ def assert_allclose( desired: Array, rtol: Optional[float] = None, atol: Optional[float] = None, - dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None, + dtype: Optional[Union[jnp.dtype, TEDType, np.dtype, str]] = None, **kwargs, ) -> None: """Check if two tensors are close. @@ -1484,7 +1484,7 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): def dtype_tols( - dtype: Union[DType, TEDType, np.dtype], + dtype: Union[jnp.dtype, TEDType, np.dtype], reference_value: float = 1.0, rtol: Optional[float] = None, atol: Optional[float] = None, @@ -1600,3 +1600,15 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): fmt = fmt + "\n {}\n {}" jax.debug.print(fmt, *args) + + +@contextmanager +def use_jax_gemm(enabled=False): + orig_custom_calls = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + try: + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + yield + finally: + if enabled and orig_custom_calls is not None: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c5b1e19b07..4de4fcfb9c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -156,7 +156,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14) + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -169,7 +169,8 @@ def abstract( bias, gelu_input, out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -187,7 +188,6 @@ def _dims_are_consecutive(dims): # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) - contracting_dims, batch_dims = dimension_numbers ( lhs_contracting_dims, @@ -205,7 +205,7 @@ def _dims_are_consecutive(dims): ( lhs_batch_dims, rhs_batch_dims, - ) = map(sanitize_dims, operand_ndims, batch_dims) + ) = map(sanitize_dims, operand_ndims, batched_dims) assert _dims_are_consecutive(lhs_batch_dims), ( "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got " f"{lhs_batch_dims}." @@ -335,7 +335,8 @@ def lowering( bias, gelu_input, out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -344,8 +345,7 @@ def lowering( grad, use_split_accumulator, ): - del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype - contracting_dims, _ = dimension_numbers + del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -385,7 +385,8 @@ def impl( bias, gelu_input, out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -394,7 +395,7 @@ def impl( grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0]) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) ) @@ -422,7 +423,8 @@ def impl( bias, gelu_input, out_dtype=out_dtype, - dimension_numbers=dimension_numbers, + contracting_dims=contracting_dims, + batched_dims=batched_dims, lhs_quantized_colwise=lhs_quantized_colwise, rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, @@ -436,9 +438,10 @@ def impl( @staticmethod def batcher( batched_args, - batch_dims, + jax_batch_dims, out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -449,9 +452,8 @@ def batcher( ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args - lhs_bdims, _, rhs_bdims, *_ = batch_dims - contracting_dims, batch_dims = dimension_numbers - arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) + lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims + arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " @@ -480,7 +482,8 @@ def batcher( GemmPrimitive.outer_primitive.bind( *batched_args, out_dtype=out_dtype, - dimension_numbers=dimension_numbers, + contracting_dims=contracting_dims, + batched_dims=batched_dims, lhs_quantized_colwise=lhs_quantized_colwise, rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, @@ -509,12 +512,11 @@ def _decompose_operand_specs(specs, contracting_dims, batch_dims): return bspecs, lspecs, cspecs @staticmethod - def _parse_operand_output_specs(arg_infos, dimension_numbers): + def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) - contracting_dims, batch_dims = dimension_numbers lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( - sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims ) ( (lhs_bspecs, lhs_lspecs, lhs_cspecs), @@ -652,7 +654,8 @@ def _parse_operand_output_specs(arg_infos, dimension_numbers): @staticmethod def infer_sharding_from_operands( out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -674,7 +677,7 @@ def infer_sharding_from_operands( del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) + GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -693,7 +696,8 @@ def infer_sharding_from_operands( @staticmethod def partition( out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -713,7 +717,7 @@ def partition( all_reduce_spec, reduce_scatter_spec, scatter_dim, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) # Assemble argument shardings # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. @@ -759,7 +763,8 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): bias, gelu_input, out_dtype=out_dtype, - dimension_numbers=dimension_numbers, + contracting_dims=contracting_dims, + batched_dims=batched_dims, lhs_quantized_colwise=lhs_quantized_colwise, rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, @@ -790,7 +795,8 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): @staticmethod def shardy_sharding_rule( out_dtype, - dimension_numbers, + contracting_dims, + batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, @@ -828,7 +834,7 @@ def _generate_operand_rules(name, ndim, cdims, bdims): operand_ndims = (len(lhs.shape), len(rhs.shape)) (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( lambda dims: map(sanitize_dims, operand_ndims, dims), - dimension_numbers, + (contracting_dims, batched_dims), ) lhs_specs, rhs_specs = map( _generate_operand_rules, @@ -896,7 +902,8 @@ def _te_gemm( gelu_input: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, - dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), + batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, @@ -908,10 +915,9 @@ def _te_gemm( lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING - contracting_dims, batch_dims = dimension_numbers lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -965,7 +971,8 @@ def _te_gemm( bias, gelu_input, out_dtype=out_dtype, - dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), + contracting_dims=(lhs_cdims, rhs_cdims), + batched_dims=(lhs_bdims, rhs_bdims), lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, scaling_mode=scaling_mode, @@ -1281,7 +1288,8 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), + batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, **kwargs, @@ -1298,11 +1306,13 @@ def gemm( Object for down-casting the LHS operand for quantized GEMM. rhs_quantizer: Quantizer, default = None Object for down-casting the RHS operand for quantized GEMM. - dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]], default = (((-1, ), (0, )), ((), ())) - Tuple of two tuples of sequences representing the contracting and batched dimensions, - respectively. The first sequence in each tuple represents the contracting/batched - dimensions of the LHS operand, and the second sequence represents the contracting/batched - dimensions of the RHS operand. + contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) + Tuple of sequences representing the contracting dimensions of the operands. + batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), + Tuple of sequences representing the batched dimensions of the operands. This is *not* used + to perform a batched matrix multiplication, but it is required to avoid a potentially + undesirable reduction in any batched contracting dimensions when invoked with sharded + operands (e.g. when computing weight gradients in a Flax module). bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1359,14 +1369,15 @@ def gemm( "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) - return _jax_gemm(lhs, rhs, dimension_numbers[0], lhs_quantizer, rhs_quantizer) + return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) outputs = _te_gemm( lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, - dimension_numbers=dimension_numbers, + contracting_dims=contracting_dims, + batched_dims=batched_dims, **kwargs, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 0d4a1b7524..42c91be648 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -53,7 +53,7 @@ def dense( # Remove when tex.quantize() can handle quantizer=None if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, dimension_numbers=(contracting_dims, ((), ()))) + output = tex.gemm(x, kernel, contracting_dims=contracting_dims) if bias is not None: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) @@ -79,7 +79,7 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types - batch_first: Assume that X is batched in the first dimension. + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. Returns: Transformed output tensor @@ -141,7 +141,8 @@ def _dense_fwd_rule( output = tex.gemm( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -206,7 +207,8 @@ def _dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, - dimension_numbers=((g_contracting_dim, k_contracting_dim), ((x_bdim,), ())), + contracting_dims=(g_contracting_dim, k_contracting_dim), + batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -219,7 +221,8 @@ def _dense_bwd_rule( wgrad = tex.gemm( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), - dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim,))), + contracting_dims=(x_contracting_dim, g_contracting_dim), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 62fa2cfcd2..d0106b35ec 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -58,7 +58,7 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix - batch_first: Assume that X is batched in the first dimension. + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_set: Set of quantizers for different tensor types Returns: @@ -217,7 +217,8 @@ def _layernorm_dense_fwd_rule( output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -308,7 +309,8 @@ def _layernorm_dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), + contracting_dims=(g_constracting_dim, k_constracting_dim), + batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -321,7 +323,8 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim,))), + contracting_dims=(x_constracting_dim, g_constracting_dim), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 5d129aa54d..5ca94eb635 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -80,7 +80,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation - batch_first: Assume that X is batched in the first dimension. + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -291,7 +291,8 @@ def _layernorm_mlp_fwd_rule( dot_1_output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) @@ -326,7 +327,8 @@ def _layernorm_mlp_fwd_rule( dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, ) @@ -436,7 +438,8 @@ def _layernorm_mlp_bwd_rule( dgrad_2 = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, - dimension_numbers=((g_contracting_dims_2, k_contracting_dims_2), ((x_bdim,), ())), + contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + batched_dims=((x_bdim,), ()), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -450,7 +453,8 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), + contracting_dims=(x_contracting_dims, g_contracting_dims), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -477,7 +481,8 @@ def _layernorm_mlp_bwd_rule( dgrad_1 = tex.gemm( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, - dimension_numbers=((g_contracting_dims_1, k_contracting_dims_1), ((x_bdim,), ())), + contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + batched_dims=((x_bdim,), ()), ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -487,7 +492,8 @@ def _layernorm_mlp_bwd_rule( wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), + contracting_dims=(x_contracting_dims, g_contracting_dims), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From e850ab5f4d633465cbd67e4a0ea26b1a08753fbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Jul 2025 19:40:57 +0000 Subject: [PATCH 24/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 4 +++- tests/jax/test_distributed_layernorm_mlp.py | 4 +++- tests/jax/utils.py | 6 +++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 290b392b37..5d6b62f613 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -930,7 +930,9 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi x, w, dimension_numbers=(contracting_dims, ((), ())), - lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, + lhs_quantizer=( + quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), rhs_quantizer=( quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad ), diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 8c377cdd22..39d9e24402 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -215,7 +215,9 @@ def _test_layernorm_mlp_grad( value_and_grad_func, in_shardings=in_shardings, out_shardings=out_shardings, - static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), + static_argnums=range( + len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1 + ), ) # +1 for multi_gpus multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index f3bb10bd2f..37e7a4f7d3 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -160,7 +160,7 @@ class DotProductAttention(nn.Module): transpose_batch_sequence: bool = True scale_attn_logits: bool = True dropout_rate: float = 0.0 - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.float32 float32_logits: bool = False """Computes dot-product attention given query, key, and value. @@ -283,7 +283,7 @@ class DenseGeneral(nn.Module): features: Union[Iterable[int], int] axis: Union[Iterable[int], int] = -1 - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.float32 kernel_init: Initializer = None kernel_axes: Tuple[str, ...] = () use_bias: bool = False @@ -525,7 +525,7 @@ class MultiHeadAttention(nn.Module): num_gqa_groups: int | None = None head_dim: int = 64 transpose_batch_sequence: bool = True - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.float32 dropout_rate: float = 0.0 kernel_init: Initializer = None float32_logits: bool = False # computes logits in float32 for stability. From 4db82a040aac37b319798445484b16ef46213322 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 9 Jul 2025 19:46:49 +0000 Subject: [PATCH 25/30] fixed typo in test utils Signed-off-by: Alp Dener --- tests/jax/test_distributed_layernorm_mlp.py | 7 ++++--- tests/jax/utils.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 39d9e24402..b3e77b7dc6 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -165,11 +165,12 @@ def _test_layernorm_mlp_grad( input_shape, activation_type, use_bias, dtype ) static_inputs = [layernorm_type, activation_type] - value_and_grad_func = jax.value_and_grad( - self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) - ) with use_jax_gemm(enabled=with_jax_gemm): + value_and_grad_func = jax.value_and_grad( + self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) + ) + # Single GPU with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): single_jitter = jax.jit( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 37e7a4f7d3..85cc9ea0cb 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -33,7 +33,7 @@ PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] -Initializer = Callable[[PRNGKey, Shape, DType], Array] +Initializer = Callable[[PRNGKey, Shape, jnp.dtype], Array] # Enables verbose printing of tensor numerics for debug. NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) @@ -1604,11 +1604,16 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): @contextmanager def use_jax_gemm(enabled=False): - orig_custom_calls = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + try: if enabled: os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" yield + finally: - if enabled and orig_custom_calls is not None: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls + if enabled: + if orig_custom_calls_filter is None: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + else: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter From 14da5c84f30a821bc25824b06b887b39f38edfaa Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 9 Jul 2025 21:06:06 +0000 Subject: [PATCH 26/30] added sequence-first input warnings Signed-off-by: Alp Dener --- tests/jax/utils.py | 20 +++++++-------- transformer_engine/jax/dense.py | 18 +++++++++++++- transformer_engine/jax/flax/module.py | 30 ++++++++++++++++++----- transformer_engine/jax/layernorm_dense.py | 18 ++++++++++++++ transformer_engine/jax/layernorm_mlp.py | 18 ++++++++++++++ 5 files changed, 87 insertions(+), 17 deletions(-) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 85cc9ea0cb..3dbf993b97 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -7,7 +7,7 @@ import functools import math import operator -from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional +from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType from contextlib import contextmanager import jax @@ -21,7 +21,6 @@ import pytest from transformer_engine.jax.attention import ( - AttnMaskType, canonicalize_attn_mask_type, make_swa_mask, ) @@ -29,11 +28,12 @@ PRNGKey = Any Shape = Tuple[int, ...] -Array = Any +DType = NewType('DType', jnp.dtype) +Array = NewType('Array', jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] -Initializer = Callable[[PRNGKey, Shape, jnp.dtype], Array] +Initializer = Callable[[PRNGKey, Shape, DType], Array] # Enables verbose printing of tensor numerics for debug. NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0))) @@ -160,7 +160,7 @@ class DotProductAttention(nn.Module): transpose_batch_sequence: bool = True scale_attn_logits: bool = True dropout_rate: float = 0.0 - dtype: jnp.dtype = jnp.float32 + dtype: DType = jnp.float32 float32_logits: bool = False """Computes dot-product attention given query, key, and value. @@ -283,7 +283,7 @@ class DenseGeneral(nn.Module): features: Union[Iterable[int], int] axis: Union[Iterable[int], int] = -1 - dtype: jnp.dtype = jnp.float32 + dtype: DType = jnp.float32 kernel_init: Initializer = None kernel_axes: Tuple[str, ...] = () use_bias: bool = False @@ -525,7 +525,7 @@ class MultiHeadAttention(nn.Module): num_gqa_groups: int | None = None head_dim: int = 64 transpose_batch_sequence: bool = True - dtype: jnp.dtype = jnp.float32 + dtype: DType = jnp.float32 dropout_rate: float = 0.0 kernel_init: Initializer = None float32_logits: bool = False # computes logits in float32 for stability. @@ -1424,7 +1424,7 @@ def assert_allclose( desired: Array, rtol: Optional[float] = None, atol: Optional[float] = None, - dtype: Optional[Union[jnp.dtype, TEDType, np.dtype, str]] = None, + dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None, **kwargs, ) -> None: """Check if two tensors are close. @@ -1484,7 +1484,7 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): def dtype_tols( - dtype: Union[jnp.dtype, TEDType, np.dtype], + dtype: Union[DType, TEDType, np.dtype], reference_value: float = 1.0, rtol: Optional[float] = None, atol: Optional[float] = None, @@ -1519,7 +1519,7 @@ def dtype_tols( TEDType.kFloat8E5M2: jnp.float8_e5m2, }[dtype] elif isinstance(dtype, np.dtype): - dtype = jnp.dtype(dtype) + dtype = DType(dtype) # Expect bit-wise accuracy for integer dtypes if not jnp.issubdtype(dtype, jnp.floating): diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 42c91be648..a0fc7b7af8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -8,7 +8,7 @@ It implements matrix multiplication with optional bias addition and supports customizable contracting dimensions for flexible tensor operations. """ - +import warnings from typing import Tuple, Sequence from functools import partial import jax @@ -23,6 +23,16 @@ ) +DENSE_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global DENSE_BATCH_FIRST_WARNING_ISSUED + if not DENSE_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + DENSE_BATCH_FIRST_WARNING_ISSUED = True + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -118,6 +128,12 @@ def _dense_fwd_rule( if x.ndim >= num_cdims + 2: # Assume X is batched if it has at least +2 dimensions more than the number of contracting # dimensions. + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `dense()` layer implementation does not officially support sequence-first " + "inputs and may produce incorrect results when `batch_first=False`. Use " + "sequence-first inputs at your own discretion.", + ) x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 flatten_axis_x = -len(x_contracting_dims) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index bd311472f0..e88092ce35 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -6,7 +6,7 @@ """ from functools import reduce import operator -from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType import numpy as np import jax.numpy as jnp @@ -15,12 +15,12 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, _issue_batch_first_warning as _dense_warning from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm -from ..layernorm_dense import layernorm_dense -from ..layernorm_mlp import layernorm_mlp +from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning +from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -35,8 +35,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = jnp.ndarray +DType = NewType('DType', jnp.dtype) +Array = NewType('Array', jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] @@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase): input_axes: Tuple[str, ...] = () def __post_init__(self): + if self.transpose_batch_sequence: + _dense_warning( + "TE/JAX DenseGeneral() module does not officially support sequence-first inputs " + "and may produce incorrect results when `transpose_batch_sequence=True`. Use " + "sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype @@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float = None def __post_init__(self): + if self.transpose_batch_sequence: + _ln_dense_warning( + "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first " + "inputs and may produce incorrect results when `transpose_batch_sequence=True`. " + "Use sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, @@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase): dot_2_input_axes: Tuple[str, ...] = None def __post_init__(self): + if self.transpose_batch_sequence: + _ln_mlp_warning( + "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs " + "and may produce incorrect results when `transpose_batch_sequence=True`. Use " + "sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index d0106b35ec..5ccfc71c24 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -9,6 +9,7 @@ distributed training through sharding constraints. """ +import warnings from functools import partial from typing import Tuple @@ -25,6 +26,16 @@ ) +LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED + if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True + + def layernorm_dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -188,6 +199,13 @@ def _layernorm_dense_fwd_rule( x_bdim = None if x.ndim > 2: + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `layernorm_dense()` fused-layer implementation does not officially " + "support sequence-first inputs and may produce incorrect results when " + "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " + "inputs at your own discretion." + ) x_bdim = 0 if batch_first else x.ndim - 2 x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 5ca94eb635..507c49c7e9 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -13,6 +13,7 @@ quantization, and distributed training through sharding constraints. """ +import warnings from typing import List, Tuple, Sequence, Union, Callable from functools import partial @@ -31,6 +32,16 @@ from .sharding import get_non_contracting_logical_axes +LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED + if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True + + def layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -263,6 +274,13 @@ def _layernorm_mlp_fwd_rule( x_bdim = None if x.ndim > 2: + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially " + "support sequence-first inputs and may produce incorrect results when " + "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " + "inputs at your own discretion." + ) x_bdim = 0 if batch_first else x.ndim - 2 use_bias_1 = bias_1 is not None From d5cb23307156697b8a15fe11d66b90b6e10b1291 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Jul 2025 21:07:56 +0000 Subject: [PATCH 27/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/utils.py | 4 ++-- transformer_engine/jax/flax/module.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 3dbf993b97..13b2b9148f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -28,8 +28,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = NewType('DType', jnp.dtype) -Array = NewType('Array', jnp.ndarray) +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index e88092ce35..5992d36079 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -35,8 +35,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = NewType('DType', jnp.dtype) -Array = NewType('Array', jnp.ndarray) +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] From 8cbe0e21f407f924c0052cf8933ad2d9966eeece Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 10 Jul 2025 15:03:45 +0000 Subject: [PATCH 28/30] fixed datasets version for JAX examples Signed-off-by: Alp Dener --- examples/jax/mnist/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jax/mnist/requirements.txt b/examples/jax/mnist/requirements.txt index 5624ca1ead..ee1e543805 100644 --- a/examples/jax/mnist/requirements.txt +++ b/examples/jax/mnist/requirements.txt @@ -1,4 +1,4 @@ -datasets +datasets<4.0.0 flax>=0.7.1 optax Pillow From 66eab769d668d32c308e3e71f3f544a3d4f38c51 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 10 Jul 2025 16:46:45 +0000 Subject: [PATCH 29/30] reverting modification to force_1x_quantization decision Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 3d08f00df0..23e821b1a0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -668,7 +668,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if force_1x_quantization: + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): colwise_scale_inv = rowwise_scale_inv if q_layout == QuantizeLayout.ROWWISE: From 9781ebf5e48ff9eb1ee02470d79bbc42e16c0757 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 11 Jul 2025 02:48:36 +0000 Subject: [PATCH 30/30] corrected gemm function syntax in unit tests Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5d6b62f613..d6a17ac372 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -899,7 +899,7 @@ def _generate_gemm_input(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) - primitive_out = tex.gemm(x, w, dimension_numbers=(contracting_dims, ((), ()))) + primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @@ -929,7 +929,7 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi primitive_out = tex.gemm( x, w, - dimension_numbers=(contracting_dims, ((), ())), + contracting_dims=contracting_dims, lhs_quantizer=( quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad ),