From cf1774c06cb463a38aafc17e7d75eb1322c23ee5 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 4 Jun 2025 21:30:27 +0000 Subject: [PATCH 01/27] added XLA FFI custom op for TE/common nvte_cublas_gemm Signed-off-by: Alp Dener started GemmPrimitive, abstract done Signed-off-by: Alp Dener gemm custom op working with BF16, needs testing for FP8/MXFP8 Signed-off-by: Alp Dener converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call Signed-off-by: Alp Dener BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue Signed-off-by: Alp Dener fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2 Signed-off-by: Alp Dener updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet Signed-off-by: Alp Dener new GemmPrimitive passing all Dense tests Signed-off-by: Alp Dener import cleanup and reverted code chunk movement Signed-off-by: Alp Dener removed unused .transpose() implementations from ScaledTensors Signed-off-by: Alp Dener all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl Signed-off-by: Alp Dener removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions Signed-off-by: Alp Dener removed unused changes to ScaledTensor classes and debug prints Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 188 +++-- tests/jax/test_layer.py | 15 +- .../jax/cpp_extensions/activation.py | 75 +- transformer_engine/jax/cpp_extensions/gemm.py | 696 ++++++++++++++++-- transformer_engine/jax/cpp_extensions/misc.py | 9 +- .../jax/cpp_extensions/normalization.py | 7 + .../jax/cpp_extensions/quantization.py | 65 +- transformer_engine/jax/csrc/extensions.h | 3 + .../jax/csrc/extensions/gemm.cpp | 173 +++++ transformer_engine/jax/csrc/extensions/misc.h | 9 + .../jax/csrc/extensions/pybind.cpp | 6 + transformer_engine/jax/dense.py | 104 +-- transformer_engine/jax/layernorm_dense.py | 115 +-- transformer_engine/jax/layernorm_mlp.py | 208 +++--- .../jax/quantize/dequantizer.py | 17 + transformer_engine/jax/quantize/tensor.py | 63 +- 16 files changed, 1349 insertions(+), 404 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index f689bce6a5..5a59d113d5 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2,13 +2,15 @@ # # See LICENSE for license information. +import os +import operator +from functools import reduce +from typing import Union + +import pytest import jax import jax.numpy as jnp -import pytest from jax import jit, value_and_grad -from functools import reduce -from typing import Union -import operator from utils import ( assert_allclose, @@ -30,7 +32,6 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( - DelayedScaleQuantizer, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -44,6 +45,9 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x + +from transformer_engine_jax import is_non_nt_fp8_gemm_supported GEMM_CASES = [ (256, 256, 512), @@ -59,10 +63,14 @@ is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] +supported_fp8_gemm_layouts = [] """ Find supported scaling modes""" if is_fp8_supported: supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) + supported_fp8_gemm_layouts.append("NT") + if is_non_nt_fp8_gemm_supported(): + supported_fp8_gemm_layouts += ["TT", "TN", "NN"] if is_mxfp8_supported: supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) @@ -73,7 +81,7 @@ def is_shape_supported_by_mxfp8(input_shape): input_shape = input_shape.values[0] ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True - except: + except AssertionError: # get_scale_shapes will raise an exception if the shape is not supported return False @@ -147,6 +155,13 @@ def assert_dequantized_grouped_scaled_tensor( pytest.fail("a must be a GroupedScaledTensor object") +def use_jax_dot_for_gemm(enabled=False): + if enabled: + os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' + elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: + os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + + ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_TYPES = [ ("gelu",), @@ -624,15 +639,15 @@ def test_quantize_bitwise( ): key = jax.random.PRNGKey(0) - input = jax.random.uniform(key, input_shape, in_dtype) + inp = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) + jax_output = _jax_quantize(inp, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + te_output = tex.quantize(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -710,7 +725,7 @@ def test_quantize_dbias( pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") key = jax.random.PRNGKey(0) - input = jax.random.uniform(key, input_shape, in_dtype) + inp = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout @@ -718,15 +733,15 @@ def test_quantize_dbias( te_output, te_dbias = jit( lambda input: tex.quantize_dbias( - input, quantizer=te_quantizer, flatten_axis=flatten_axis + inp, quantizer=te_quantizer, flatten_axis=flatten_axis ) - )(input) + )(inp) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( - input, quantizer=jax_quantizer, flatten_axis=flatten_axis + inp, quantizer=jax_quantizer, flatten_axis=flatten_axis ) - )(input) + )(inp) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -880,7 +895,10 @@ def _generate_gemm_input(self, m, n, k, data_layout): @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_bf16(self, m, n, k, data_layout): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims) @@ -890,23 +908,56 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) + @pytest_parametrize_wrapper( + "lhs_q_dtype,rhs_q_dtype", + [ + (jnp.float8_e4m3fn, jnp.float8_e4m3fn), # fprop GEMM + (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM + (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM + ] + ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): - x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + @pytest_parametrize_wrapper("data_layout", supported_fp8_gemm_layouts) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_fp8(self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, + with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + + lhs, rhs, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False + scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, + is_2x2x=False + ) + lhs_quantizer = ( + quantizer_set.x + if lhs_q_dtype == jnp.float8_e4m3fn + else quantizer_set.dgrad + ) + rhs_quantizer = ( + quantizer_set.kernel + if rhs_q_dtype == jnp.float8_e4m3fn + else quantizer_set.dgrad ) + primitive_out = tex.gemm( - x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set + lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims ) - ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + ref_out = self._ref_gemm_with_jnp_dot(lhs, rhs, data_layout) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) + test_q_dtype = ( + jnp.float8_e5m2 + if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) + else jnp.float8_e4m3fn + ) + assert_allclose(primitive_out, ref_out, dtype=test_q_dtype) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - def test_dense_grad_bf16(self, m, n, k): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_bf16(self, m, n, k, with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -932,9 +983,11 @@ def ref_func(x, w, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + use_jax_dot_for_gemm(enabled=with_jax_gemm) + data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -956,7 +1009,8 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True + scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, + is_2x2x=True ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 @@ -969,10 +1023,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2) @pytest.fixture(name="random_inputs") @@ -996,19 +1050,14 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") + use_jax_dot_for_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1025,8 +1074,8 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, is_2x2x=True, ) @@ -1072,32 +1121,27 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) - assert_allclose(prim_out, ref_out, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias + self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") + # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1123,8 +1167,8 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, is_2x2x=True, ) @@ -1149,26 +1193,26 @@ def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ) ) - def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): + def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): + dim_nums = ((1,), (0,)), ((), ()) + ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - # TODO: replace gemm with jnp.dot - linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,))) + + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, dim_nums) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = _jax_act_lu(linear_1_out, activation_type) - linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) + act_out = _jax_act_lu(linear_1_out, activation_type) + + linear_2_out = jax.lax.dot_general(act_out, kernel_2, dim_nums) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) - return linear_2_out - - def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): - return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2)) + return jnp.mean(linear_2_out) value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) @@ -1193,18 +1237,18 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) # E5M2 * E5M2 is not supported diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d59e130530..ab79e2eae4 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -487,6 +487,13 @@ class BaseTester: runner = BaseRunner + def use_jax_dot_for_gemm(self, enabled=False): + """Enable/disable TE custom cuBLAS GEMM primitive.""" + if enabled: + os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' + elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: + os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" QuantizeConfig.finalize() # Ensure FP8 disabled. @@ -499,16 +506,20 @@ def test_backward(self, data_shape, dtype, attrs): @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): + @pytest.mark.parametrize("with_jax_gemm", (False, True)) + def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): """Test forward with fp8 enabled""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): + @pytest.mark.parametrize("with_jax_gemm", (False, True)) + def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): """Test backward with fp8 enabled""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ce66bba3cf..4d0a4c1bbd 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -985,6 +985,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -993,6 +994,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1037,6 +1039,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1090,6 +1096,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1100,6 +1107,7 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1113,13 +1121,42 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) + scale = jnp.empty((), jnp.float32) + act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive - if not PrimitiveClass.enabled(): + if ( + not PrimitiveClass.enabled() + or (quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE) + ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) + if quantizer is None: + output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + dz, + x, + scale, + # outputs float32 for dbias accumulation + out_dtype=(jnp.float32 if is_dbias else x.dtype), + # default value for no scaling, TE/common ignore this value when scale is unset + scaling_mode=ScalingMode.NO_SCALING.value, + is_2x=False, # unused + scale_dtype=jnp.float32, # unused + is_dbias=False, + act_enum=act_type_id, + act_len=act_len, + is_outer=True, + ) + output = output.astype(x.dtype) + dbias = None + if is_dbias: + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) + + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype, + ), dbias + + return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): @@ -1145,31 +1182,6 @@ def quantize_dact_dbias( if war_output is not None: return war_output - scale = jnp.empty((), jnp.float32) - - act_type_id = ActivationEnum[activation_type] - - if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( - dz, - x, - scale, - # outputs float32 for dbias accumulation - out_dtype=(jnp.float32 if is_dbias else x.dtype), - # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused - scale_dtype=jnp.float32, # unused - is_dbias=False, - act_enum=act_type_id, - act_len=act_len, - is_outer=True, - ) - dbias = None - if is_dbias: - dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - return output.astype(x.dtype), dbias - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( @@ -1183,7 +1195,7 @@ def quantize_dact_dbias( ) return out, dbias - if isinstance(quantizer, DelayedScaleQuantizer): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # TE/common dact_dbias_quantize does not support gated act yet @@ -1243,6 +1255,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1252,6 +1265,7 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1262,5 +1276,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, + noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d3c23015c1..d553da10f3 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,19 +3,25 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict -from functools import partial, reduce -import operator import math +import operator +from collections.abc import Iterable +from typing import Tuple, Sequence, Union +from functools import partial, reduce, lru_cache + import jax import jax.numpy as jnp +from jax import dtypes +from jax.sharding import NamedSharding, PartitionSpec + +import transformer_engine_jax as tex from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize - from ..quantize import ( ScaledTensor, + ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, @@ -25,9 +31,18 @@ QuantizeLayout, noop_quantizer_set, ) +from ..sharding import get_padded_spec -__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] +__all__ = [ + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_contracting_dims", + "is_gemm_with_all_layouts_supported", +] num_cublas_streams = get_num_compute_streams() @@ -35,16 +50,517 @@ def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if get_device_compute_capability(0) >= 90: + if tex.get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 -def is_gemm_with_all_layouts_supported() -> False: +def is_gemm_with_all_layouts_supported() -> bool: """Return True if using blackwell, False otherwise.""" return get_device_compute_capability(0) >= 100 +def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: + """Convert relative (negative) indexes to absolute dimension numbers.""" + dims_ = dims if isinstance(dims, Iterable) else (dims, ) + if len(dims_) == 0: + return dims_ + return tuple( ndim + dim if dim < 0 else dim for dim in dims_ ) + + +def get_non_contracting_dims(ndim, contracting_dims): + """Return a tuple of dimensions not included in the contracting dimensions.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(dim for dim in range(ndim) if dim not in contracting_dims) + + +def transpose_contracting_dims(ndim, contracting_dims): + """Compute the new dimension numbers for contracting dimensions after a transpose.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(ndim - i - 1 for i in contracting_dims)[::-1] + + +def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: + lhs, rhs, e4m3, e5m2, e8m0 = map( + dtypes.canonicalize_dtype, + ( + lhs_dtype, + rhs_dtype, + jnp.float8_e4m3fn, + jnp.float8_e5m2, + jnp.uint8 # replace with jnp.float8_e8m0 when JAX/XLA merges support + ) + ) + + # MXFP8 GEMM needs both operands to be MXFP8 (uint8 for now until JAX merges float8_e8m0) + if lhs is e8m0 and rhs is e8m0: + return True + + # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) + if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): + return True + + # Any other combination of data types is not supported + return False + + +def _get_gemm_layout( + operand_ndims: Tuple[int, int], + contracting_dims: Tuple[Sequence[int], Sequence[int]] +) -> Tuple[bool, bool]: + lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting + rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting + return lhs_is_transposed, rhs_is_transposed + + +def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_contracting_dims, rhs_contracting_dims = map( + sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims + ) + + lhs_q = lhs + rhs_q = rhs + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + lhs_q = lhs_quantizer.quantize( + lhs, + is_rowwise=True, + is_colwise=False, + flatten_axis=( + max(lhs_contracting_dims) + 1 + if lhs_is_transposed + else min(lhs_contracting_dims) + ), + ) + if lhs_is_transposed: + # Manually update data layout and columnwise flag to avoid transposing already + # transposed data + lhs_q.data_layout = "T" + lhs_q.is_colwise = True + + if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + rhs_q = rhs_quantizer.quantize( + rhs, + is_rowwise=True, + is_colwise=False, + flatten_axis=( + min(rhs_contracting_dims) + if rhs_is_transposed + else max(rhs_contracting_dims) + 1 + ), + ) + if not rhs_is_transposed: + # Manually update data layout and columnwise flag to avoid transposing already + # transposed data + rhs_q.data_layout = "T" + rhs_q.is_colwise = True + + return lhs_q, rhs_q + + +class GemmPrimitive(BasePrimitive): + """ + Primitive for cuBLAS GEMM + """ + + name = "te_gemm_ffi" + multiple_results = True + impl_static_args = (6, 7, 8, 9, 10, 11, 12) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, + contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + # Sanity-check operand layouts and types + operand_ndims = (lhs.ndim, rhs.ndim) + ( + lhs_contracting_dims, + rhs_contracting_dims, + ) = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_contracting_size, rhs_contracting_size = map( + lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims) + ) + assert lhs_contracting_size == rhs_contracting_size, ( + "cuBLAS GEMM operands have incompatible contracting dimensions: " + f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) + + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) + if scaling_mode != ScalingMode.NO_SCALING: + assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + assert lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0, ( + "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + ) + if ( + scaling_mode != ScalingMode.MXFP8_1D_SCALING + and not tex.is_non_nt_fp8_gemm_supported() + ): + assert not lhs_is_transposed and rhs_is_transposed, ( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) + + # Determine output shape and dtype + assert dtypes.canonicalize_dtype(out_dtype).itemsize > 1, ( + "cuBLAS GEMM custom op does not support 8-bit quantized output types." + ) + lhs_non_contracting_shape, rhs_non_contracting_shape = map( + lambda shape, dims: [ shape[dim] for dim in range(len(shape)) if dim not in dims ], + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims) + ) + out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # Validate bias + bias_shape = (0, ) + bias_dtype = out_dtype + if fuse_bias: + expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) + if not grad: + assert bias.size == expected_bias_size, ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({expected_bias_size}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {bias_dtype} but found {bias.dtype}." + ) + bias_shape = bias.shape + else: + bias_shape = rhs_non_contracting_shape + bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + + # Validate pre-GeLU + pre_gelu_shape = (0, ) + pre_gelu_dtype = out_dtype + if fuse_gelu: + pre_gelu_shape = out_shape + if grad: + pre_gelu_ndim = len(pre_gelu_shape) + assert ( + gelu_input.ndim == pre_gelu_shape + and all(gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)) + ), ( + "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " + f"expected {pre_gelu_shape} but found {gelu_input.shape}." + ) + assert gelu_input.dtype == out_dtype, ( + "cuBLAS GEMM pre-GeLU tensor has incorrect data type, " + f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." + ) + pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + + # Need extra workspace for swizzled scale factors + lhs_swizzle_size = 0 + rhs_swizzle_size = 0 + swizzle_dtype = jnp.uint8 # replace with jnp.float8_e8m0 when JAX merges support + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_swizzle_size = lhs_scale_inv.size + rhs_swizzle_size = rhs_scale_inv.size + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size, ), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size, ), dtype=swizzle_dtype) + + # Declare cuBLAS workspace + workspace = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), + dtype=jnp.uint8) + + return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + + @staticmethod + def outer_abstract(*args, **kwargs): + outputs = GemmPrimitive.abstract(*args, **kwargs) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, + contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + del out_dtype + lhs_aval, _, rhs_aval, *_ = ctx.avals_in + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout((lhs_aval.ndim, rhs_aval.ndim), + (lhs_cdims, rhs_cdims)) + + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + kwargs = { + "scaling_mode" : int(scaling_mode.value), + "lhs_axis_boundary" : max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary" : min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed" : lhs_transposed, + "rhs_transposed" : rhs_transposed, + "fuse_bias" : fuse_bias, + "fuse_gelu" : fuse_gelu, + "grad" : grad, + } + + operand_output_aliases = {} + if fuse_bias and not grad: + operand_output_aliases.update({ 4 : 1 }) # bias <-> bias_grad + if fuse_gelu and grad: + operand_output_aliases.update({ 5 : 2 }) # gelu_input <-> pre_gelu_out + + return jax.ffi.ffi_lowering( + GemmPrimitive.name, + operand_output_aliases=operand_output_aliases, + )(ctx, *args, **kwargs) + + @staticmethod + def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, contracting_dims, + scaling_mode, fuse_bias, fuse_gelu, grad): + outputs = GemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + ) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def batcher(batched_args, batch_dims, out_dtype, contracting_dims, scaling_mode, fuse_bias, + fuse_gelu, grad): + assert GemmPrimitive.outer_primitive is not None + lhs, _, rhs, *_ = batched_args + lhs_bdims, *_ = batch_dims + + # Output is batched like LHS only if LHS is batched and RHS is not + out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None, ) + bias_bdims = (None, ) # Bias is never batched + pre_gelu_bdims = (None, ) # Pre-GeLU output, if exists, is batched like GEMM output + if fuse_gelu and not grad: + pre_gelu_bdims = out_bdims + + return ( + GemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + ), + (out_bdims, bias_bdims, pre_gelu_bdims) + ) + + @staticmethod + def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse_bias, + fuse_gelu, grad, mesh, arg_infos, result_infos): + del out_dtype, scaling_mode, result_infos + + # Check contracting dimensions + lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) + operand_ndims = (len(lhs_spec), len(rhs_spec)) + lhs_contracting_dims, rhs_contracting_dims = map( + sanitize_dims, operand_ndims, contracting_dims + ) + lhs_contracting_specs, rhs_contracting_specs = map( + lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None], + (lhs_spec, rhs_spec), + (lhs_contracting_dims, rhs_contracting_dims) + ) + assert len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1, ( + "cuBLAS GEMM operands can have only one sharded contracting dimension." + ) + lhs_contracting_spec, rhs_contracting_spec = map( + lambda spec: None if len(spec) == 0 else spec[0], + (lhs_contracting_specs, rhs_contracting_specs) + ) + assert lhs_contracting_spec == rhs_contracting_spec, ( + "cuBLAS GEMM operands must have the same sharding in contracting dimensions." + ) + + # Sanity check leading dimensions, allow for simultaneous batch and sequence sharding + lhs_leading_dims, rhs_leading_dims = map( + get_non_contracting_dims, operand_ndims, (lhs_contracting_dims, rhs_contracting_dims) + ) + lhs_leading_specs, rhs_leading_specs = map( + lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None ], + (lhs_spec, rhs_spec), + (lhs_leading_dims, rhs_leading_dims) + ) + assert len(lhs_leading_specs) <= 1 and len(rhs_leading_specs) <= 1, ( + "cuBLAS GEMM operands cannot have more than one sharded leading dimensions. This error " + "usually means a sequence-parallel operand was not all-gathered before the GEMM op." + ) + + # Determine output sharding + lhs_leading_spec, rhs_leading_spec = map( + lambda spec: None if len(spec) == 0 else spec[0], + (lhs_leading_specs, rhs_leading_specs) + ) + out_spec = (lhs_leading_spec, rhs_leading_spec) + if operand_ndims[0] > 2 and operand_ndims[1] == 2: + # Restore batch dimensions/sharding to the output + out_spec = (*lhs_leading_specs, rhs_leading_spec) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Bias gradient sharding inherits the RHS contracting spec + bias_spec = (None, ) + if fuse_bias and grad: + bias_spec = (rhs_contracting_spec, ) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + + # Pre-GeLU sharding matches output sharding + pre_gelu_spec = (None, ) + if fuse_gelu and not grad: + pre_gelu_spec = out_spec + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + + return (out_sharding, bias_sharding, pre_gelu_sharding) + + @staticmethod + def partition(out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, + mesh, arg_infos, result_infos): + out_shardings = GemmPrimitive.infer_sharding_from_operands( + out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, + mesh, arg_infos, result_infos + ) + output_spec = out_shardings[0].spec + + # Operand shardings are already guarded with asserts so leave them unchanged here + lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + + # Any distributed scales (e.g. MXFP8) need to be gathered + scale_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Bias has to be sharded same as the trailing dimension of the GEMM output + bias_spec = (None, ) + if fuse_bias and not grad: + bias_spec = (output_spec[-1], ) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + + # Pre-GeLU output has to be sharded same as the GEMM output + pre_gelu_spec = (None, ) + if fuse_gelu and grad: + pre_gelu_spec = output_spec + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + + arg_shardings = ( + lhs_sharding, + scale_sharding, + rhs_sharding, + scale_sharding, + bias_sharding, + pre_gelu_sharding, + ) + + return mesh, GemmPrimitive.impl, out_shardings, arg_shardings + + +register_primitive(GemmPrimitive) + + +@lru_cache(maxsize=1) +def gemm_uses_jax_dot() -> bool: + """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" + return not GemmPrimitive.enabled() + + +def _te_gemm( + lhs: Union[jax.Array, ScaledTensor], + rhs: Union[jax.Array, ScaledTensor], + bias: jax.Array = None, + gelu_input: jax.Array = None, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1, ), (-2, )), + fuse_bias: bool = False, + fuse_gelu: bool = False, + grad: bool = False, +) -> Tuple[jax.Array, ...]: + # Prepare non-quantized GEMM operands + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_contracting_dims, rhs_contracting_dims = map( + sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims + ) + + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + + # Extract GEMM custom op inputs from quantized operands + if isinstance(lhs_q, ScaledTensor): + assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) + if isinstance(lhs_q, ScaledTensor2x): + # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise + # shape. Since we have access to both row-wise and column-wise tensors, we always + # choose the one that avoids transposing LHS in the GEMM kernel to comply with the + # NT-layout restriction for FP8 GEMM on Hopper. + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + scaling_mode = lhs_q.scaling_mode + lhs_data = lhs_q.data + lhs_scale_inv = lhs_q.scale_inv + if lhs_q.data_layout == "T": + lhs_contracting_dims = transpose_contracting_dims(lhs_q.ndim, lhs_contracting_dims) + + if isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) + if isinstance(rhs_q, ScaledTensor2x): + # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise + # shape. Since we have access to both row-wise and column-wise tensors, we always + # choose the one that avoids transposing LHS in the GEMM kernel to comply with the + # NT-layout restriction for FP8 GEMM on Hopper. + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) + rhs_data = rhs_q.data + rhs_scale_inv = rhs_q.scale_inv + if rhs_q.data_layout == "T": + rhs_contracting_dims = transpose_contracting_dims(rhs_q.ndim, rhs_contracting_dims) + + # Dummy empties for bias and gelu + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype + if bias is None or not (fuse_bias and not grad): + bias = jnp.empty(0, dtype=out_dtype) + if gelu_input is None or not (fuse_gelu and grad): + gelu_input = jnp.empty(0, dtype=out_dtype) + + return GemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=(lhs_contracting_dims, rhs_contracting_dims), + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -221,11 +737,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) def _calculate_remaining_shape(shape, contracting_dims): - return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) - - -def _transpose_contract_dims(ndim, contracting_dims): - return tuple(ndim - i - 1 for i in contracting_dims)[::-1] + contracting_dims_ = sanitize_dims(len(shape), contracting_dims) + return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_) # Apply jit to guarantee correctness of FP8 GEMM. @@ -233,9 +746,9 @@ def _transpose_contract_dims(ndim, contracting_dims): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract) + lhs_contract = transpose_contracting_dims(lhs.data.ndim, lhs_contract) if rhs.data_layout == "T": - rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract) + rhs_contract = transpose_contracting_dims(rhs.data.ndim, rhs_contract) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -306,12 +819,12 @@ def _jax_gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, ) -> jnp.ndarray: """ FP8 GEMM via JAX """ - dim_nums = (contracting_dims, ((), ())) def _jax_gemm_fp8_impl(lhs, rhs): @@ -331,65 +844,116 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - return _jax_gemm_fp8_impl(lhs, rhs) - - if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): - if quantizer_set != noop_quantizer_set: - assert type(quantizer_set.x) is type(quantizer_set.kernel) - (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - # Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm) - lhs_q = quantizer_set.x.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = quantizer_set.kernel.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) - if ( - isinstance(lhs, jnp.ndarray) - and isinstance(rhs, jnp.ndarray) - and quantizer_set == noop_quantizer_set - ): - return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) + if isinstance(lhs_q, ScaledTensor) or isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor), ( + "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." + ) + return _jax_gemm_fp8_impl(lhs_q, rhs_q) - raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array") + return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: QuantizerSet = noop_quantizer_set, -) -> jnp.ndarray: - """General matrix multiplication with optional quantization. - - Args: - lhs: First input matrix. - rhs: Second input matrix. - contracting_dims: Tuple of two sequences representing the contracting dimensions. - The first sequence represents the contracting dimensions of the first matrix, - and the second sequence represents the contracting dimensions of the second matrix. - quantizer_set: Set of quantizers for FP8 quantization of the output. - If None, no quantization is applied and the output has the same dtype as the inputs. - - Returns: - If quantizer_set is None: - The matrix multiplication result. - Shape: (M, N) - Dtype: Same as input dtype - If quantizer_set is provided: - A ScaledTensor containing the quantized matrix multiplication result. + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + **kwargs, +) -> Tuple[jnp.ndarray, ...]: + r"""General matrix multiplication with optional quantization. + + Parameters + ---------- + lhs: Union[jax.Array, ScaledTensor] + Left-hand side operand in the matrix multiplication. + rhs: Union[jax.Array, ScaledTensor] + Right-hand side operand in the matrix multiplication. + lhs_quantizer: Quantizer, default = None + Object for down-casting the LHS operand for quantized GEMM. + rhs_quantizer: Quantizer, default = None + Object for down-casting the RHS operand for quantized GEMM. + contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (-2, )) + Tuple of two sequences representing the contracting dimensions. The first sequence + represents the contracting dimensions of the LHS operand, and the second sequence + represents the contracting dimensions of the RHS operand. + bias: jax.Array, default = None + Optional additive bias term, required for forward GEMM with bias fusion. Only supported + with TE's custom call to cuBLAS GEMM. + gelu_input: jax.Array, default = None + Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only + supported with TE's custom call to cuBLAS GEMM. + fuse_bias: bool, default = False + Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with + TE's custom call to cuBLAS GEMM. + fuse_gelu: bool, default = False + Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported + with TE's custom call to cuBLAS GEMM. + grad: bool, default = False + Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with + TE's custom call to cuBLAS GEMM. + + Returns + ------- + jax.Array: + Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the + GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution + when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and + `grad=False`. + Optional[jax.Array]: + Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call + to cuBLAS GEMM. + Optional[jax.Array]: + Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input + to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to + compute the GeLU contribution to the gradient. Only supported with TE's custom call to + cuBLAS GEMM. """ + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility + if lhs_quantizer is None or rhs_quantizer is None: + quantizer_set = kwargs.get("quantizer_set", None) + if quantizer_set is not None: + lhs_quantizer = quantizer_set.x + rhs_quantizer = quantizer_set.kernel + + # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + if gemm_uses_jax_dot(): + assert kwargs.get("bias", None) is None and not kwargs.get("fuse_bias", False), ( + "TE GEMM was invoked with bias fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert kwargs.get("gelu_input", None) is None and not kwargs.get("fuse_gelu", False), ( + "TE GEMM was invoked with GeLU fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + + outputs = _te_gemm( + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims, + **kwargs + ) - return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) + # Discard empty outputs + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) + grad = kwargs.get("grad", False) + clean_outputs = outputs[0] # first output is the final result and is never empty + if (fuse_bias and grad) or (fuse_gelu and not grad): + clean_outputs = (outputs[0], ) + if fuse_bias and grad: # only return bias gradient if it exists + clean_outputs += (outputs[1], ) + if fuse_gelu and not grad: # only return pre-GeLU output if it exists + clean_outputs += (outputs[2], ) + return clean_outputs def grouped_gemm( diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 87f1c1913a..94dfaa45a4 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to calculate dbias separately. This function checks if the workaround should be applied. """ + if quantizer is None: + return False + arch_l_100 = False for local_gpu_id in range(len(jax.local_devices())): if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: arch_l_100 = True break + # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, + # but this fails when bias fusion is turned on with arch < 100. + force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() return ( - quantizer is not None - and quantizer.q_layout == QuantizeLayout.ROWWISE + (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 and is_dbias ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 07ebb33114..f2829c3a58 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1276,6 +1276,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1292,6 +1293,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1319,6 +1321,11 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") + if quantizer is None and noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), mu, rsigma + return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 11b3cdc2a3..e84fb3a976 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -538,11 +538,12 @@ def _jax_quantize( def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): - assert flatten_axis < 0 + sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype dbias = jnp.sum( dx.astype(jnp.float32), - axis=tuple(range(dx.ndim + flatten_axis)), + axis=tuple(range(sum_axis)), keepdims=False, ) return dbias.astype(dtype) @@ -568,6 +569,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -577,24 +579,28 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + # Early-exit for non-quantized call dq_dtype = dq_dtype or x.dtype - - PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if not PrimitiveClass.enabled(): + if quantizer is None: + dbias = None if is_dbias: - return _jax_quantize_dbias( - x, - quantizer=quantizer, - dq_dtype=dq_dtype, + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + if noop_scaled_tensor: + # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() + # always works. + return ScaledTensorFactory.create_2x( + x, None, x, None, ScalingMode.NO_SCALING, dq_dtype=x.dtype, data_layout="NN", flatten_axis=flatten_axis, - ) - return ( - _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), - None, - ) + ), dbias + return x, dbias - # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, + # fall back on the native-JAX quantize implementation + PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive + if ( + quantizer.q_layout == QuantizeLayout.COLWISE + or not PrimitiveClass.enabled() + ): if is_dbias: return _jax_quantize_dbias( x, @@ -606,9 +612,8 @@ def _quantize_dbias_impl( _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), None, ) - scale = jnp.empty((), jnp.float32) - # TE/common dbias_quantize does not support 1x on arch < 100 + # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out, _ = _quantize_dbias_impl( x=x, @@ -620,29 +625,23 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - if quantizer is None: - if is_dbias: - return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - return x, None - + scale = jnp.empty((), jnp.float32) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) scale = compute_scale_from_amax(amax, quantizer.q_dtype) - - if isinstance(quantizer, DelayedScaleQuantizer): + elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale - is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) # It is faster to use 1x quantization for tensor scaling + is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() and is_1x_kernel_supported ) - q_layout = quantizer.q_layout if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE @@ -666,7 +665,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if force_1x_quantization: colwise_scale_inv = rowwise_scale_inv if q_layout == QuantizeLayout.ROWWISE: @@ -698,6 +697,7 @@ def quantize( x: jnp.ndarray, quantizer: Quantizer, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -707,6 +707,8 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer + is None. Returns: A ScaledTensor containing the quantized input tensor. @@ -715,6 +717,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) return out @@ -724,6 +727,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -734,6 +738,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when + quantizer is None. Returns: A tuple containing: @@ -743,7 +749,8 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis + dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index aa257abe95..553534a205 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +// GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index d57d4682ca..02822a3f0a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -4,12 +4,16 @@ * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" +#include "transformer_engine/swizzle.h" #include +#include +#include #include "../extensions.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" +#include "common/util/string.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 @@ -17,6 +21,175 @@ namespace transformer_engine { namespace jax { +std::tuple> xla_buffer_to_nvte_gemm_operand( + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + // Set tensor data with collapsed 2D shape + auto buffer_dims = buffer.dimensions(); + std::vector input_shape = {product(buffer_dims, 0, axis_boundary), + product(buffer_dims, axis_boundary, buffer_dims.size())}; + auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type()); + TensorWrapper input(get_nvte_scaling_mode(scaling_mode)); + + if (rowwise) { + input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + } else { + input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + } + + // Set scaling factor for quantized tensors + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); + + std::vector scale_shape(scale_inv.dimensions().begin(), scale_inv.dimensions().end()); + auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) + ? DType::kFloat8E8M0 + : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + + // Swizzle scaling factors for MXFP8 + if (is_block_scaling(scaling_mode)) { + // Get the swizzle buffer + NVTE_CHECK(swizzled_scale_inv->element_count() > 0, + "Missing swizzled inverse scale buffer in the JAX primitive."); + auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto swizzled_scale_inv_dtype = convert_ffi_datatype_to_te_dtype( + swizzled_scale_inv->element_type()); + NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); + + // Create tensor to hold swizzled scale factor + TensorWrapper output(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } + + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + + // Set swizzled scales into the input tensor + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + scale_shape); + } + } + } + + return std::make_tuple(std::move(input), input_shape); +} + +Error_Type GemmFFI( + cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, + Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type lhs_swizzle, + Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad) { + // Operands (this includes swizzling MXFP8 scaling factors) + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when + // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) + bool always_rowwise = ( + scaling_mode == JAXX_Scaling_Mode::NO_SCALING || ( + is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; + bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, rhs, rhs_scale_inv, lhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + + // Output tensor + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), "cuBLAS GEMM output buffer size is incorrect, " + "expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); + + // Bias input to forward pass or bias gradient output from backward pass + void* bias_ptr = nullptr; + std::vector bias_shape = {0}; + DType bias_dtype = out_dtype; + if (fuse_bias) { + if (!grad) { + NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); + } + bias_ptr = bias_grad->untyped_data(); + bias_shape.at(0) = bias_grad->dimensions().front(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + } + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + // Pre-GeLU output from forward pass or input to backward pass + void* pre_gelu_ptr = nullptr; + std::vector pre_gelu_shape = {0}; + DType pre_gelu_dtype = out_dtype; + if (gelu_input.element_count() > 0) { + if (grad) { + NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"); + } + pre_gelu_ptr = pre_gelu_out->untyped_data(); + pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1), + static_cast(pre_gelu_out->dimensions().back())}; + pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type()); + } + auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); + + // cuBLAS workspace + std::vector workspace_shape = {static_cast(workspace->element_count())}; + auto workspace_ = TensorWrapper(workspace->untyped_data(), workspace_shape, DType::kByte); + + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, false, + num_math_sm, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad"), + FFI_CudaGraph_Traits); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 03194e9d72..b4ed3399d9 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { CURRENT_TENSOR_SCALING = 3, }; +inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING + || mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); +} + +inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2d7801cc20..afbeb644c1 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -55,6 +55,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + // GEMM + dict["te_gemm_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); + // Grouped GEMM dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), @@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); + m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8834f4f73c..0fe8d1bb9a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -48,7 +48,7 @@ def dense( Transformed output tensor """ # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set: + if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) output = tex.gemm(x, kernel, contracting_dims) if bias is not None: @@ -90,39 +90,50 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims - - flatten_axis_x = -len(x_contracting_dims) - flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + x_cdims, k_cdims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + x_is_transposed = x.ndim - 1 not in x_cdims + k_is_transposed = kernel.ndim - 1 in k_cdims + assert not x_is_transposed and not k_is_transposed, ( + "Forward-mode Dense layer implementation only supports 'NN' layout inputs." + ) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = tex.quantize(x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, + noop_scaled_tensor=True) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) + rowwise_x = casted_x.get_rowwise_tensor() + colwise_x = casted_x.get_colwise_tensor() casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + kernel, flatten_axis=max(k_cdims) + 1, quantizer=quantizer_set.kernel, + noop_scaled_tensor=True ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) + colwise_k = casted_kernel.get_colwise_tensor() + rowwise_k = casted_kernel.get_rowwise_tensor() - # GEMM NN + # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) + # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + use_bias = bias is not None output = tex.gemm( - casted_x.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + rowwise_x, + colwise_k, + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, - x.shape, - kernel.shape, + colwise_x, + rowwise_k, use_bias, quantizer_set, - flatten_axis_k, ) return output, ctx @@ -135,46 +146,49 @@ def _dense_bwd_rule( Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( - colwise_casted_x, - rowwise_casted_kernel, - x_shape, - kernel_shape, + colwise_x, + rowwise_k, use_bias, quantizer_set, - flatten_axis_k, ) = ctx + # Original non-contracting dimensions in the forward pass are contracting dimensions for the + # backward pass. + fwd_x_cdims, fwd_k_cdims = map( + tex.sanitize_dims, (colwise_x.ndim, rowwise_k.ndim), contracting_dims + ) + fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_x.ndim, fwd_x_cdims) + fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_k.ndim, fwd_k_cdims) + # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input + flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) - # GEMM NT - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_contracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) - ) - # k_non_contracting_dims - k_contracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims - ) + # Prepare DGRAD and WGRAD operands and contracting dims + rowwise_g = casted_grad.get_rowwise_tensor() + rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) + colwise_g = casted_grad.get_colwise_tensor() + colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) + + # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, - (g_contracting_dim, k_contracting_dim), + rowwise_g, + rowwise_k, + contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), + grad=True ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) - # GEMM TN - # x_non_contracting_dims - g_contracting_dim = x_contracting_dim = tuple( - range(0, len(x_shape) - len(fwd_x_contracting_dims)) - ) - + # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) + # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) wgrad = tex.gemm( - colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) + colwise_x, + colwise_g, + contracting_dims=(fwd_x_non_cdims, colwise_g_cdims), + grad=True, ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 727ff78c2d..26b9d8fad5 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -173,12 +173,12 @@ def _layernorm_dense_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - x_contracting_dims = (len(x.shape) - 1,) - k_contracting_dims = (0,) + x_cdims = (x.ndim - 1,) + k_cdims = (0,) assert x.shape[-1] == kernel.shape[0] + # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) - casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -187,42 +187,50 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) - # Kernel in (hidden_in, hidden_out...) - flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + # Layernorm output (batch..., hidden_in) + rowwise_ln_out = casted_ln_out.get_rowwise_tensor() + colwise_ln_out = casted_ln_out.get_colwise_tensor() + + # Kernel (hidden_in, hidden_out) + flatten_axis = 1 - kernel.ndim + casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + noop_scaled_tensor=True) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) - # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out...) - output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), - ) + rowwise_kernel = casted_kernel.get_rowwise_tensor() + colwise_kernel = casted_kernel.get_colwise_tensor() + # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) + # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) use_bias = bias is not None - if use_bias: + output = tex.gemm( + rowwise_ln_out, + colwise_kernel, + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False, + ) + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, - casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, - x.shape, - kernel.shape, + colwise_ln_out, + rowwise_kernel, mu, rsigma, x, gamma, beta, - x_contracting_dims, - k_contracting_dims, + x_cdims, + k_cdims, use_bias, quantizer_set, - flatten_axis, ) return output, ctx @@ -233,7 +241,7 @@ def _layernorm_dense_bwd_rule( zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes, # pylint: disable=unused-argument + dot_input_axes, kernel_axes, ctx, grad, @@ -250,57 +258,57 @@ def _layernorm_dense_bwd_rule( Tuple of gradients for all input parameters """ ( - colwise_casted_ln_out, - rowwise_casted_kernel, - x_shape, - kernel_shape, + colwise_ln_out, + rowwise_kernel, mu, rsigma, x, gamma, beta, - x_contracting_dims_in_fwd, - k_contracting_dims_in_fwd, + fwd_x_cdims, + fwd_k_cdims, use_bias, quantizer_set, - flatten_axis, ) = ctx + # Original non-contracting dimensions in the forward pass are contracting dimensions for the + # backward pass. + fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) + fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_kernel.ndim, fwd_k_cdims) + # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input + flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim - g_constracting_dim = tuple( - range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim) - ) - # k_non_contracting_dims - k_constracting_dim = tuple( - dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd - ) + # Prepare DGRAD and WGRAD operands and contracting dims + rowwise_g = casted_grad.get_rowwise_tensor() + rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) + colwise_g = casted_grad.get_colwise_tensor() + colwise_ln_out_cdims = fwd_x_non_cdims + colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) - # NT GEMM + # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel, - (g_constracting_dim, k_constracting_dim), - ) - - dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) - - g_constracting_dim = x_constracting_dim = tuple( - range(0, len(x_shape) - len(x_contracting_dims_in_fwd)) + rowwise_g, + rowwise_kernel, + contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), + grad=True ) + dgrad = with_sharding_constraint_by_logical_axes(dgrad, dot_input_axes) - # TN GEMM + # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) + # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) wgrad = tex.gemm( - colwise_casted_ln_out, - casted_grad.get_colwise_tensor(), - (x_constracting_dim, g_constracting_dim), + colwise_ln_out, + colwise_g, + contracting_dims=(colwise_ln_out_cdims, colwise_g_cdims), + grad=True, ) - wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) + # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad, x, @@ -312,6 +320,7 @@ def _layernorm_dense_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) + dx = with_sharding_constraint_by_logical_axes(dx, layernorm_input_axes) return dx, wgrad, dgamma, dbeta, dbias, quantizer_set diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e04b930233..8a42555a5e 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -22,7 +22,11 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type -from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set +from .quantize import ( + with_sharding_constraint_by_logical_axes, + QuantizerSet, + noop_quantizer_set +) from .sharding import get_non_contracting_logical_axes @@ -244,16 +248,16 @@ def _layernorm_mlp_fwd_rule( assert len(kernel_2.shape) == 2 assert kernel_1.shape[-2] == len(activation_type) - x_contracting_dims = (len(x.shape) - 1,) - k_contracting_dims = (0,) + x_cdims = (x.ndim - 1,) + k_cdims = (0,) - assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + assert x.shape[x_cdims[0]] == kernel_1.shape[k_cdims[0]] use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None + # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) - casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -262,49 +266,77 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + # FC1 kernel (hidden_in, act_len, hidden_out) + casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, + noop_scaled_tensor=True) + + # Prepare FC1 FPROP operands and layouts + rowwise_ln_out = casted_ln_out.get_rowwise_tensor() + rowwise_kernel_1 = casted_kernel_1.get_rowwise_tensor() + colwise_ln_out = casted_ln_out.get_colwise_tensor() + colwise_kernel_1 = casted_kernel_1.get_colwise_tensor() - # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out) + # FC1 GEMM: + # (batch..., hidden_in) x (hidden_in, act_len, hidden_out) = (batch..., act_len, hidden_out) + # FC1 FP8 GEMM: + # (batch..., hidden_in) x (hidden_out, act_len, hidden_in)^T = (batch..., act_len, hidden_out) + use_bias_1 = bias_1 is not None dot_1_output = tex.gemm( - casted_ln_out.get_rowwise_tensor(), - casted_kernel_1.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + rowwise_ln_out, + colwise_kernel_1, + bias=bias_1 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False, ) if dot_1_input_axes is not None and kernel_1_axes is not None: dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), + *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_cdims), + *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_cdims), ) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - if use_bias_1: + if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) - # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) - + # Activation (batch..., act_len, hidden_out) -> (batch..., hidden_out) + casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + noop_scaled_tensor=True) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) + # FC2 kernel (hidden_out, hidden_in) + casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, + noop_scaled_tensor=True) - # NN GEMM - # (batch..., hidden_in) x (hidden_out, hidden_in) + # Prepare FC2 FPROP operands and layouts + rowwise_act_out = casted_act_out.get_rowwise_tensor() + rowwise_kernel_2 = casted_kernel_2.get_rowwise_tensor() + colwise_act_out = casted_act_out.get_colwise_tensor() + colwise_kernel_2 = casted_kernel_2.get_colwise_tensor() + + # FC2 GEMM: + # (batch..., hidden_out) x (hidden_out, hidden_in) = (batch..., hidden_in) + # FC2 FP8 GEMM: + # (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dot_2_output = tex.gemm( - casted_act_out.get_rowwise_tensor(), - casted_kernel_2.get_colwise_tensor(), - (x_contracting_dims, k_contracting_dims), + rowwise_act_out, + colwise_kernel_2, + bias=bias_2 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + contracting_dims=(x_cdims, k_cdims), + grad=False ) - if use_bias_2: + if use_bias_2 and tex.gemm_uses_jax_dot(): bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) @@ -317,15 +349,13 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - casted_ln_out.get_colwise_tensor(), - casted_kernel_1.get_rowwise_tensor(), + colwise_ln_out, + rowwise_kernel_1, dot_1_output, - casted_act_out.get_colwise_tensor(), - casted_kernel_2.get_rowwise_tensor(), - x_contracting_dims, - k_contracting_dims, - kernel_1.shape, - kernel_2.shape, + colwise_act_out, + rowwise_kernel_2, + x_cdims, + k_cdims, use_bias_1, use_bias_2, quantizer_sets, @@ -362,22 +392,20 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, rsigma, gamma, beta, - colwise_casted_ln_out, - rowwise_casted_kernel_1, + colwise_ln_out, + rowwise_kernel_1, dot_1_output, - colwise_casted_act_out, - rowwise_casted_kernel_2, - x_contracting_dims_in_fwd, - k_contracting_dims_in_fwd, - kernel_1_shape, - kernel_2_shape, + colwise_act_out, + rowwise_kernel_2, + fwd_x_cdims, + fwd_k_cdims, use_bias_1, use_bias_2, quantizer_sets, @@ -385,82 +413,85 @@ def _layernorm_mlp_bwd_rule( ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets - # Since the sharding of outputs should be the same as dot_1's input + # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input + fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) + flatten_axis_grad = len(fwd_x_non_cdims) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad + grad, is_dbias=use_bias_2, flatten_axis=flatten_axis_grad, + quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, ) - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - g_contracting_dims_2 = tuple( - range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) - ) - # k_non_contracting_dims - k_contracting_dims_2 = tuple( - dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd - ) + # Prepare FC2 DGRAD and WGRAD operands and contracting dims + rowwise_g = casted_grad.get_rowwise_tensor() + rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) + fwd_k2_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_2.ndim, fwd_k_cdims) - # NT GEMM - # (batch..., hidden_out) x (hidden_in, hidden_out) + colwise_g = casted_grad.get_colwise_tensor() + colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) + colwise_act_out_cdims = tex.get_non_contracting_dims(colwise_act_out.ndim, fwd_x_cdims) + + # FC2 DGRAD GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) dgrad_2 = tex.gemm( - casted_grad.get_rowwise_tensor(), - rowwise_casted_kernel_2, - (g_contracting_dims_2, k_contracting_dims_2), + rowwise_g, + rowwise_kernel_2, + contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), + grad=True ) - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - x_contracting_dims = g_contracting_dims = tuple( - range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) - ) - - # TN GEMM - # (hidden, batch...,) x (hidden, batch...) + # FC2 WGRAD GEMM: + # (batch..., hidden_out)^T x (batch..., hidden_in) = (hidden_out, hidden_in) + # FC2 WGRAD FP8 GEMM: + # (hidden_out, batch...) x (hidden_in, batch...)^T = (hidden_out, hidden_in) wgrad_2 = tex.gemm( - colwise_casted_act_out, - casted_grad.get_colwise_tensor(), - (x_contracting_dims, g_contracting_dims), + colwise_act_out, + colwise_g, + contracting_dims=(colwise_act_out_cdims, colwise_g_cdims), + grad=True, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) + # Activation gradient w/ bias fusion (batch..., hidden_out) -> (batch.., act_len, hidden_out) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( dgrad_2, dot_1_output, activation_type=activation_type, is_dbias=use_bias_1, - quantizer=ffn2_quantizer_set.dgrad, + quantizer=ffn1_quantizer_set.dgrad, + noop_scaled_tensor=True, ) - # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim - dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim - g_contracting_dims_1 = tuple( - range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) - ) - # k_non_contracting_dims - k_contracting_dims_1 = tuple( - dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd - ) + # Prepare FC1 DGRAD and WGRAD operands and contracting dims + rowwise_dact_out = casted_dact_out.get_rowwise_tensor() + rowwise_dact_out_cdims = tuple(range(flatten_axis_grad, rowwise_dact_out.ndim)) + colwise_dact_out = casted_dact_out.get_colwise_tensor() + colwise_dact_out_cdims = tex.get_non_contracting_dims(casted_dact_out.ndim, rowwise_dact_out_cdims) + fwd_k1_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_1.ndim, fwd_k_cdims) - # NT GEMM + # FC1 DGRAD GEMM: + # (batch..., act_len, hidden_out) x (hidden_in, act_len, hidden_out)^T = (batch..., hidden_in) dgrad_1 = tex.gemm( - casted_dact_out.get_rowwise_tensor(), - rowwise_casted_kernel_1, - (g_contracting_dims_1, k_contracting_dims_1), + rowwise_dact_out, + rowwise_kernel_1, + contracting_dims=(rowwise_dact_out_cdims, fwd_k1_non_cdims), + grad=True ) - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) - # TN GEMM - # (hidden, batch...) x (hidden, batch...) + # FC1 WGRAD GEMM: + # (batch..., hidden_in)^T x (batch..., act_len, hidden_out) = (hidden_in, act_len, hidden_out) + # FC1 WGRAD FP8 GEMM: + # (hidden_in, batch...) x (hidden_out, act_len, batch...)^T = (hidden_in, act_len, hidden_out) wgrad_1 = tex.gemm( - colwise_casted_ln_out, - casted_dact_out.get_colwise_tensor(), - (x_contracting_dims, g_contracting_dims), + colwise_ln_out, + colwise_dact_out, + contracting_dims=(fwd_x_non_cdims, colwise_dact_out_cdims), + grad=True, ) - wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) + # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad_1, x, @@ -472,6 +503,7 @@ def _layernorm_mlp_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) + dx = with_sharding_constraint_by_logical_axes(dx, norm_input_axes) return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 06a2562fb1..2459190f1a 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -36,6 +36,22 @@ def dequantize(scaled_tensor): """Dequantizing given tensor to higher precision.""" +@dataclass +class NoopDequantizer(Dequantizer): + """No-op Dequantizer Class""" + + @staticmethod + def _dequantize_func(data, *args, **kwargs): + """A no-op dequantize function that returns the data without any changes.""" + del args, kwargs + return data + + @staticmethod + def dequantize(scaled_tensor): + """A no-op dequantize function that simply returns the data array in the ScaledTensor.""" + return scaled_tensor.data + + class TensorScaleDequantizer(Dequantizer): """ TensorScaling Dequantizer Class @@ -152,6 +168,7 @@ def dequantize(scaled_tensor): ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NO_SCALING: NoopDequantizer, } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 02b1a1a99e..4761816220 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -55,6 +55,11 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, *aux_data) + @property + @abstractmethod + def ndim(self): + """Number of dimensions of the underlying quantized array.""" + @abstractmethod def dequantize(self): """Dequantizes the tensor back to its original precision. @@ -136,25 +141,40 @@ def __post_init__(self): 0 < self.flatten_axis < len(self.data.shape) ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis - ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + if self.scaling_mode == ScalingMode.NO_SCALING: + self.scale_inv = jnp.empty((1, ), dtype=jnp.float32) + + else: + expected_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis ) + if self.scale_inv.shape != expected_scale_shape: + assert self.scale_inv.shape == expected_unpadded_scale_shape, ( + f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" + f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" + f" {self.scale_inv.shape}" + ) + expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( + self.data.shape, self.is_colwise, is_padded=False, + flatten_axis=self.flatten_axis + ) + if self.scale_inv.shape != expected_scale_shape: + assert self.scale_inv.shape == expected_unpadded_scale_shape, ( + f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" + f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" + f" {self.scale_inv.shape}" + ) + pad_width = tuple( + (0, a - b) for a, b in zip(expected_scale_shape, + expected_unpadded_scale_shape) + ) + # This actually pad scale_inv with nan, should we pad it with 127 directly instead? + self.scale_inv = jnp.pad( + self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. @@ -173,6 +193,10 @@ def tree_flatten(self): ) return (children, aux_data) + @property + def ndim(self): + return self.data.ndim + def dequantize(self): """Dequantizes the tensor using the stored dequantization function. @@ -370,6 +394,11 @@ def tree_flatten(self): aux_data = () return (children, aux_data) + @property + def ndim(self): + """Number of dimensions of the underlying row-wise tensor.""" + return self.rowwise_tensor.ndim + def dequantize(self): """Dequantizes the tensor using the row-wise component's dequantization. From da0709afa8fa727a0c2451a2d35d5ac7aa2f9559 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 13 Jun 2025 05:02:42 +0000 Subject: [PATCH 02/27] minor unit test cleanup Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 5 ++--- transformer_engine/jax/cpp_extensions/quantization.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5a59d113d5..a87d12bec1 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -45,7 +45,6 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x from transformer_engine_jax import is_non_nt_fp8_gemm_supported @@ -1133,7 +1132,7 @@ def ref_func(x, w, gamma, beta): @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm @@ -1141,7 +1140,7 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ - + use_jax_dot_for_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index e84fb3a976..4bb4803fcb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -36,7 +36,6 @@ Quantizer, GroupedQuantizer, QuantizeLayout, - DelayedScaleQuantizer, ScalingMode, compute_scale_from_amax, ) From e5b933c064466ef197558ffe7c31b2dd5f8348b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 05:03:15 +0000 Subject: [PATCH 03/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 54 ++-- tests/jax/test_layer.py | 6 +- .../jax/cpp_extensions/activation.py | 19 +- transformer_engine/jax/cpp_extensions/gemm.py | 239 +++++++++++------- .../jax/cpp_extensions/normalization.py | 10 +- .../jax/cpp_extensions/quantization.py | 27 +- .../jax/csrc/extensions/gemm.cpp | 65 ++--- transformer_engine/jax/csrc/extensions/misc.h | 4 +- transformer_engine/jax/dense.py | 31 +-- transformer_engine/jax/layernorm_dense.py | 15 +- transformer_engine/jax/layernorm_mlp.py | 41 +-- transformer_engine/jax/quantize/tensor.py | 12 +- 12 files changed, 306 insertions(+), 217 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index a87d12bec1..002a61c715 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -156,9 +156,9 @@ def assert_dequantized_grouped_scaled_tensor( def use_jax_dot_for_gemm(enabled=False): if enabled: - os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' - elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: - os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] @@ -731,9 +731,7 @@ def test_quantize_dbias( ) te_output, te_dbias = jit( - lambda input: tex.quantize_dbias( - inp, quantizer=te_quantizer, flatten_axis=flatten_axis - ) + lambda input: tex.quantize_dbias(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) )(inp) jax_output, jax_dbias = jit( @@ -911,44 +909,42 @@ def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): "lhs_q_dtype,rhs_q_dtype", [ (jnp.float8_e4m3fn, jnp.float8_e4m3fn), # fprop GEMM - (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM - (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM - ] + (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM + (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM + ], ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", supported_fp8_gemm_layouts) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_gemm_fp8(self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, - with_jax_gemm): + def test_gemm_fp8( + self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, with_jax_gemm + ): use_jax_dot_for_gemm(enabled=with_jax_gemm) lhs, rhs, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=False - ) - lhs_quantizer = ( - quantizer_set.x - if lhs_q_dtype == jnp.float8_e4m3fn - else quantizer_set.dgrad + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=False, ) + lhs_quantizer = quantizer_set.x if lhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad rhs_quantizer = ( - quantizer_set.kernel - if rhs_q_dtype == jnp.float8_e4m3fn - else quantizer_set.dgrad + quantizer_set.kernel if rhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad ) primitive_out = tex.gemm( - lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, - contracting_dims=contracting_dims + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims, ) ref_out = self._ref_gemm_with_jnp_dot(lhs, rhs, data_layout) test_q_dtype = ( - jnp.float8_e5m2 - if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) - else jnp.float8_e4m3fn + jnp.float8_e5m2 if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) else jnp.float8_e4m3fn ) assert_allclose(primitive_out, ref_out, dtype=test_q_dtype) @@ -1008,8 +1004,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=True + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index ab79e2eae4..389148415a 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -490,9 +490,9 @@ class BaseTester: def use_jax_dot_for_gemm(self, enabled=False): """Enable/disable TE custom cuBLAS GEMM primitive.""" if enabled: - os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$' - elif 'NVTE_JAX_CUSTOM_CALLS_RE' in os.environ: - os.environ.pop('NVTE_JAX_CUSTOM_CALLS_RE') + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 4d0a4c1bbd..341dcb0c8c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1124,9 +1124,8 @@ def quantize_dact_dbias( scale = jnp.empty((), jnp.float32) act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive - if ( - not PrimitiveClass.enabled() - or (quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE) + if not PrimitiveClass.enabled() or ( + quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) @@ -1152,9 +1151,17 @@ def quantize_dact_dbias( dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) if noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype, - ), dbias + return ( + ScaledTensorFactory.create_2x( + output, + None, + output, + None, + ScalingMode.NO_SCALING, + dq_dtype=output.dtype, + ), + dbias, + ) return output, dbias diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d553da10f3..5f1218872a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -62,10 +62,10 @@ def is_gemm_with_all_layouts_supported() -> bool: def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: """Convert relative (negative) indexes to absolute dimension numbers.""" - dims_ = dims if isinstance(dims, Iterable) else (dims, ) + dims_ = dims if isinstance(dims, Iterable) else (dims,) if len(dims_) == 0: return dims_ - return tuple( ndim + dim if dim < 0 else dim for dim in dims_ ) + return tuple(ndim + dim if dim < 0 else dim for dim in dims_) def get_non_contracting_dims(ndim, contracting_dims): @@ -88,8 +88,8 @@ def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: rhs_dtype, jnp.float8_e4m3fn, jnp.float8_e5m2, - jnp.uint8 # replace with jnp.float8_e8m0 when JAX/XLA merges support - ) + jnp.uint8, # replace with jnp.float8_e8m0 when JAX/XLA merges support + ), ) # MXFP8 GEMM needs both operands to be MXFP8 (uint8 for now until JAX merges float8_e8m0) @@ -105,8 +105,7 @@ def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: def _get_gemm_layout( - operand_ndims: Tuple[int, int], - contracting_dims: Tuple[Sequence[int], Sequence[int]] + operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]] ) -> Tuple[bool, bool]: lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting @@ -128,9 +127,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_rowwise=True, is_colwise=False, flatten_axis=( - max(lhs_contracting_dims) + 1 - if lhs_is_transposed - else min(lhs_contracting_dims) + max(lhs_contracting_dims) + 1 if lhs_is_transposed else min(lhs_contracting_dims) ), ) if lhs_is_transposed: @@ -145,9 +142,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_rowwise=True, is_colwise=False, flatten_axis=( - min(rhs_contracting_dims) - if rhs_is_transposed - else max(rhs_contracting_dims) + 1 + min(rhs_contracting_dims) if rhs_is_transposed else max(rhs_contracting_dims) + 1 ), ) if not rhs_is_transposed: @@ -171,8 +166,20 @@ class GemmPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, - contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + def abstract( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) ( @@ -182,7 +189,7 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), - (lhs_contracting_dims, rhs_contracting_dims) + (lhs_contracting_dims, rhs_contracting_dims), ) assert lhs_contracting_size == rhs_contracting_size, ( "cuBLAS GEMM operands have incompatible contracting dimensions: " @@ -195,9 +202,9 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype "cuBLAS GEMM quantized operands have incompatible data types: " f"{lhs.dtype} x {rhs.dtype}." ) - assert lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0, ( - "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." - ) + assert ( + lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 + ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING and not tex.is_non_nt_fp8_gemm_supported() @@ -209,19 +216,19 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype ) # Determine output shape and dtype - assert dtypes.canonicalize_dtype(out_dtype).itemsize > 1, ( - "cuBLAS GEMM custom op does not support 8-bit quantized output types." - ) + assert ( + dtypes.canonicalize_dtype(out_dtype).itemsize > 1 + ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." lhs_non_contracting_shape, rhs_non_contracting_shape = map( - lambda shape, dims: [ shape[dim] for dim in range(len(shape)) if dim not in dims ], + lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], (lhs.shape, rhs.shape), - (lhs_contracting_dims, rhs_contracting_dims) + (lhs_contracting_dims, rhs_contracting_dims), ) out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) # Validate bias - bias_shape = (0, ) + bias_shape = (0,) bias_dtype = out_dtype if fuse_bias: expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) @@ -240,15 +247,14 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) # Validate pre-GeLU - pre_gelu_shape = (0, ) + pre_gelu_shape = (0,) pre_gelu_dtype = out_dtype if fuse_gelu: pre_gelu_shape = out_shape if grad: pre_gelu_ndim = len(pre_gelu_shape) - assert ( - gelu_input.ndim == pre_gelu_shape - and all(gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)) + assert gelu_input.ndim == pre_gelu_shape and all( + gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim) ), ( "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " f"expected {pre_gelu_shape} but found {gelu_input.shape}." @@ -266,12 +272,13 @@ def abstract(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_swizzle_size = lhs_scale_inv.size rhs_swizzle_size = rhs_scale_inv.size - lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size, ), dtype=swizzle_dtype) - rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size, ), dtype=swizzle_dtype) + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) # Declare cuBLAS workspace - workspace = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), - dtype=jnp.uint8) + workspace = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace @@ -281,31 +288,45 @@ def outer_abstract(*args, **kwargs): return outputs[:-3] # discard workspace arrays @staticmethod - def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, - contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad): + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): del out_dtype lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) - lhs_transposed, rhs_transposed = _get_gemm_layout((lhs_aval.ndim, rhs_aval.ndim), - (lhs_cdims, rhs_cdims)) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) + ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) kwargs = { - "scaling_mode" : int(scaling_mode.value), - "lhs_axis_boundary" : max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - "rhs_axis_boundary" : min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, - "lhs_transposed" : lhs_transposed, - "rhs_transposed" : rhs_transposed, - "fuse_bias" : fuse_bias, - "fuse_gelu" : fuse_gelu, - "grad" : grad, + "scaling_mode": int(scaling_mode.value), + "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed": lhs_transposed, + "rhs_transposed": rhs_transposed, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, } operand_output_aliases = {} if fuse_bias and not grad: - operand_output_aliases.update({ 4 : 1 }) # bias <-> bias_grad + operand_output_aliases.update({4: 1}) # bias <-> bias_grad if fuse_gelu and grad: - operand_output_aliases.update({ 5 : 2 }) # gelu_input <-> pre_gelu_out + operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out return jax.ffi.ffi_lowering( GemmPrimitive.name, @@ -313,8 +334,20 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ )(ctx, *args, **kwargs) @staticmethod - def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, contracting_dims, - scaling_mode, fuse_bias, fuse_gelu, grad): + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): outputs = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, @@ -332,16 +365,24 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_dtype, co return outputs[:-3] # discard workspace arrays @staticmethod - def batcher(batched_args, batch_dims, out_dtype, contracting_dims, scaling_mode, fuse_bias, - fuse_gelu, grad): + def batcher( + batched_args, + batch_dims, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args lhs_bdims, *_ = batch_dims # Output is batched like LHS only if LHS is batched and RHS is not - out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None, ) - bias_bdims = (None, ) # Bias is never batched - pre_gelu_bdims = (None, ) # Pre-GeLU output, if exists, is batched like GEMM output + out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None,) + bias_bdims = (None,) # Bias is never batched + pre_gelu_bdims = (None,) # Pre-GeLU output, if exists, is batched like GEMM output if fuse_gelu and not grad: pre_gelu_bdims = out_bdims @@ -355,12 +396,21 @@ def batcher(batched_args, batch_dims, out_dtype, contracting_dims, scaling_mode, fuse_gelu=fuse_gelu, grad=grad, ), - (out_bdims, bias_bdims, pre_gelu_bdims) + (out_bdims, bias_bdims, pre_gelu_bdims), ) @staticmethod - def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse_bias, - fuse_gelu, grad, mesh, arg_infos, result_infos): + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + mesh, + arg_infos, + result_infos, + ): del out_dtype, scaling_mode, result_infos # Check contracting dimensions @@ -370,29 +420,29 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse sanitize_dims, operand_ndims, contracting_dims ) lhs_contracting_specs, rhs_contracting_specs = map( - lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None], + lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], (lhs_spec, rhs_spec), - (lhs_contracting_dims, rhs_contracting_dims) - ) - assert len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1, ( - "cuBLAS GEMM operands can have only one sharded contracting dimension." + (lhs_contracting_dims, rhs_contracting_dims), ) + assert ( + len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1 + ), "cuBLAS GEMM operands can have only one sharded contracting dimension." lhs_contracting_spec, rhs_contracting_spec = map( lambda spec: None if len(spec) == 0 else spec[0], - (lhs_contracting_specs, rhs_contracting_specs) - ) - assert lhs_contracting_spec == rhs_contracting_spec, ( - "cuBLAS GEMM operands must have the same sharding in contracting dimensions." + (lhs_contracting_specs, rhs_contracting_specs), ) + assert ( + lhs_contracting_spec == rhs_contracting_spec + ), "cuBLAS GEMM operands must have the same sharding in contracting dimensions." # Sanity check leading dimensions, allow for simultaneous batch and sequence sharding lhs_leading_dims, rhs_leading_dims = map( get_non_contracting_dims, operand_ndims, (lhs_contracting_dims, rhs_contracting_dims) ) lhs_leading_specs, rhs_leading_specs = map( - lambda specs, dims: [ specs[dim] for dim in dims if specs[dim] is not None ], + lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], (lhs_spec, rhs_spec), - (lhs_leading_dims, rhs_leading_dims) + (lhs_leading_dims, rhs_leading_dims), ) assert len(lhs_leading_specs) <= 1 and len(rhs_leading_specs) <= 1, ( "cuBLAS GEMM operands cannot have more than one sharded leading dimensions. This error " @@ -401,8 +451,7 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse # Determine output sharding lhs_leading_spec, rhs_leading_spec = map( - lambda spec: None if len(spec) == 0 else spec[0], - (lhs_leading_specs, rhs_leading_specs) + lambda spec: None if len(spec) == 0 else spec[0], (lhs_leading_specs, rhs_leading_specs) ) out_spec = (lhs_leading_spec, rhs_leading_spec) if operand_ndims[0] > 2 and operand_ndims[1] == 2: @@ -411,13 +460,13 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # Bias gradient sharding inherits the RHS contracting spec - bias_spec = (None, ) + bias_spec = (None,) if fuse_bias and grad: - bias_spec = (rhs_contracting_spec, ) + bias_spec = (rhs_contracting_spec,) bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) # Pre-GeLU sharding matches output sharding - pre_gelu_spec = (None, ) + pre_gelu_spec = (None,) if fuse_gelu and not grad: pre_gelu_spec = out_spec pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) @@ -425,11 +474,27 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, scaling_mode, fuse return (out_sharding, bias_sharding, pre_gelu_sharding) @staticmethod - def partition(out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, - mesh, arg_infos, result_infos): + def partition( + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + mesh, + arg_infos, + result_infos, + ): out_shardings = GemmPrimitive.infer_sharding_from_operands( - out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, grad, - mesh, arg_infos, result_infos + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + mesh, + arg_infos, + result_infos, ) output_spec = out_shardings[0].spec @@ -442,13 +507,13 @@ def partition(out_dtype, contracting_dims, scaling_mode, fuse_bias, fuse_gelu, g scale_sharding = NamedSharding(mesh, PartitionSpec(None)) # Bias has to be sharded same as the trailing dimension of the GEMM output - bias_spec = (None, ) + bias_spec = (None,) if fuse_bias and not grad: - bias_spec = (output_spec[-1], ) + bias_spec = (output_spec[-1],) bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) # Pre-GeLU output has to be sharded same as the GEMM output - pre_gelu_spec = (None, ) + pre_gelu_spec = (None,) if fuse_gelu and grad: pre_gelu_spec = output_spec pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) @@ -481,7 +546,7 @@ def _te_gemm( gelu_input: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1, ), (-2, )), + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, @@ -848,9 +913,9 @@ def _jax_gemm_fp8_impl(lhs, rhs): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) or isinstance(rhs_q, ScaledTensor): - assert isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor), ( - "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." - ) + assert isinstance(lhs_q, ScaledTensor) and isinstance( + rhs_q, ScaledTensor + ), "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." return _jax_gemm_fp8_impl(lhs_q, rhs_q) return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) @@ -939,7 +1004,7 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, - **kwargs + **kwargs, ) # Discard empty outputs @@ -948,11 +1013,11 @@ def gemm( grad = kwargs.get("grad", False) clean_outputs = outputs[0] # first output is the final result and is never empty if (fuse_bias and grad) or (fuse_gelu and not grad): - clean_outputs = (outputs[0], ) + clean_outputs = (outputs[0],) if fuse_bias and grad: # only return bias gradient if it exists - clean_outputs += (outputs[1], ) + clean_outputs += (outputs[1],) if fuse_gelu and not grad: # only return pre-GeLU output if it exists - clean_outputs += (outputs[2], ) + clean_outputs += (outputs[2],) return clean_outputs diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index f2829c3a58..bf5c257d7b 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1322,9 +1322,13 @@ def normalization_fwd( raise ValueError(f"{norm_type=} is not supported.") if quantizer is None and noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype - ), mu, rsigma + return ( + ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), + mu, + rsigma, + ) return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 4bb4803fcb..3cb0e1cdfb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -587,19 +587,25 @@ def _quantize_dbias_impl( if noop_scaled_tensor: # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() # always works. - return ScaledTensorFactory.create_2x( - x, None, x, None, ScalingMode.NO_SCALING, dq_dtype=x.dtype, data_layout="NN", - flatten_axis=flatten_axis, - ), dbias + return ( + ScaledTensorFactory.create_2x( + x, + None, + x, + None, + ScalingMode.NO_SCALING, + dq_dtype=x.dtype, + data_layout="NN", + flatten_axis=flatten_axis, + ), + dbias, + ) return x, dbias # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if ( - quantizer.q_layout == QuantizeLayout.COLWISE - or not PrimitiveClass.enabled() - ): + if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -748,7 +754,10 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + dz, + quantizer=quantizer, + is_dbias=is_dbias, + flatten_axis=flatten_axis, noop_scaled_tensor=noop_scaled_tensor, ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 02822a3f0a..7736e0ec50 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -4,7 +4,6 @@ * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" -#include "transformer_engine/swizzle.h" #include #include @@ -12,8 +11,9 @@ #include "../extensions.h" #include "common/util/cuda_runtime.h" -#include "common/util/system.h" #include "common/util/string.h" +#include "common/util/system.h" +#include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" #define MXFP8_BLOCK_SIZE 32 @@ -44,8 +44,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( std::vector scale_shape(scale_inv.dimensions().begin(), scale_inv.dimensions().end()); auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) - ? DType::kFloat8E8M0 - : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + ? DType::kFloat8E8M0 + : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { @@ -58,8 +58,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(swizzled_scale_inv->element_count() > 0, "Missing swizzled inverse scale buffer in the JAX primitive."); auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - auto swizzled_scale_inv_dtype = convert_ffi_datatype_to_te_dtype( - swizzled_scale_inv->element_type()); + auto swizzled_scale_inv_dtype = + convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, "Inverse scale factors need to have an 8-bit data type."); @@ -92,19 +92,18 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } -Error_Type GemmFFI( - cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, - Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type lhs_swizzle, - Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, - int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad) { +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, + Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad) { // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) - bool always_rowwise = ( - scaling_mode == JAXX_Scaling_Mode::NO_SCALING || ( - is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( @@ -117,12 +116,14 @@ Error_Type GemmFFI( (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), "cuBLAS GEMM output buffer size is incorrect, " - "expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, " + "expected ", + out_.numel(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(output->dimensions())); // Bias input to forward pass or bias gradient output from backward pass - void* bias_ptr = nullptr; + void *bias_ptr = nullptr; std::vector bias_shape = {0}; DType bias_dtype = out_dtype; if (fuse_bias) { @@ -137,7 +138,7 @@ Error_Type GemmFFI( auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); // Pre-GeLU output from forward pass or input to backward pass - void* pre_gelu_ptr = nullptr; + void *pre_gelu_ptr = nullptr; std::vector pre_gelu_shape = {0}; DType pre_gelu_dtype = out_dtype; if (gelu_input.element_count() > 0) { @@ -168,18 +169,18 @@ Error_Type GemmFFI( XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs - .Arg() // lhs_scale_inv - .Arg() // rhs - .Arg() // rhs_scale_inv - .Arg() // bias - .Arg() // gelu_input - .Ret() // output - .Ret() // bias_grad - .Ret() // pre_gelu_out - .Ret() // lhs_swizzled - .Ret() // rhs_swizzled - .Ret() // workspace + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace .Attr("scaling_mode") .Attr("lhs_axis_boundary") .Attr("rhs_axis_boundary") diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index b4ed3399d9..af7f54feb6 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -48,8 +48,8 @@ enum class JAXX_Scaling_Mode : int64_t { }; inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { - return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING - || mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING || + mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); } inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 0fe8d1bb9a..6f94620d99 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -90,24 +90,25 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, Returns: Tuple of (output, context) for backward pass """ - x_cdims, k_cdims = map( - tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims - ) + x_cdims, k_cdims = map(tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims) x_is_transposed = x.ndim - 1 not in x_cdims k_is_transposed = kernel.ndim - 1 in k_cdims - assert not x_is_transposed and not k_is_transposed, ( - "Forward-mode Dense layer implementation only supports 'NN' layout inputs." - ) + assert ( + not x_is_transposed and not k_is_transposed + ), "Forward-mode Dense layer implementation only supports 'NN' layout inputs." - casted_x = tex.quantize(x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, - noop_scaled_tensor=True) + casted_x = tex.quantize( + x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) rowwise_x = casted_x.get_rowwise_tensor() colwise_x = casted_x.get_colwise_tensor() casted_kernel = tex.quantize( - kernel, flatten_axis=max(k_cdims) + 1, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True + kernel, + flatten_axis=max(k_cdims) + 1, + quantizer=quantizer_set.kernel, + noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) colwise_k = casted_kernel.get_colwise_tensor() @@ -163,7 +164,10 @@ def _dense_bwd_rule( # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_grad, + quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) @@ -175,10 +179,7 @@ def _dense_bwd_rule( # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - rowwise_g, - rowwise_k, - contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), - grad=True + rowwise_g, rowwise_k, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 26b9d8fad5..bd8984ecec 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -197,8 +197,9 @@ def _layernorm_dense_fwd_rule( # Kernel (hidden_in, hidden_out) flatten_axis = 1 - kernel.ndim - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) rowwise_kernel = casted_kernel.get_rowwise_tensor() @@ -278,7 +279,10 @@ def _layernorm_dense_bwd_rule( # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_grad, quantizer=quantizer_set.dgrad, + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_grad, + quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) @@ -291,10 +295,7 @@ def _layernorm_dense_bwd_rule( # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) dgrad = tex.gemm( - rowwise_g, - rowwise_kernel, - contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), - grad=True + rowwise_g, rowwise_kernel, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, dot_input_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8a42555a5e..f76205d9f7 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -22,11 +22,7 @@ from . import cpp_extensions as tex from .layernorm import canonicalize_norm_type -from .quantize import ( - with_sharding_constraint_by_logical_axes, - QuantizerSet, - noop_quantizer_set -) +from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .sharding import get_non_contracting_logical_axes @@ -271,8 +267,9 @@ def _layernorm_mlp_fwd_rule( casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) # FC1 kernel (hidden_in, act_len, hidden_out) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) # Prepare FC1 FPROP operands and layouts rowwise_ln_out = casted_ln_out.get_rowwise_tensor() @@ -309,13 +306,15 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # Activation (batch..., act_len, hidden_out) -> (batch..., hidden_out) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, - noop_scaled_tensor=True) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) # FC2 kernel (hidden_out, hidden_in) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) # Prepare FC2 FPROP operands and layouts rowwise_act_out = casted_act_out.get_rowwise_tensor() @@ -333,7 +332,7 @@ def _layernorm_mlp_fwd_rule( bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, contracting_dims=(x_cdims, k_cdims), - grad=False + grad=False, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -418,8 +417,11 @@ def _layernorm_mlp_bwd_rule( flatten_axis_grad = len(fwd_x_non_cdims) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, flatten_axis=flatten_axis_grad, - quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, + grad, + is_dbias=use_bias_2, + flatten_axis=flatten_axis_grad, + quantizer=ffn2_quantizer_set.dgrad, + noop_scaled_tensor=True, ) # Prepare FC2 DGRAD and WGRAD operands and contracting dims @@ -433,10 +435,7 @@ def _layernorm_mlp_bwd_rule( # FC2 DGRAD GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) dgrad_2 = tex.gemm( - rowwise_g, - rowwise_kernel_2, - contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), - grad=True + rowwise_g, rowwise_kernel_2, contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), grad=True ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -466,7 +465,9 @@ def _layernorm_mlp_bwd_rule( rowwise_dact_out = casted_dact_out.get_rowwise_tensor() rowwise_dact_out_cdims = tuple(range(flatten_axis_grad, rowwise_dact_out.ndim)) colwise_dact_out = casted_dact_out.get_colwise_tensor() - colwise_dact_out_cdims = tex.get_non_contracting_dims(casted_dact_out.ndim, rowwise_dact_out_cdims) + colwise_dact_out_cdims = tex.get_non_contracting_dims( + casted_dact_out.ndim, rowwise_dact_out_cdims + ) fwd_k1_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_1.ndim, fwd_k_cdims) # FC1 DGRAD GEMM: @@ -475,7 +476,7 @@ def _layernorm_mlp_bwd_rule( rowwise_dact_out, rowwise_kernel_1, contracting_dims=(rowwise_dact_out_cdims, fwd_k1_non_cdims), - grad=True + grad=True, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 4761816220..1c5d77c05b 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -142,7 +142,7 @@ def __post_init__(self): ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((1, ), dtype=jnp.float32) + self.scale_inv = jnp.empty((1,), dtype=jnp.float32) else: expected_scale_shape = self.scaling_mode.get_scale_shape( @@ -158,8 +158,10 @@ def __post_init__(self): f" {self.scale_inv.shape}" ) expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, - flatten_axis=self.flatten_axis + self.data.shape, + self.is_colwise, + is_padded=False, + flatten_axis=self.flatten_axis, ) if self.scale_inv.shape != expected_scale_shape: assert self.scale_inv.shape == expected_unpadded_scale_shape, ( @@ -168,8 +170,8 @@ def __post_init__(self): f" {self.scale_inv.shape}" ) pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, - expected_unpadded_scale_shape) + (0, a - b) + for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) ) # This actually pad scale_inv with nan, should we pad it with 127 directly instead? self.scale_inv = jnp.pad( From 92dec51a5bb293e265ae5315f3d4865e8246479e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 13 Jun 2025 06:55:20 +0000 Subject: [PATCH 04/27] FP8 tests passing on Blackwell but MXFP8 outputs NaN Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d553da10f3..81c64359b1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -123,34 +123,36 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ lhs_q = lhs rhs_q = rhs if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + scaling_mode = lhs_quantizer.scaling_mode lhs_q = lhs_quantizer.quantize( lhs, - is_rowwise=True, - is_colwise=False, + is_rowwise=True if scaling_mode.is_tensor_scaling() else not lhs_is_transposed, + is_colwise=False if scaling_mode.is_tensor_scaling() else lhs_is_transposed, flatten_axis=( max(lhs_contracting_dims) + 1 if lhs_is_transposed else min(lhs_contracting_dims) ), ) - if lhs_is_transposed: + if lhs_is_transposed and scaling_mode.is_tensor_scaling(): # Manually update data layout and columnwise flag to avoid transposing already # transposed data lhs_q.data_layout = "T" lhs_q.is_colwise = True if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + scaling_mode = rhs_quantizer.scaling_mode rhs_q = rhs_quantizer.quantize( rhs, - is_rowwise=True, - is_colwise=False, + is_rowwise=True if scaling_mode.is_tensor_scaling() else rhs_is_transposed, + is_colwise=False if scaling_mode.is_tensor_scaling() else not rhs_is_transposed, flatten_axis=( min(rhs_contracting_dims) if rhs_is_transposed else max(rhs_contracting_dims) + 1 ), ) - if not rhs_is_transposed: + if not rhs_is_transposed and scaling_mode.is_tensor_scaling(): # Manually update data layout and columnwise flag to avoid transposing already # transposed data rhs_q.data_layout = "T" @@ -515,7 +517,7 @@ def _te_gemm( scaling_mode = lhs_q.scaling_mode lhs_data = lhs_q.data lhs_scale_inv = lhs_q.scale_inv - if lhs_q.data_layout == "T": + if lhs_is_transposed and lhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): lhs_contracting_dims = transpose_contracting_dims(lhs_q.ndim, lhs_contracting_dims) if isinstance(rhs_q, ScaledTensor): @@ -535,7 +537,7 @@ def _te_gemm( ) rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv - if rhs_q.data_layout == "T": + if rhs_is_transposed and rhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): rhs_contracting_dims = transpose_contracting_dims(rhs_q.ndim, rhs_contracting_dims) # Dummy empties for bias and gelu From 9eba586c3ee1d0453ad34bb183b6b1bef5edfcce Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Sat, 14 Jun 2025 06:21:22 +0000 Subject: [PATCH 05/27] reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2 Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 150 ++++++------- transformer_engine/jax/cpp_extensions/gemm.py | 102 ++++----- .../jax/csrc/extensions/gemm.cpp | 6 +- transformer_engine/jax/dense.py | 100 ++++----- transformer_engine/jax/layernorm_dense.py | 115 +++++----- transformer_engine/jax/layernorm_mlp.py | 201 ++++++++---------- transformer_engine/jax/quantize/tensor.py | 2 +- 7 files changed, 303 insertions(+), 373 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 002a61c715..ccb667b8f2 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2,15 +2,13 @@ # # See LICENSE for license information. -import os -import operator -from functools import reduce -from typing import Union - -import pytest import jax import jax.numpy as jnp +import pytest from jax import jit, value_and_grad +from functools import reduce +from typing import Union +import operator from utils import ( assert_allclose, @@ -46,8 +44,6 @@ from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine_jax import is_non_nt_fp8_gemm_supported - GEMM_CASES = [ (256, 256, 512), (32, 32, 32), @@ -62,14 +58,10 @@ is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] -supported_fp8_gemm_layouts = [] """ Find supported scaling modes""" if is_fp8_supported: supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) - supported_fp8_gemm_layouts.append("NT") - if is_non_nt_fp8_gemm_supported(): - supported_fp8_gemm_layouts += ["TT", "TN", "NN"] if is_mxfp8_supported: supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) @@ -80,7 +72,7 @@ def is_shape_supported_by_mxfp8(input_shape): input_shape = input_shape.values[0] ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape) return True - except AssertionError: + except: # get_scale_shapes will raise an exception if the shape is not supported return False @@ -154,13 +146,6 @@ def assert_dequantized_grouped_scaled_tensor( pytest.fail("a must be a GroupedScaledTensor object") -def use_jax_dot_for_gemm(enabled=False): - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)] ALL_ACTIVATION_TYPES = [ ("gelu",), @@ -638,15 +623,15 @@ def test_quantize_bitwise( ): key = jax.random.PRNGKey(0) - inp = jax.random.uniform(key, input_shape, in_dtype) + input = jax.random.uniform(key, input_shape, in_dtype) te_quantizer, jax_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) - jax_output = _jax_quantize(inp, quantizer=jax_quantizer, flatten_axis=flatten_axis) + jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -724,21 +709,23 @@ def test_quantize_dbias( pytest.skip(f"Input shape {input_shape} is not supported by MXFP8") key = jax.random.PRNGKey(0) - inp = jax.random.uniform(key, input_shape, in_dtype) + input = jax.random.uniform(key, input_shape, in_dtype) jax_quantizer, te_quantizer = QuantizerFactory.create( n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout ) te_output, te_dbias = jit( - lambda input: tex.quantize_dbias(inp, quantizer=te_quantizer, flatten_axis=flatten_axis) - )(inp) + lambda input: tex.quantize_dbias( + input, quantizer=te_quantizer, flatten_axis=flatten_axis + ) + )(input) jax_output, jax_dbias = jit( lambda input: _jax_quantize_dbias( - inp, quantizer=jax_quantizer, flatten_axis=flatten_axis + input, quantizer=jax_quantizer, flatten_axis=flatten_axis ) - )(inp) + )(input) assert_bitwise_scaled_tensors(te_output, jax_output) @@ -863,6 +850,21 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +valid_fp8_gemm_operand_types = [ + (jnp.float8_e4m3fn, jnp.float8_e4m3fn), + (jnp.float8_e5m2, jnp.float8_e4m3fn), + (jnp.float8_e4m3fn, jnp.float8_e5m2), +] + + +def _use_jax_fp8_gemm(enabled=False): + import os + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": @@ -892,10 +894,7 @@ def _generate_gemm_input(self, m, n, k, data_layout): @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): - use_jax_dot_for_gemm(enabled=with_jax_gemm) - + def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) primitive_out = tex.gemm(x, w, contracting_dims) @@ -904,55 +903,36 @@ def test_gemm_bf16(self, m, n, k, data_layout, with_jax_gemm): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper( - "lhs_q_dtype,rhs_q_dtype", - [ - (jnp.float8_e4m3fn, jnp.float8_e4m3fn), # fprop GEMM - (jnp.float8_e4m3fn, jnp.float8_e5m2), # wgrad GEMM - (jnp.float8_e5m2, jnp.float8_e4m3fn), # dgrad GEMM - ], - ) + @pytest_parametrize_wrapper("m,n,k", [(256, 512, 1024)]) + @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - @pytest_parametrize_wrapper("data_layout", supported_fp8_gemm_layouts) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_gemm_fp8( - self, m, n, k, lhs_q_dtype, rhs_q_dtype, scaling_mode, data_layout, with_jax_gemm - ): - use_jax_dot_for_gemm(enabled=with_jax_gemm) - - lhs, rhs, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + _use_jax_fp8_gemm(enabled=with_jax_gemm) + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=False, - ) - lhs_quantizer = quantizer_set.x if lhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad - rhs_quantizer = ( - quantizer_set.kernel if rhs_q_dtype == jnp.float8_e4m3fn else quantizer_set.dgrad + is_2x2x=False ) - primitive_out = tex.gemm( - lhs, - rhs, - lhs_quantizer=lhs_quantizer, - rhs_quantizer=rhs_quantizer, + x, + w, contracting_dims=contracting_dims, + lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, + rhs_quantizer=( + quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), ) - ref_out = self._ref_gemm_with_jnp_dot(lhs, rhs, data_layout) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) - test_q_dtype = ( - jnp.float8_e5m2 if jnp.float8_e5m2 in (lhs_q_dtype, rhs_q_dtype) else jnp.float8_e4m3fn - ) - assert_allclose(primitive_out, ref_out, dtype=test_q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_bf16(self, m, n, k, with_jax_gemm): - use_jax_dot_for_gemm(enabled=with_jax_gemm) - + def test_dense_grad_bf16(self, m, n, k): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -981,7 +961,7 @@ def ref_func(x, w, data_layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): - use_jax_dot_for_gemm(enabled=with_jax_gemm) + _use_jax_fp8_gemm(enabled=with_jax_gemm) data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -1004,10 +984,9 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, + scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, + is_2x2x=True ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 @@ -1054,7 +1033,7 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g """ Test layernorm_dense VJP Rule """ - use_jax_dot_for_gemm(enabled=with_jax_gemm) + _use_jax_fp8_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1072,7 +1051,7 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1130,7 +1109,7 @@ def ref_func(x, w, gamma, beta): @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest.mark.parametrize("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm @@ -1138,7 +1117,7 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ - use_jax_dot_for_gemm(enabled=with_jax_gemm) + _use_jax_fp8_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1165,7 +1144,7 @@ def test_layernorm_mlp_grad( n_quantizer_sets=2, scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1190,26 +1169,25 @@ def prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ) ) - def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): - dim_nums = ((1,), (0,)), ((), ()) - + def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - - linear_1_out = jax.lax.dot_general(ln_out, kernel_1, dim_nums) + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ()))) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - act_out = _jax_act_lu(linear_1_out, activation_type) - - linear_2_out = jax.lax.dot_general(act_out, kernel_2, dim_nums) + x = _jax_act_lu(linear_1_out, activation_type) + linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) - return jnp.mean(linear_2_out) + return linear_2_out + + def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): + return jnp.mean(_ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2)) value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2c6794e6bc..2475264c7c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -18,7 +18,7 @@ from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive -from .quantization import grouped_quantize +from .quantization import quantize, grouped_quantize from ..quantize import ( ScaledTensor, ScaledTensor2x, @@ -81,21 +81,16 @@ def transpose_contracting_dims(ndim, contracting_dims): def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: - lhs, rhs, e4m3, e5m2, e8m0 = map( + lhs, rhs, e4m3, e5m2 = map( dtypes.canonicalize_dtype, ( lhs_dtype, rhs_dtype, jnp.float8_e4m3fn, jnp.float8_e5m2, - jnp.uint8, # replace with jnp.float8_e8m0 when JAX/XLA merges support ), ) - # MXFP8 GEMM needs both operands to be MXFP8 (uint8 for now until JAX merges float8_e8m0) - if lhs is e8m0 and rhs is e8m0: - return True - # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): return True @@ -114,44 +109,37 @@ def _get_gemm_layout( def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): - lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) - lhs_contracting_dims, rhs_contracting_dims = map( - sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims - ) - lhs_q = lhs rhs_q = rhs if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: - scaling_mode = lhs_quantizer.scaling_mode + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_is_rowwise = lhs.ndim - 1 in lhs_cdims + flatten_axis = ( + min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 + ) lhs_q = lhs_quantizer.quantize( lhs, - is_rowwise=True if scaling_mode.is_tensor_scaling() else not lhs_is_transposed, - is_colwise=False if scaling_mode.is_tensor_scaling() else lhs_is_transposed, - flatten_axis=( - max(lhs_contracting_dims) + 1 if lhs_is_transposed else min(lhs_contracting_dims) - ), + is_rowwise=lhs_is_rowwise, + is_colwise=not lhs_is_rowwise, + flatten_axis=flatten_axis, ) - if lhs_is_transposed and scaling_mode.is_tensor_scaling(): - # Manually update data layout and columnwise flag to avoid transposing already - # transposed data - lhs_q.data_layout = "T" - lhs_q.is_colwise = True + if isinstance(lhs_q, ScaledTensor2x): + lhs_q = lhs_q.get_rowwise_tensor() if lhs_is_rowwise else lhs_q.get_colwise_tensor() if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: - scaling_mode = rhs_quantizer.scaling_mode + rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_is_rowwise = rhs.ndim - 1 in rhs_cdims + flatten_axis = ( + min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 + ) rhs_q = rhs_quantizer.quantize( rhs, - is_rowwise=True if scaling_mode.is_tensor_scaling() else rhs_is_transposed, - is_colwise=False if scaling_mode.is_tensor_scaling() else not rhs_is_transposed, - flatten_axis=( - min(rhs_contracting_dims) if rhs_is_transposed else max(rhs_contracting_dims) + 1 - ), + is_rowwise=rhs_is_rowwise, + is_colwise=not rhs_is_rowwise, + flatten_axis=flatten_axis, ) - if not rhs_is_transposed and scaling_mode.is_tensor_scaling(): - # Manually update data layout and columnwise flag to avoid transposing already - # transposed data - rhs_q.data_layout = "T" - rhs_q.is_colwise = True + if isinstance(rhs_q, ScaledTensor2x): + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_rowwise else rhs_q.get_colwise_tensor() return lhs_q, rhs_q @@ -270,7 +258,7 @@ def abstract( # Need extra workspace for swizzled scale factors lhs_swizzle_size = 0 rhs_swizzle_size = 0 - swizzle_dtype = jnp.uint8 # replace with jnp.float8_e8m0 when JAX merges support + swizzle_dtype = jnp.uint8 if scaling_mode == ScalingMode.MXFP8_1D_SCALING: lhs_swizzle_size = lhs_scale_inv.size rhs_swizzle_size = rhs_scale_inv.size @@ -535,7 +523,6 @@ def partition( register_primitive(GemmPrimitive) -@lru_cache(maxsize=1) def gemm_uses_jax_dot() -> bool: """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" return not GemmPrimitive.enabled() @@ -560,7 +547,7 @@ def _te_gemm( rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) - lhs_contracting_dims, rhs_contracting_dims = map( + lhs_cdims, rhs_cdims = map( sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims ) @@ -574,16 +561,13 @@ def _te_gemm( "`Quantizer` object to quantize the RHS operand." ) if isinstance(lhs_q, ScaledTensor2x): - # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise - # shape. Since we have access to both row-wise and column-wise tensors, we always - # choose the one that avoids transposing LHS in the GEMM kernel to comply with the - # NT-layout restriction for FP8 GEMM on Hopper. + # Choose the quantization of the contracting dimension(s) lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() scaling_mode = lhs_q.scaling_mode lhs_data = lhs_q.data lhs_scale_inv = lhs_q.scale_inv - if lhs_is_transposed and lhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): - lhs_contracting_dims = transpose_contracting_dims(lhs_q.ndim, lhs_contracting_dims) + if lhs_q.data_layout == "T": + lhs_cdims = transpose_contracting_dims(lhs_q.ndim, lhs_cdims) if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -591,10 +575,7 @@ def _te_gemm( "`Quantizer` object to quantize the LHS operand." ) if isinstance(rhs_q, ScaledTensor2x): - # Contracting dimensions for a ScaledTensor2x is interpreted relative to the row-wise - # shape. Since we have access to both row-wise and column-wise tensors, we always - # choose the one that avoids transposing LHS in the GEMM kernel to comply with the - # NT-layout restriction for FP8 GEMM on Hopper. + # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( "cuBLAS GEMM quantized operands have mismatched scaling types, " @@ -602,8 +583,8 @@ def _te_gemm( ) rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv - if rhs_is_transposed and rhs_q.data_layout == "T" and scaling_mode.is_tensor_scaling(): - rhs_contracting_dims = transpose_contracting_dims(rhs_q.ndim, rhs_contracting_dims) + if rhs_q.data_layout == "T": + rhs_cdims = transpose_contracting_dims(rhs_q.ndim, rhs_cdims) # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -620,7 +601,7 @@ def _te_gemm( bias, gelu_input, out_dtype=out_dtype, - contracting_dims=(lhs_contracting_dims, rhs_contracting_dims), + contracting_dims=(lhs_cdims, rhs_cdims), scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -911,16 +892,21 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - # Quantize operands (if necessary) - lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, + contracting_dims) - if isinstance(lhs_q, ScaledTensor) or isinstance(rhs_q, ScaledTensor): - assert isinstance(lhs_q, ScaledTensor) and isinstance( - rhs_q, ScaledTensor - ), "Both LHS and RHS must be quantized (or have valid quantizers) for FP8 GEMM." + if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): return _jax_gemm_fp8_impl(lhs_q, rhs_q) - return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) + if ( + isinstance(lhs, jnp.ndarray) + and isinstance(rhs, jnp.ndarray) + and lhs_quantizer is None + and rhs_quantizer is None + ): + return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) + + raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array") def gemm( @@ -987,7 +973,7 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled - if gemm_uses_jax_dot(): + if not GemmPrimitive.enabled(): assert kwargs.get("bias", None) is None and not kwargs.get("fuse_bias", False), ( "TE GEMM was invoked with bias fusion options that are not supported by the " "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7736e0ec50..987b0e817d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -42,7 +42,9 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); - std::vector scale_shape(scale_inv.dimensions().begin(), scale_inv.dimensions().end()); + auto scale_dims = scale_inv.dimensions(); + std::vector scale_shape = {product(scale_dims, 0, axis_boundary), + product(scale_dims, axis_boundary, scale_dims.size())}; auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); @@ -53,7 +55,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } // Swizzle scaling factors for MXFP8 - if (is_block_scaling(scaling_mode)) { + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { // Get the swizzle buffer NVTE_CHECK(swizzled_scale_inv->element_count() > 0, "Missing swizzled inverse scale buffer in the JAX primitive."); diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 6f94620d99..2aeb46594b 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -90,40 +90,29 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, Returns: Tuple of (output, context) for backward pass """ - x_cdims, k_cdims = map(tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims) - x_is_transposed = x.ndim - 1 not in x_cdims - k_is_transposed = kernel.ndim - 1 in k_cdims - assert ( - not x_is_transposed and not k_is_transposed - ), "Forward-mode Dense layer implementation only supports 'NN' layout inputs." - - casted_x = tex.quantize( - x, flatten_axis=min(x_cdims), quantizer=quantizer_set.x, noop_scaled_tensor=True - ) + x_contracting_dims, k_contracting_dims = contracting_dims + + flatten_axis_x = -len(x_contracting_dims) + flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + + casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, + noop_scaled_tensor=True) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) - rowwise_x = casted_x.get_rowwise_tensor() - colwise_x = casted_x.get_colwise_tensor() casted_kernel = tex.quantize( - kernel, - flatten_axis=max(k_cdims) + 1, - quantizer=quantizer_set.kernel, + kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) - colwise_k = casted_kernel.get_colwise_tensor() - rowwise_k = casted_kernel.get_rowwise_tensor() - # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) - # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + # GEMM NN use_bias = bias is not None output = tex.gemm( - rowwise_x, - colwise_k, + casted_x.get_rowwise_tensor(), + casted_kernel.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, ) if use_bias and tex.gemm_uses_jax_dot(): @@ -131,10 +120,13 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, output += jnp.reshape(bias, bias_new_shape) ctx = ( - colwise_x, - rowwise_k, + casted_x.get_colwise_tensor(), + casted_kernel.get_rowwise_tensor(), + x.shape, + kernel.shape, use_bias, quantizer_set, + flatten_axis_k, ) return output, ctx @@ -147,49 +139,49 @@ def _dense_bwd_rule( Returns: Tuple of gradients with respect to inputs """ + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + ( - colwise_x, - rowwise_k, + colwise_casted_x, + rowwise_casted_kernel, + x_shape, + kernel_shape, use_bias, quantizer_set, + flatten_axis_k, ) = ctx - # Original non-contracting dimensions in the forward pass are contracting dimensions for the - # backward pass. - fwd_x_cdims, fwd_k_cdims = map( - tex.sanitize_dims, (colwise_x.ndim, rowwise_k.ndim), contracting_dims - ) - fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_x.ndim, fwd_x_cdims) - fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_k.ndim, fwd_k_cdims) - # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input - flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, - is_dbias=use_bias, - flatten_axis=flatten_axis_grad, - quantizer=quantizer_set.dgrad, + grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) - # Prepare DGRAD and WGRAD operands and contracting dims - rowwise_g = casted_grad.get_rowwise_tensor() - rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) - colwise_g = casted_grad.get_colwise_tensor() - colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) - - # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) + # GEMM NT + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim + g_contracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + ) + # k_non_contracting_dims + k_contracting_dim = tuple( + dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + ) dgrad = tex.gemm( - rowwise_g, rowwise_k, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True + casted_grad.get_rowwise_tensor(), + rowwise_casted_kernel, + contracting_dims=(g_contracting_dim, k_contracting_dim), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) - # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) - # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) + # GEMM TN + # x_non_contracting_dims + g_contracting_dim = x_contracting_dim = tuple( + range(0, len(x_shape) - len(fwd_x_contracting_dims)) + ) + wgrad = tex.gemm( - colwise_x, - colwise_g, - contracting_dims=(fwd_x_non_cdims, colwise_g_cdims), - grad=True, + colwise_casted_x, + casted_grad.get_colwise_tensor(), + contracting_dims=(x_contracting_dim, g_contracting_dim) ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index bd8984ecec..96e66ff946 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -173,12 +173,12 @@ def _layernorm_dense_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - x_cdims = (x.ndim - 1,) - k_cdims = (0,) + x_contracting_dims = (len(x.shape) - 1,) + k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) + casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -186,52 +186,47 @@ def _layernorm_dense_fwd_rule( zero_centered_gamma, epsilon, norm_type, - quantizer_set.x, - noop_scaled_tensor=True, + quantizer=quantizer_set.x, + noop_scaled_tensor=True ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) - # Layernorm output (batch..., hidden_in) - rowwise_ln_out = casted_ln_out.get_rowwise_tensor() - colwise_ln_out = casted_ln_out.get_colwise_tensor() - - # Kernel (hidden_in, hidden_out) - flatten_axis = 1 - kernel.ndim - casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True - ) + # Kernel in (hidden_in, hidden_out...) + flatten_axis = 1 - len(kernel.shape) + casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + noop_scaled_tensor=True) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) - rowwise_kernel = casted_kernel.get_rowwise_tensor() - colwise_kernel = casted_kernel.get_colwise_tensor() - - # FPROP GEMM: (batch..., hidden_in) x (hidden_in, hidden_out) = (batch..., hidden_out) - # FPROP FP8 GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + # NN GEMM + # (batch..., hidden_in) x (hidden_in, hidden_out...) use_bias = bias is not None output = tex.gemm( - rowwise_ln_out, - colwise_kernel, + casted_ln_out.get_rowwise_tensor(), + casted_kernel.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, ) + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - colwise_ln_out, - rowwise_kernel, + casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None, + casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None, + x.shape, + kernel.shape, mu, rsigma, x, gamma, beta, - x_cdims, - k_cdims, + x_contracting_dims, + k_contracting_dims, use_bias, quantizer_set, + flatten_axis, ) return output, ctx @@ -242,7 +237,7 @@ def _layernorm_dense_bwd_rule( zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes, + dot_input_axes, # pylint: disable=unused-argument kernel_axes, ctx, grad, @@ -259,57 +254,58 @@ def _layernorm_dense_bwd_rule( Tuple of gradients for all input parameters """ ( - colwise_ln_out, - rowwise_kernel, + colwise_casted_ln_out, + rowwise_casted_kernel, + x_shape, + kernel_shape, mu, rsigma, x, gamma, beta, - fwd_x_cdims, - fwd_k_cdims, + x_contracting_dims_in_fwd, + k_contracting_dims_in_fwd, use_bias, quantizer_set, + flatten_axis, ) = ctx - # Original non-contracting dimensions in the forward pass are contracting dimensions for the - # backward pass. - fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) - fwd_k_non_cdims = tex.get_non_contracting_dims(rowwise_kernel.ndim, fwd_k_cdims) - # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input - flatten_axis_grad = len(fwd_x_non_cdims) casted_grad, dbias = tex.quantize_dbias( - grad, - is_dbias=use_bias, - flatten_axis=flatten_axis_grad, - quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, + grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True ) - # Prepare DGRAD and WGRAD operands and contracting dims - rowwise_g = casted_grad.get_rowwise_tensor() - rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) - colwise_g = casted_grad.get_colwise_tensor() - colwise_ln_out_cdims = fwd_x_non_cdims - colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim + g_constracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(k_contracting_dims_in_fwd), grad.ndim) + ) + # k_non_contracting_dims + k_constracting_dim = tuple( + dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd + ) - # DGRAD GEMM: (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) + # NT GEMM dgrad = tex.gemm( - rowwise_g, rowwise_kernel, contracting_dims=(rowwise_g_cdims, fwd_k_non_cdims), grad=True + casted_grad.get_rowwise_tensor(), + rowwise_casted_kernel, + contracting_dims=(g_constracting_dim, k_constracting_dim), + ) + + dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) + + g_constracting_dim = x_constracting_dim = tuple( + range(0, len(x_shape) - len(x_contracting_dims_in_fwd)) ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, dot_input_axes) - # WGRAD GEMM: (batch..., hidden_in)^T x (batch..., hidden_out) = (hidden_in, hidden_out) - # WGRAD FP8 GEMM: (hidden_in, batch...) x (hidden_out, batch...)^T = (hidden_in, hidden_out) + # TN GEMM wgrad = tex.gemm( - colwise_ln_out, - colwise_g, - contracting_dims=(colwise_ln_out_cdims, colwise_g_cdims), - grad=True, + colwise_casted_ln_out, + casted_grad.get_colwise_tensor(), + contracting_dims=(x_constracting_dim, g_constracting_dim), ) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) - # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad, x, @@ -321,7 +317,6 @@ def _layernorm_dense_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) - dx = with_sharding_constraint_by_logical_axes(dx, layernorm_input_axes) return dx, wgrad, dgamma, dbeta, dbias, quantizer_set diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index f76205d9f7..8deb83d87a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -244,16 +244,16 @@ def _layernorm_mlp_fwd_rule( assert len(kernel_2.shape) == 2 assert kernel_1.shape[-2] == len(activation_type) - x_cdims = (x.ndim - 1,) - k_cdims = (0,) + x_contracting_dims = (len(x.shape) - 1,) + k_contracting_dims = (0,) - assert x.shape[x_cdims[0]] == kernel_1.shape[k_cdims[0]] + assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None - # Apply layernorm with quantized output if quantizer_set is given x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) + casted_ln_out, mu, rsigma = tex.normalization_fwd( x, gamma, @@ -266,35 +266,23 @@ def _layernorm_mlp_fwd_rule( ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - # FC1 kernel (hidden_in, act_len, hidden_out) - casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True - ) - - # Prepare FC1 FPROP operands and layouts - rowwise_ln_out = casted_ln_out.get_rowwise_tensor() - rowwise_kernel_1 = casted_kernel_1.get_rowwise_tensor() - colwise_ln_out = casted_ln_out.get_colwise_tensor() - colwise_kernel_1 = casted_kernel_1.get_colwise_tensor() + casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, + noop_scaled_tensor=True) - # FC1 GEMM: - # (batch..., hidden_in) x (hidden_in, act_len, hidden_out) = (batch..., act_len, hidden_out) - # FC1 FP8 GEMM: - # (batch..., hidden_in) x (hidden_out, act_len, hidden_in)^T = (batch..., act_len, hidden_out) - use_bias_1 = bias_1 is not None + # NN GEMM + # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( - rowwise_ln_out, - colwise_kernel_1, + casted_ln_out.get_rowwise_tensor(), + casted_kernel_1.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False ) if dot_1_input_axes is not None and kernel_1_axes is not None: dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_cdims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_cdims), + *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), + *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), ) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) @@ -305,34 +293,23 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) - # Activation (batch..., act_len, hidden_out) -> (batch..., hidden_out) - casted_act_out = tex.act_lu( - dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True - ) - casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) + # (batch..., hidden_in) -> (batch..., hidden) + casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + noop_scaled_tensor=True) - # FC2 kernel (hidden_out, hidden_in) - casted_kernel_2 = tex.quantize( - kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True - ) + casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - # Prepare FC2 FPROP operands and layouts - rowwise_act_out = casted_act_out.get_rowwise_tensor() - rowwise_kernel_2 = casted_kernel_2.get_rowwise_tensor() - colwise_act_out = casted_act_out.get_colwise_tensor() - colwise_kernel_2 = casted_kernel_2.get_colwise_tensor() + casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, + noop_scaled_tensor=True) - # FC2 GEMM: - # (batch..., hidden_out) x (hidden_out, hidden_in) = (batch..., hidden_in) - # FC2 FP8 GEMM: - # (batch..., hidden_out) x (hidden_in, hidden_out)^T = (batch..., hidden_in) + # NN GEMM + # (batch..., hidden_in) x (hidden_out, hidden_in) dot_2_output = tex.gemm( - rowwise_act_out, - colwise_kernel_2, + casted_act_out.get_rowwise_tensor(), + casted_kernel_2.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, - contracting_dims=(x_cdims, k_cdims), - grad=False, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -348,13 +325,15 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - colwise_ln_out, - rowwise_kernel_1, + casted_ln_out.get_colwise_tensor(), + casted_kernel_1.get_rowwise_tensor(), dot_1_output, - colwise_act_out, - rowwise_kernel_2, - x_cdims, - k_cdims, + casted_act_out.get_colwise_tensor(), + casted_kernel_2.get_rowwise_tensor(), + x_contracting_dims, + k_contracting_dims, + kernel_1.shape, + kernel_2.shape, use_bias_1, use_bias_2, quantizer_sets, @@ -391,20 +370,22 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, rsigma, gamma, beta, - colwise_ln_out, - rowwise_kernel_1, + colwise_casted_ln_out, + rowwise_casted_kernel_1, dot_1_output, - colwise_act_out, - rowwise_kernel_2, - fwd_x_cdims, - fwd_k_cdims, + colwise_casted_act_out, + rowwise_casted_kernel_2, + x_contracting_dims_in_fwd, + k_contracting_dims_in_fwd, + kernel_1_shape, + kernel_2_shape, use_bias_1, use_bias_2, quantizer_sets, @@ -412,87 +393,84 @@ def _layernorm_mlp_bwd_rule( ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets - # Axis boundary for the gradient is the number of non-contracting dimensions of the FWD input - fwd_x_non_cdims = tex.get_non_contracting_dims(colwise_ln_out.ndim, fwd_x_cdims) - flatten_axis_grad = len(fwd_x_non_cdims) + # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) + casted_grad, dbias_2 = tex.quantize_dbias( - grad, - is_dbias=use_bias_2, - flatten_axis=flatten_axis_grad, - quantizer=ffn2_quantizer_set.dgrad, - noop_scaled_tensor=True, + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + noop_scaled_tensor=True ) - # Prepare FC2 DGRAD and WGRAD operands and contracting dims - rowwise_g = casted_grad.get_rowwise_tensor() - rowwise_g_cdims = tuple(range(flatten_axis_grad, grad.ndim)) - fwd_k2_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_2.ndim, fwd_k_cdims) - - colwise_g = casted_grad.get_colwise_tensor() - colwise_g_cdims = tex.get_non_contracting_dims(grad.ndim, rowwise_g_cdims) - colwise_act_out_cdims = tex.get_non_contracting_dims(colwise_act_out.ndim, fwd_x_cdims) + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim + g_contracting_dims_2 = tuple( + range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) + ) + # k_non_contracting_dims + k_contracting_dims_2 = tuple( + dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd + ) - # FC2 DGRAD GEMM: (batch..., hidden_in) x (hidden_out, hidden_in)^T = (batch..., hidden_out) + # NT GEMM + # (batch..., hidden_out) x (hidden_in, hidden_out) dgrad_2 = tex.gemm( - rowwise_g, rowwise_kernel_2, contracting_dims=(rowwise_g_cdims, fwd_k2_non_cdims), grad=True + casted_grad.get_rowwise_tensor(), + rowwise_casted_kernel_2, + contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), ) + dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - # FC2 WGRAD GEMM: - # (batch..., hidden_out)^T x (batch..., hidden_in) = (hidden_out, hidden_in) - # FC2 WGRAD FP8 GEMM: - # (hidden_out, batch...) x (hidden_in, batch...)^T = (hidden_out, hidden_in) + x_contracting_dims = g_contracting_dims = tuple( + range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) + ) + + # TN GEMM + # (hidden, batch...,) x (hidden, batch...) wgrad_2 = tex.gemm( - colwise_act_out, - colwise_g, - contracting_dims=(colwise_act_out_cdims, colwise_g_cdims), - grad=True, + colwise_casted_act_out, + casted_grad.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, g_contracting_dims), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) - # Activation gradient w/ bias fusion (batch..., hidden_out) -> (batch.., act_len, hidden_out) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( dgrad_2, dot_1_output, activation_type=activation_type, is_dbias=use_bias_1, - quantizer=ffn1_quantizer_set.dgrad, + quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, ) - # Prepare FC1 DGRAD and WGRAD operands and contracting dims - rowwise_dact_out = casted_dact_out.get_rowwise_tensor() - rowwise_dact_out_cdims = tuple(range(flatten_axis_grad, rowwise_dact_out.ndim)) - colwise_dact_out = casted_dact_out.get_colwise_tensor() - colwise_dact_out_cdims = tex.get_non_contracting_dims( - casted_dact_out.ndim, rowwise_dact_out_cdims + # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim + dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim + g_contracting_dims_1 = tuple( + range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim) + ) + # k_non_contracting_dims + k_contracting_dims_1 = tuple( + dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd ) - fwd_k1_non_cdims = tex.get_non_contracting_dims(rowwise_kernel_1.ndim, fwd_k_cdims) - # FC1 DGRAD GEMM: - # (batch..., act_len, hidden_out) x (hidden_in, act_len, hidden_out)^T = (batch..., hidden_in) + # NT GEMM dgrad_1 = tex.gemm( - rowwise_dact_out, - rowwise_kernel_1, - contracting_dims=(rowwise_dact_out_cdims, fwd_k1_non_cdims), - grad=True, + casted_dact_out.get_rowwise_tensor(), + rowwise_casted_kernel_1, + contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), ) + dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) - # FC1 WGRAD GEMM: - # (batch..., hidden_in)^T x (batch..., act_len, hidden_out) = (hidden_in, act_len, hidden_out) - # FC1 WGRAD FP8 GEMM: - # (hidden_in, batch...) x (hidden_out, act_len, batch...)^T = (hidden_in, act_len, hidden_out) + # TN GEMM + # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( - colwise_ln_out, - colwise_dact_out, - contracting_dims=(fwd_x_non_cdims, colwise_dact_out_cdims), - grad=True, + colwise_casted_ln_out, + casted_dact_out.get_colwise_tensor(), + contracting_dims=(x_contracting_dims, g_contracting_dims), ) + wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) - # Layernorm gradient dx, dgamma, dbeta = tex.normalization_bwd( dgrad_1, x, @@ -504,7 +482,6 @@ def _layernorm_mlp_bwd_rule( epsilon=epsilon, norm_type=norm_type, ) - dx = with_sharding_constraint_by_logical_axes(dx, norm_input_axes) return (dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, quantizer_sets) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 1c5d77c05b..98dc38de24 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -142,7 +142,7 @@ def __post_init__(self): ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" if self.scaling_mode == ScalingMode.NO_SCALING: - self.scale_inv = jnp.empty((1,), dtype=jnp.float32) + self.scale_inv = jnp.empty((0,), dtype=jnp.float32) else: expected_scale_shape = self.scaling_mode.get_scale_shape( From b80e2843ad1750821295fa3d16a998f6ba18f4e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Jun 2025 06:21:52 +0000 Subject: [PATCH 06/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 8 +++++--- transformer_engine/jax/cpp_extensions/gemm.py | 15 ++++---------- transformer_engine/jax/dense.py | 16 ++++++++++----- transformer_engine/jax/layernorm_dense.py | 14 ++++++++----- transformer_engine/jax/layernorm_mlp.py | 20 ++++++++++--------- 5 files changed, 40 insertions(+), 33 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ccb667b8f2..cd3e09b64c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -859,6 +859,7 @@ def test_quantize_dact_dbias_mxfp8_scaling( def _use_jax_fp8_gemm(enabled=False): import os + if enabled: os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: @@ -916,7 +917,7 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2, - is_2x2x=False + is_2x2x=False, ) primitive_out = tex.gemm( x, @@ -984,9 +985,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=jnp.float8_e4m3fn, + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2475264c7c..124703eaec 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -114,9 +114,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) lhs_is_rowwise = lhs.ndim - 1 in lhs_cdims - flatten_axis = ( - min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 - ) + flatten_axis = min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 lhs_q = lhs_quantizer.quantize( lhs, is_rowwise=lhs_is_rowwise, @@ -129,9 +127,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) rhs_is_rowwise = rhs.ndim - 1 in rhs_cdims - flatten_axis = ( - min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 - ) + flatten_axis = min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( rhs, is_rowwise=rhs_is_rowwise, @@ -547,9 +543,7 @@ def _te_gemm( rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) - lhs_cdims, rhs_cdims = map( - sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims - ) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -892,8 +886,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, - contracting_dims) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): return _jax_gemm_fp8_impl(lhs_q, rhs_q) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 2aeb46594b..8c5da04371 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -95,12 +95,15 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, - noop_scaled_tensor=True) + casted_x = tex.quantize( + x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, + kernel, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.kernel, noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -152,7 +155,10 @@ def _dense_bwd_rule( ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) @@ -181,7 +187,7 @@ def _dense_bwd_rule( wgrad = tex.gemm( colwise_casted_x, casted_grad.get_colwise_tensor(), - contracting_dims=(x_contracting_dim, g_contracting_dim) + contracting_dims=(x_contracting_dim, g_contracting_dim), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 96e66ff946..6be7954862 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -187,14 +187,15 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, - noop_scaled_tensor=True + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM @@ -271,8 +272,11 @@ def _layernorm_dense_bwd_rule( ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8deb83d87a..a9aef69ab0 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -266,8 +266,9 @@ def _layernorm_mlp_fwd_rule( ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) @@ -276,7 +277,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_1.get_colwise_tensor(), contracting_dims=(x_contracting_dims, k_contracting_dims), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, - fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) if dot_1_input_axes is not None and kernel_1_axes is not None: @@ -294,13 +295,15 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, - noop_scaled_tensor=True) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel, - noop_scaled_tensor=True) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) @@ -397,8 +400,7 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, - noop_scaled_tensor=True + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From a7aa2f4f1a112a17a37ae356ccbbc3aa1c754fb8 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 17 Jun 2025 09:16:49 +0000 Subject: [PATCH 07/27] MXFP8 issue traced to scale factor padding with NaNs instead of zeros Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 9 +++- transformer_engine/jax/cpp_extensions/gemm.py | 46 ++++++++++++++++--- .../jax/csrc/extensions/gemm.cpp | 9 ++-- transformer_engine/jax/quantize/tensor.py | 27 ++++------- 4 files changed, 60 insertions(+), 31 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index cd3e09b64c..9ea44bb1f7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -904,12 +904,19 @@ def test_gemm_bf16(self, m, n, k, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(256, 512, 1024)]) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + if ( + not with_jax_gemm + and scaling_mode.is_1d_block_scaling() + and jnp.float8_e5m2 in (x_qtype, w_qtype) + ): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + _use_jax_fp8_gemm(enabled=with_jax_gemm) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 124703eaec..c54314e283 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -7,7 +7,7 @@ import operator from collections.abc import Iterable from typing import Tuple, Sequence, Union -from functools import partial, reduce, lru_cache +from functools import partial, reduce import jax import jax.numpy as jnp @@ -18,13 +18,14 @@ from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from .base import BasePrimitive, register_primitive -from .quantization import quantize, grouped_quantize +from .quantization import grouped_quantize from ..quantize import ( ScaledTensor, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, + BlockScaleQuantizer, GroupedQuantizer, QuantizeConfig, QuantizerSet, @@ -123,6 +124,10 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ) if isinstance(lhs_q, ScaledTensor2x): lhs_q = lhs_q.get_rowwise_tensor() if lhs_is_rowwise else lhs_q.get_colwise_tensor() + if jnp.any(jnp.isnan(lhs_q.data)): + print("Found NaNs in quantized LHS data.") + if jnp.any(jnp.isnan(lhs_q.scale_inv)): + print("Found NaNs in quantized LHS scale_inv.") if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) @@ -136,6 +141,10 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ) if isinstance(rhs_q, ScaledTensor2x): rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_rowwise else rhs_q.get_colwise_tensor() + if jnp.any(jnp.isnan(rhs_q.data)): + print("Found NaNs in quantized RHS data.") + if jnp.any(jnp.isnan(rhs_q.scale_inv)): + print("Found NaNs in quantized RHS scale_inv.") return lhs_q, rhs_q @@ -165,7 +174,10 @@ def abstract( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): + del use_split_accumulator + # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) ( @@ -288,6 +300,7 @@ def lowering( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): del out_dtype lhs_aval, _, rhs_aval, *_ = ctx.avals_in @@ -306,6 +319,7 @@ def lowering( "fuse_bias": fuse_bias, "fuse_gelu": fuse_gelu, "grad": grad, + "use_split_accumulator": use_split_accumulator, } operand_output_aliases = {} @@ -333,6 +347,7 @@ def impl( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): outputs = GemmPrimitive.inner_primitive.bind( lhs, @@ -347,6 +362,7 @@ def impl( fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, + use_split_accumulator=use_split_accumulator, ) return outputs[:-3] # discard workspace arrays @@ -360,6 +376,7 @@ def batcher( fuse_bias, fuse_gelu, grad, + use_split_accumulator, ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args @@ -381,6 +398,7 @@ def batcher( fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, + use_split_accumulator=use_split_accumulator, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -393,11 +411,12 @@ def infer_sharding_from_operands( fuse_bias, fuse_gelu, grad, + use_split_accumulator, mesh, arg_infos, result_infos, ): - del out_dtype, scaling_mode, result_infos + del out_dtype, scaling_mode, use_split_accumulator, result_infos # Check contracting dimensions lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) @@ -467,6 +486,7 @@ def partition( fuse_bias, fuse_gelu, grad, + use_split_accumulator, mesh, arg_infos, result_infos, @@ -478,6 +498,7 @@ def partition( fuse_bias, fuse_gelu, grad, + use_split_accumulator, mesh, arg_infos, result_infos, @@ -535,6 +556,7 @@ def _te_gemm( fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, + use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands lhs_data = lhs @@ -600,6 +622,7 @@ def _te_gemm( fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, + use_split_accumulator=use_split_accumulator, ) @@ -889,6 +912,12 @@ def _jax_gemm_fp8_impl(lhs, rhs): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): + if isinstance(lhs_q, ScaledTensor2x): + lhs_is_transposed = lhs.ndim - 1 not in sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + if isinstance(rhs_q, ScaledTensor2x): + rhs_is_transposed = rhs.ndim - 1 in sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( @@ -941,6 +970,9 @@ def gemm( grad: bool, default = False Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with TE's custom call to cuBLAS GEMM. + use_split_accumulator: bool, default = True + Enable promoting some intermediate sums to higher precision when accumulating the result in + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Returns ------- @@ -966,13 +998,15 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) if not GemmPrimitive.enabled(): - assert kwargs.get("bias", None) is None and not kwargs.get("fuse_bias", False), ( + assert kwargs.get("bias", None) is None and not fuse_gelu, ( "TE GEMM was invoked with bias fusion options that are not supported by the " "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) - assert kwargs.get("gelu_input", None) is None and not kwargs.get("fuse_gelu", False), ( + assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( "TE GEMM was invoked with GeLU fusion options that are not supported by the " "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." @@ -989,8 +1023,6 @@ def gemm( ) # Discard empty outputs - fuse_bias = kwargs.get("fuse_bias", False) - fuse_gelu = kwargs.get("fuse_gelu", False) grad = kwargs.get("grad", False) clean_outputs = outputs[0] # first output is the final result and is never empty if (fuse_bias and grad) or (fuse_gelu and not grad): diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 987b0e817d..f7a84e41c9 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -100,7 +100,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad) { + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) @@ -162,8 +162,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, false, - num_math_sm, stream); + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); return ffi_with_cuda_error_check(); } @@ -190,7 +190,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("rhs_transposed") .Attr("fuse_bias") .Attr("fuse_gelu") - .Attr("grad"), + .Attr("grad") + .Attr("use_split_accumulator"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 98dc38de24..b4e809106e 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -157,26 +157,15 @@ def __post_init__(self): f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" f" {self.scale_inv.shape}" ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, - self.is_colwise, - is_padded=False, - flatten_axis=self.flatten_axis, + pad_width = tuple( + (0, a - b) + for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + ) + + # This actually pad scale_inv with nan, should we pad it with 127 directly instead? + self.scale_inv = jnp.pad( + self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) - for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) - ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 - ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 1be8773ae58e7d71a0c98ec07085c8b1af47f056 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 07:41:22 -0700 Subject: [PATCH 08/27] padding scale with 2^-127 instead of nans Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/tensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index b4e809106e..19a8658b8c 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -162,11 +162,13 @@ def __post_init__(self): for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? + # padding with the smallest number it can present self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 ) + + def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 75008de5893c3a768bab2b4f73786e1c6bbc97d0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 07:41:44 -0700 Subject: [PATCH 09/27] fix bug on rhs_scale_inv usage Signed-off-by: Phuong Nguyen --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f7a84e41c9..249d229de4 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -111,7 +111,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, rhs, rhs_scale_inv, lhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], From 5b0c1f57d9c16840c04d2ca523bd63bf66f206ab Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 17 Jun 2025 08:28:09 -0700 Subject: [PATCH 10/27] cleanup E8M0 type converter use it in gemm.cpp Signed-off-by: Phuong Nguyen --- transformer_engine/jax/csrc/extensions/ffi.cpp | 7 +++---- transformer_engine/jax/csrc/extensions/gemm.cpp | 14 ++++++-------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index a760df4a79..e77c38e990 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E4M3FN: return DType::kFloat8E4M3; break; - // case xla::ffi::DataType::F8E8M0FNU: - // return DType::kFloat8E8M0; - // break; + case xla::ffi::DataType::F8E8M0FNU: + return DType::kFloat8E8M0; + break; default: auto type_num = static_cast(type); - if (type_num == 33) return DType::kFloat8E8M0; NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", static_cast(type_num)); break; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 249d229de4..a3a1899b14 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -45,9 +45,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( auto scale_dims = scale_inv.dimensions(); std::vector scale_shape = {product(scale_dims, 0, axis_boundary), product(scale_dims, axis_boundary, scale_dims.size())}; - auto scale_dtype = (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) - ? DType::kFloat8E8M0 - : convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { @@ -66,14 +64,14 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( "Inverse scale factors need to have an 8-bit data type."); // Create tensor to hold swizzled scale factor - TensorWrapper output(NVTE_MXFP8_1D_SCALING); + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); if (rowwise) { output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } @@ -82,10 +80,10 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), DType::kFloat8E8M0, + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } } From b49d586325b5af2afd0fdfb5e028675c5dd7520a Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 18 Jun 2025 06:12:09 +0000 Subject: [PATCH 11/27] segfault fixed, passing all unittests on Blackwell Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 2 +- tests/jax/test_layer.py | 14 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 43 +++++++++++-------- .../jax/csrc/extensions/gemm.cpp | 10 +++-- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9ea44bb1f7..3de6453f5e 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1118,7 +1118,7 @@ def ref_func(x, w, gamma, beta): @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 389148415a..058dbf5bb4 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -479,6 +479,7 @@ def generate_inputs(self, data_shape, dtype): @pytest.mark.parametrize("data_shape", DATA_SHAPE) @pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("with_jax_gemm", (False, True)) @pytest.mark.parametrize("attrs", ATTRS) class BaseTester: """ @@ -494,20 +495,22 @@ def use_jax_dot_for_gemm(self, enabled=False): elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - def test_forward(self, data_shape, dtype, attrs): + def test_forward(self, data_shape, dtype, with_jax_gemm, attrs): """Test normal datatype forward""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype) - def test_backward(self, data_shape, dtype, attrs): + + def test_backward(self, data_shape, dtype, with_jax_gemm, attrs): """Test normal datatype backward""" + self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - @pytest.mark.parametrize("with_jax_gemm", (False, True)) - def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): + def test_forward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): """Test forward with fp8 enabled""" self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) @@ -516,8 +519,7 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_g @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - @pytest.mark.parametrize("with_jax_gemm", (False, True)) - def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe, with_jax_gemm): + def test_backward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): """Test backward with fp8 enabled""" self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c54314e283..a8cbe75702 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -25,7 +25,6 @@ GroupedScaledTensor1x, ScalingMode, Quantizer, - BlockScaleQuantizer, GroupedQuantizer, QuantizeConfig, QuantizerSet, @@ -114,37 +113,43 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ rhs_q = rhs if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) - lhs_is_rowwise = lhs.ndim - 1 in lhs_cdims - flatten_axis = min(lhs_cdims) if lhs_is_rowwise else max(lhs_cdims) + 1 + lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims + need_lhs_colwise = ( + lhs_is_transposed + and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not tex.is_non_nt_fp8_gemm_supported() + ) + ) + flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, + is_rowwise=not need_lhs_colwise, + is_colwise=need_lhs_colwise, flatten_axis=flatten_axis, ) if isinstance(lhs_q, ScaledTensor2x): - lhs_q = lhs_q.get_rowwise_tensor() if lhs_is_rowwise else lhs_q.get_colwise_tensor() - if jnp.any(jnp.isnan(lhs_q.data)): - print("Found NaNs in quantized LHS data.") - if jnp.any(jnp.isnan(lhs_q.scale_inv)): - print("Found NaNs in quantized LHS scale_inv.") + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) - rhs_is_rowwise = rhs.ndim - 1 in rhs_cdims - flatten_axis = min(rhs_cdims) if rhs_is_rowwise else max(rhs_cdims) + 1 + rhs_is_transposed = rhs.ndim - 1 in rhs_cdims + need_rhs_colwise = ( + not rhs_is_transposed + and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not tex.is_non_nt_fp8_gemm_supported() + ) + ) + flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, + is_rowwise=not need_rhs_colwise, + is_colwise=need_rhs_colwise, flatten_axis=flatten_axis, ) if isinstance(rhs_q, ScaledTensor2x): - rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_rowwise else rhs_q.get_colwise_tensor() - if jnp.any(jnp.isnan(rhs_q.data)): - print("Found NaNs in quantized RHS data.") - if jnp.any(jnp.isnan(rhs_q.scale_inv)): - print("Found NaNs in quantized RHS scale_inv.") + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() return lhs_q, rhs_q diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index a3a1899b14..c88aac3ff6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -42,9 +42,13 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); - auto scale_dims = scale_inv.dimensions(); - std::vector scale_shape = {product(scale_dims, 0, axis_boundary), - product(scale_dims, axis_boundary, scale_dims.size())}; + std::vector scale_shape = {1}; + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Block scaling also needs to be collapsed to match 2D data + scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), + product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + } + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); From bd9bca34cedbd9e36e6956c620cfaa1f2b1bdbf1 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Jun 2025 06:07:35 -0700 Subject: [PATCH 12/27] fix for fuseddense tests Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 2 -- transformer_engine/jax/layernorm_dense.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0b7f68d79e..e4408f74c1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -629,8 +629,6 @@ def _te_gemm( ) -======= ->>>>>>> main class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 3cdbeecd5a..816436a00e 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -293,7 +293,6 @@ def _layernorm_dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - (g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim), ) @@ -307,7 +306,6 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - (x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim), ) From 8fcb1bb0f34de33c452c308170b29e9743284978 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 18 Jun 2025 08:02:55 -0700 Subject: [PATCH 13/27] fix workspace alignment Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 9 ++++---- .../jax/csrc/extensions/gemm.cpp | 21 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e4408f74c1..d5946e1718 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -32,7 +32,7 @@ noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, ) -from ..sharding import get_padded_spec +from .misc import get_padded_spec __all__ = [ @@ -277,9 +277,10 @@ def abstract( rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) # Declare cuBLAS workspace - workspace = jax.core.ShapedArray( - shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 - ) + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 0fb8734998..eed2ed2a4f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -21,6 +21,13 @@ namespace transformer_engine { namespace jax { + +static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { + // Move the pointer to the next 256B aligned address + return reinterpret_cast((reinterpret_cast(ptr) + 255) & + ~static_cast(255)); +} + std::tuple> xla_buffer_to_nvte_gemm_operand( cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { @@ -157,9 +164,11 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); - // cuBLAS workspace - std::vector workspace_shape = {static_cast(workspace->element_count())}; - auto workspace_ = TensorWrapper(workspace->untyped_data(), workspace_shape, DType::kByte); + // cuBLAS workspace + 256 alignment enforcement + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; + auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -197,12 +206,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, FFI_CudaGraph_Traits); -static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { - // Move the pointer to the next 256B aligned address - return reinterpret_cast((reinterpret_cast(ptr) + 255) & - ~static_cast(255)); -} - Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, From b2b41592b444ae3ac0b67b44ce7146ac6e25fc74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:05:07 +0000 Subject: [PATCH 14/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_layer.py | 1 - transformer_engine/jax/cpp_extensions/gemm.py | 32 ++++++++----------- .../jax/csrc/extensions/gemm.cpp | 8 ++--- transformer_engine/jax/quantize/tensor.py | 5 +-- 4 files changed, 16 insertions(+), 30 deletions(-) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 058dbf5bb4..9f3da5094f 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -501,7 +501,6 @@ def test_forward(self, data_shape, dtype, with_jax_gemm, attrs): QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype) - def test_backward(self, data_shape, dtype, with_jax_gemm, attrs): """Test normal datatype backward""" self.use_jax_dot_for_gemm(enabled=with_jax_gemm) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d5946e1718..9cd3c85619 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -36,13 +36,13 @@ __all__ = [ - "gemm", - "grouped_gemm", - "gemm_uses_jax_dot", - "sanitize_dims", - "get_non_contracting_dims", - "transpose_contracting_dims", - ] + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_contracting_dims", +] num_cublas_streams = get_num_compute_streams() @@ -109,12 +109,9 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims - need_lhs_colwise = ( - lhs_is_transposed - and ( - lhs_quantizer.scaling_mode.is_1d_block_scaling() - or not is_fp8_gemm_with_all_layouts_supported() - ) + need_lhs_colwise = lhs_is_transposed and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() ) flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( @@ -130,12 +127,9 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) rhs_is_transposed = rhs.ndim - 1 in rhs_cdims - need_rhs_colwise = ( - not rhs_is_transposed - and ( - rhs_quantizer.scaling_mode.is_1d_block_scaling() - or not is_fp8_gemm_with_all_layouts_supported() - ) + need_rhs_colwise = not rhs_is_transposed and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() ) flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index eed2ed2a4f..2c5f027ba1 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -21,7 +21,6 @@ namespace transformer_engine { namespace jax { - static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { // Move the pointer to the next 256B aligned address return reinterpret_cast((reinterpret_cast(ptr) + 255) & @@ -78,8 +77,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); if (rowwise) { output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, @@ -91,8 +89,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); } else { input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); @@ -205,7 +202,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("use_split_accumulator"), FFI_CudaGraph_Traits); - Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 04c9f2774f..a87326e9fe 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -149,8 +149,7 @@ def __post_init__(self): f" {self.scale_inv.shape}" ) pad_width = tuple( - (0, a - b) - for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) + (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) ) # padding with the smallest number it can present @@ -158,8 +157,6 @@ def __post_init__(self): self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 ) - - def tree_flatten(self): """Flattens the tensor for JAX tree operations. From ae4828c8df4904423da3bbf0edbc9475cba6412d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 18 Jun 2025 22:51:20 +0000 Subject: [PATCH 15/27] fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul Signed-off-by: Alp Dener all unit tests passing on H100x8 node Signed-off-by: Alp Dener [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci linting fixes Signed-off-by: Alp Dener fixed batch dimension numbers Signed-off-by: Alp Dener fixed FP8 scale sharding rule when there are no FP8 scales Signed-off-by: Alp Dener added error message for unsupported Shardy partitioner Signed-off-by: Alp Dener fixed test tolerances for FP8 cases Signed-off-by: Alp Dener fixed shardy test skip cases Signed-off-by: Alp Dener --- .../encoder/test_model_parallel_encoder.py | 23 +- examples/jax/encoder/test_multigpu_encoder.py | 23 +- .../encoder/test_multiprocessing_encoder.py | 24 +- tests/jax/test_custom_call_compute.py | 4 +- tests/jax/test_distributed_layernorm_mlp.py | 116 ++++- tests/jax/test_layer.py | 20 +- transformer_engine/jax/cpp_extensions/gemm.py | 458 ++++++++++++------ transformer_engine/jax/dense.py | 57 ++- transformer_engine/jax/layernorm_dense.py | 21 +- transformer_engine/jax/layernorm_mlp.py | 30 +- 10 files changed, 567 insertions(+), 209 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index b2bd18205f..d733fa7097 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -25,6 +25,7 @@ assert_params_sufficiently_sharded, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -307,7 +308,9 @@ def train_and_evaluate(args): key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit(encoder.init, + in_shardings=in_shardings, + out_shardings=out_shardings) var_collect = jit_encoder_init(init_rngs, inputs, masks) # Check if params are sufficiently sharded after initialization @@ -344,11 +347,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit(train_step, + in_shardings=in_shardings, + out_shardings=out_shardings) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, + in_shardings=in_shardings, + out_shardings=out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -459,8 +466,8 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -510,6 +517,8 @@ def test_te_mxfp8_with_sp(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -517,6 +526,8 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -526,6 +537,8 @@ def test_te_delayed_scaling_fp8_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_with_sp_shardy(self): """Test Transformer Engine with DelayedScaling FP8 + SP""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index b6f4db1084..8c9751c620 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -21,6 +21,7 @@ from common import is_bf16_supported, get_fp8_recipe_from_name_string import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -288,7 +289,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit(encoder.init, + in_shardings=in_shardings, + out_shardings=out_shardings) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -312,11 +315,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit(train_step, + in_shardings=in_shardings, + out_shardings=out_shardings) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, + in_shardings=in_shardings, + out_shardings=out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -424,8 +431,8 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -458,6 +465,8 @@ def test_te_mxfp8(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -465,6 +474,8 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -476,6 +487,8 @@ def test_te_delayed_scaling_fp8_shardy(self): # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c7606c3ab0..35a0e766a4 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -28,8 +28,8 @@ get_fp8_recipe_from_name_string, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -412,7 +412,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) + jit_encoder_init = jax.jit(encoder.init, + in_shardings=in_shardings, + out_shardings=out_shardings) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -432,11 +434,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, in_shardings, out_shardings) + jit_train_step = jax.jit(train_step, + in_shardings=in_shardings, + out_shardings=out_shardings) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) + jit_eval_step = jax.jit(eval_step, + in_shardings=in_shardings, + out_shardings=out_shardings) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -578,8 +584,8 @@ class TestEncoder(unittest.TestCase): """Encoder unittests""" def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): - """Run 3 epochs for testing""" - args = encoder_parser([]) + """Run 5 epochs for testing""" + args = encoder_parser(["--epochs", "5"]) num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 @@ -628,6 +634,8 @@ def test_te_mxfp8(self): assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) @@ -636,6 +644,8 @@ def test_te_bf16_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" ) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) @@ -646,6 +656,8 @@ def test_te_delayed_scaling_fp8_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" ) + @unittest.skipIf(not tex.gemm_uses_jax_dot(), + "TE cuBLAS GEMM custom op does not support shardy") def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4876aef99a..a50d5363ae 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -898,7 +898,7 @@ def _generate_gemm_input(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) - primitive_out = tex.gemm(x, w, contracting_dims) + primitive_out = tex.gemm(x, w, dimension_numbers=(contracting_dims, ((), ()))) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @@ -929,7 +929,7 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi primitive_out = tex.gemm( x, w, - contracting_dims=contracting_dims, + dimension_numbers=(contracting_dims, ((), ())), lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, rhs_quantizer=( quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index f16c84094d..3c27f0d472 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -44,6 +44,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) if is_mxfp8_supported: SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) +SUPPORTED_RECIPES_WITH_SHARDY = SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES DTYPES = [jnp.bfloat16, jnp.float16] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] @@ -74,6 +75,15 @@ def generate_fsdp_and_tp_configs(): return configs +def use_jax_fp8_gemm(enabled=False): + import os + + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDistributedLayernormMLP: def generate_inputs(self, input_shape, activation_type, use_bias, dtype): @@ -146,8 +156,17 @@ def layernorm_fp8_mlp_prim_func( ) def _test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy, + with_jax_gemm, ): + use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -208,20 +227,25 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - assert_allclose(multi_fwd, single_fwd, dtype=dtype) + fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn + bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): assert isinstance(single_grads[i], list) for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): assert_allclose( - m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" + m_grad, + s_grad, + dtype=bwd_test_type, + err_msg=f"multi_grads[{i}] is not close" ) else: assert_allclose( multi_grads[i], single_grads[i], - dtype=dtype, + dtype=bwd_test_type, err_msg=f"multi_grads[{i}] is not close", ) @@ -232,8 +256,16 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): self._test_layernorm_mlp_grad( mesh_config, @@ -243,6 +275,7 @@ def test_layernorm_mlp_grad( dtype, fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -251,19 +284,22 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): - # We don't test block scaling with Shardy because at the time of writing, - # it is not supported in JAX's scaled_matmul_stablehlo. + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the + # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=recipe.DelayedScaling(), + fp8_recipe=fp8_recipe, use_shardy=True, + with_jax_gemm=True, ) def _test_layernorm_mlp( @@ -276,7 +312,9 @@ def _test_layernorm_mlp( use_fp8, fp8_recipe, use_shardy, + with_jax_gemm, ): + use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -340,9 +378,9 @@ def _test_layernorm_mlp( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("use_shardy", [False, True]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -352,7 +390,8 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, fp8_recipe=None, - use_shardy=use_shardy, + use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -361,9 +400,10 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -374,4 +414,52 @@ def test_layernorm_mlp_layer_fp8( use_fp8=True, fp8_recipe=fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, + ) + + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + def test_layernorm_mlp_layer_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX dot_general. + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=True, + with_jax_gemm=True, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) + def test_layernorm_mlp_layer_fp8_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + ): + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the + # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=True, + fp8_recipe=fp8_recipe, + use_shardy=True, + with_jax_gemm=True, ) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 9f3da5094f..d59e130530 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -479,7 +479,6 @@ def generate_inputs(self, data_shape, dtype): @pytest.mark.parametrize("data_shape", DATA_SHAPE) @pytest.mark.parametrize("dtype", DTYPE) -@pytest.mark.parametrize("with_jax_gemm", (False, True)) @pytest.mark.parametrize("attrs", ATTRS) class BaseTester: """ @@ -488,39 +487,28 @@ class BaseTester: runner = BaseRunner - def use_jax_dot_for_gemm(self, enabled=False): - """Enable/disable TE custom cuBLAS GEMM primitive.""" - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - def test_forward(self, data_shape, dtype, with_jax_gemm, attrs): + def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_forward(data_shape, dtype) - def test_backward(self, data_shape, dtype, with_jax_gemm, attrs): + def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.finalize() # Ensure FP8 disabled. self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_forward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): + def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) - def test_backward_with_fp8(self, data_shape, dtype, with_jax_gemm, attrs, fp8_recipe): + def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" - self.use_jax_dot_for_gemm(enabled=with_jax_gemm) QuantizeConfig.initialize(fp8_recipe=fp8_recipe) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9cd3c85619..807ca62713 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -15,7 +15,7 @@ from jax.sharding import NamedSharding, PartitionSpec import transformer_engine_jax as tex -from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams +from transformer_engine_jax import get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -41,7 +41,7 @@ "gemm_uses_jax_dot", "sanitize_dims", "get_non_contracting_dims", - "transpose_contracting_dims", + "transpose_dims", ] @@ -60,7 +60,7 @@ def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: dims_ = dims if isinstance(dims, Iterable) else (dims,) if len(dims_) == 0: return dims_ - return tuple(ndim + dim if dim < 0 else dim for dim in dims_) + return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None) def get_non_contracting_dims(ndim, contracting_dims): @@ -69,10 +69,13 @@ def get_non_contracting_dims(ndim, contracting_dims): return tuple(dim for dim in range(ndim) if dim not in contracting_dims) -def transpose_contracting_dims(ndim, contracting_dims): - """Compute the new dimension numbers for contracting dimensions after a transpose.""" - contracting_dims = sanitize_dims(ndim, contracting_dims) - return tuple(ndim - i - 1 for i in contracting_dims)[::-1] +def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1): + """Compute the new dimension numbers after transpose.""" + if len(dims_to_transpose) == 0: + return dims_to_transpose + flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis + transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis)) + return tuple(transposed_dims.index(dim) for dim in dims_to_transpose) def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: @@ -106,6 +109,7 @@ def _get_gemm_layout( def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): lhs_q = lhs rhs_q = rhs + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims @@ -120,9 +124,6 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_colwise=need_lhs_colwise, flatten_axis=flatten_axis, ) - # TODO: remove - # if isinstance(lhs_q, ScaledTensor2x): - # lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) @@ -138,8 +139,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ is_colwise=need_rhs_colwise, flatten_axis=flatten_axis, ) - # if isinstance(rhs_q, ScaledTensor2x): - # rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x) @@ -166,7 +166,7 @@ def abstract( bias, gelu_input, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -177,6 +177,7 @@ def abstract( # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) + contracting_dims, _ = dimension_numbers ( lhs_contracting_dims, rhs_contracting_dims, @@ -293,7 +294,7 @@ def lowering( bias, gelu_input, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -301,6 +302,7 @@ def lowering( use_split_accumulator, ): del out_dtype + contracting_dims, _ = dimension_numbers lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -340,7 +342,7 @@ def impl( bias, gelu_input, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -355,7 +357,7 @@ def impl( bias, gelu_input, out_dtype=out_dtype, - contracting_dims=contracting_dims, + dimension_numbers=dimension_numbers, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -369,7 +371,7 @@ def batcher( batched_args, batch_dims, out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -378,12 +380,30 @@ def batcher( ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args - lhs_bdims, *_ = batch_dims + lhs_bdims, _, rhs_bdims, *_ = batch_dims + contracting_dims, batch_dims = dimension_numbers + arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) + arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) - # Output is batched like LHS only if LHS is batched and RHS is not - out_bdims = lhs_bdims if lhs.ndim > 2 and rhs.ndim == 2 else (None,) - bias_bdims = (None,) # Bias is never batched - pre_gelu_bdims = (None,) # Pre-GeLU output, if exists, is batched like GEMM output + # Output is batched like the non-contracting batch dimensions of the LHS operand + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) + lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) + out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + + # Bias gradient is never batched + bias_bdims = (None,) + + # Pre-GeLU output, if exists, is batched like GEMM output + pre_gelu_bdims = (None,) if fuse_gelu and not grad: pre_gelu_bdims = out_bdims @@ -391,7 +411,7 @@ def batcher( GemmPrimitive.outer_primitive.bind( *batched_args, out_dtype=out_dtype, - contracting_dims=contracting_dims, + dimension_numbers=dimension_numbers, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -401,10 +421,171 @@ def batcher( (out_bdims, bias_bdims, pre_gelu_bdims), ) + @staticmethod + def _decompose_operand_specs(specs, contracting_dims, batch_dims): + ndims = len(specs) + cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) + + # Batch specs + bspecs = tuple(specs[i] for i in bdims) + + # Non-batch leading dimension specs + lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) + + # Non-batch contracting dimension specs + cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + + return bspecs, lspecs, cspecs + + @staticmethod + def _parse_operand_output_specs(arg_infos, dimension_numbers): + lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + ) + ( + (lhs_bspecs, lhs_lspecs, lhs_cspecs), + (rhs_bspecs, rhs_lspecs, rhs_cspecs), + ) = map( + GemmPrimitive._decompose_operand_specs, + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + + # Batched dimensions must have the same sharding + if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: + assert all( + lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) + ), ( + "cuBLAS GEMM operand batch dimensions must have the same sharding: " + f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + ) + + # Only one each of the non-batched leading dimensions and non-batched contracting + # dimensions can be sharded + lhs_ldims, rhs_ldims = map( + lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + (lhs_ndim, rhs_ndim), + (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), + ) + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( + lambda specs: tuple(spec for spec in specs if spec is not None), + (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), + ) + assert ( + len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1 + ), ( + "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " + f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." + ) + assert ( + len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1 + ), ( + "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " + f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + ) + + # Extract single leading and contracting dimension specs + (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + lambda specs: None if len(specs) == 0 else specs[0], + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + ) + + # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts + # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. + # 1. K1 == K2 != None and N == None + # LHS: (B, M, K) + # RHS: (B, None, K) + # OUT: (B, M, None) --(AR)-> (B, M, None) + # 2. K1 == K2 != None and M == N != None + # LHS: (B, M, K) + # RHS: (B, N, K)--(AG)->(B, None, K) + # OUT: (B, M, None) --(RS)--> (B, M, N) + # 3. M == N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, M, K)--(AG)->(B, None, None) + # OUT: (B, M, None) + # 4. M != N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, N, K)--(AG)->(B, N, None) + # OUT: (B, M, N) + reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec + all_reduce_output = reduce_flag and rhs_lspec is None + reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec + all_reduce_spec = reduce_scatter_spec = scatter_dim = None + + lhs_non_contracting_specs, rhs_non_contracting_specs = map( + lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + ) + out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) + if reduce_scatter_output: + # All-gather (if necessary) the non-batch non-contracting dimension of RHS + # (B, N, K) --(AG)-> (B, None, K) + # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) + rhs_spec = tuple( + rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) + ) + reduce_scatter_spec = lhs_cspec + scatter_dim = out_specs.index(rhs_lspec) + + elif all_reduce_output: + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + all_reduce_spec = lhs_cspec + else: + # All-gather (if necessary) the non-batch contracting dimensions + # (B, M, K) --(AG)-> (B, M, None) + # (B, N, K) --(AG)-> (B, N, None) + # (B, M, None) x (B, N, None)^T = (B, M, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] + for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Check if RHS non-contracting spec also appears in the LHS non-contracting specs + if rhs_lspec is not None and rhs_lspec in tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims + ): + # All-gather (if necessary) the non-batch non-contracting dimensions of RHS + # (B, N, None) --(AG)-> (B, None, None) + # (B, M, None) x (B, None, None)^T = (B, M, None) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_contracting_specs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs), + (out_specs, bias_specs, gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) + @staticmethod def infer_sharding_from_operands( out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -414,72 +595,29 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del out_dtype, scaling_mode, use_split_accumulator, result_infos + del out_dtype, scaling_mode, grad, use_split_accumulator, result_infos - # Check contracting dimensions - lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) - operand_ndims = (len(lhs_spec), len(rhs_spec)) - lhs_contracting_dims, rhs_contracting_dims = map( - sanitize_dims, operand_ndims, contracting_dims - ) - lhs_contracting_specs, rhs_contracting_specs = map( - lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], - (lhs_spec, rhs_spec), - (lhs_contracting_dims, rhs_contracting_dims), - ) - assert ( - len(lhs_contracting_specs) <= 1 and len(rhs_contracting_specs) <= 1 - ), "cuBLAS GEMM operands can have only one sharded contracting dimension." - lhs_contracting_spec, rhs_contracting_spec = map( - lambda spec: None if len(spec) == 0 else spec[0], - (lhs_contracting_specs, rhs_contracting_specs), + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) ) - assert ( - lhs_contracting_spec == rhs_contracting_spec - ), "cuBLAS GEMM operands must have the same sharding in contracting dimensions." + out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Sanity check leading dimensions, allow for simultaneous batch and sequence sharding - lhs_leading_dims, rhs_leading_dims = map( - get_non_contracting_dims, operand_ndims, (lhs_contracting_dims, rhs_contracting_dims) - ) - lhs_leading_specs, rhs_leading_specs = map( - lambda specs, dims: [specs[dim] for dim in dims if specs[dim] is not None], - (lhs_spec, rhs_spec), - (lhs_leading_dims, rhs_leading_dims), - ) - assert len(lhs_leading_specs) <= 1 and len(rhs_leading_specs) <= 1, ( - "cuBLAS GEMM operands cannot have more than one sharded leading dimensions. This error " - "usually means a sequence-parallel operand was not all-gathered before the GEMM op." - ) + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) - # Determine output sharding - lhs_leading_spec, rhs_leading_spec = map( - lambda spec: None if len(spec) == 0 else spec[0], (lhs_leading_specs, rhs_leading_specs) - ) - out_spec = (lhs_leading_spec, rhs_leading_spec) - if operand_ndims[0] > 2 and operand_ndims[1] == 2: - # Restore batch dimensions/sharding to the output - out_spec = (*lhs_leading_specs, rhs_leading_spec) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) - - # Bias gradient sharding inherits the RHS contracting spec - bias_spec = (None,) - if fuse_bias and grad: - bias_spec = (rhs_contracting_spec,) - bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - - # Pre-GeLU sharding matches output sharding - pre_gelu_spec = (None,) - if fuse_gelu and not grad: - pre_gelu_spec = out_spec - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) - return (out_sharding, bias_sharding, pre_gelu_sharding) + return [out_sharding, dbias_sharding, pre_gelu_sharding] @staticmethod def partition( out_dtype, - contracting_dims, + dimension_numbers, scaling_mode, fuse_bias, fuse_gelu, @@ -489,51 +627,96 @@ def partition( arg_infos, result_infos, ): - out_shardings = GemmPrimitive.infer_sharding_from_operands( - out_dtype, - contracting_dims, - scaling_mode, - fuse_bias, - fuse_gelu, - grad, - use_split_accumulator, - mesh, - arg_infos, - result_infos, + del result_infos + + ( + (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), + (out_specs, dbias_specs, pre_gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) + + # Assemble argument shardings + # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + none_sharding = NamedSharding(mesh, PartitionSpec(None)) + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + arg_shardings = ( + lhs_sharding, + lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_sharding, + rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, ) - output_spec = out_shardings[0].spec - # Operand shardings are already guarded with asserts so leave them unchanged here - lhs_spec, _, rhs_spec, *_ = map(get_padded_spec, arg_infos) - lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec)) - rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + # Discard bias input spec if there is no bias fusion + if not fuse_bias: + bias_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),) + + # Discard pre-GeLU input spec if there is no GeLU fusion + if not fuse_gelu: + gelu_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + + # Assemble output shardings + out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] + + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) + + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) + + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + outputs = GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + dimension_numbers=dimension_numbers, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) - # Any distributed scales (e.g. MXFP8) need to be gathered - scale_sharding = NamedSharding(mesh, PartitionSpec(None)) + # All-Reduce/Reduce-Scatter GEMM output + if all_reduce_spec is not None: + outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) + elif reduce_scatter_spec is not None: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum_scatter( + outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) - # Bias has to be sharded same as the trailing dimension of the GEMM output - bias_spec = (None,) - if fuse_bias and not grad: - bias_spec = (output_spec[-1],) - bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + return outputs - # Pre-GeLU output has to be sharded same as the GEMM output - pre_gelu_spec = (None,) - if fuse_gelu and grad: - pre_gelu_spec = output_spec - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_spec)) + return mesh, _sharded_impl, out_shardings, arg_shardings - arg_shardings = ( - lhs_sharding, - scale_sharding, - rhs_sharding, - scale_sharding, - bias_sharding, - pre_gelu_sharding, + @staticmethod + def shardy_sharding_rule(*args, **kwargs): + del args, kwargs + raise NotImplementedError( + "TE cuBLAS GEMM custom op does not support the Shardy partitioner. You can disable the " + "custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE=\"^(?!GemmPrimitive$).+$\"` in the " + "environment, which will make GEMM operations in TE will execute with native " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` calls." ) - return mesh, GemmPrimitive.impl, out_shardings, arg_shardings - register_primitive(GemmPrimitive) @@ -550,7 +733,7 @@ def _te_gemm( gelu_input: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, @@ -562,8 +745,10 @@ def _te_gemm( lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING + contracting_dims, batch_dims = dimension_numbers lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -581,7 +766,8 @@ def _te_gemm( lhs_data = lhs_q.data lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": - lhs_cdims = transpose_contracting_dims(lhs_q.ndim, lhs_cdims) + lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -598,7 +784,8 @@ def _te_gemm( rhs_data = rhs_q.data rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": - rhs_cdims = transpose_contracting_dims(rhs_q.ndim, rhs_cdims) + rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -615,7 +802,7 @@ def _te_gemm( bias, gelu_input, out_dtype=out_dtype, - contracting_dims=(lhs_cdims, rhs_cdims), + dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -810,9 +997,11 @@ def _calculate_remaining_shape(shape, contracting_dims): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = transpose_contracting_dims(lhs.data.ndim, lhs_contract) + lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) + lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": - rhs_contract = transpose_contracting_dims(rhs.data.ndim, rhs_contract) + rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) + rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -911,12 +1100,6 @@ def _jax_gemm_fp8_impl(lhs, rhs): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): - # if isinstance(lhs_q, ScaledTensor2x): - # lhs_is_transposed = lhs.ndim - 1 not in sanitize_dims(lhs.ndim, contracting_dims[0]) - # lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() - # if isinstance(rhs_q, ScaledTensor2x): - # rhs_is_transposed = rhs.ndim - 1 in sanitize_dims(rhs.ndim, contracting_dims[1]) - # rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( @@ -933,7 +1116,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (-2,)), + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, **kwargs, @@ -950,10 +1133,11 @@ def gemm( Object for down-casting the LHS operand for quantized GEMM. rhs_quantizer: Quantizer, default = None Object for down-casting the RHS operand for quantized GEMM. - contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (-2, )) - Tuple of two sequences representing the contracting dimensions. The first sequence - represents the contracting dimensions of the LHS operand, and the second sequence - represents the contracting dimensions of the RHS operand. + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]], default = (((-1, ), (0, )), ((), ())) + Tuple of two tuples of sequences representing the contracting and batched dimensions, + respectively. The first sequence in each tuple represents the contracting/batched + dimensions of the LHS operand, and the second sequence represents the contracting/batched + dimensions of the RHS operand. bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1010,14 +1194,14 @@ def gemm( "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) - return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + return _jax_gemm(lhs, rhs, dimension_numbers[0], lhs_quantizer, rhs_quantizer) outputs = _te_gemm( lhs, rhs, lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, - contracting_dims=contracting_dims, + dimension_numbers=dimension_numbers, **kwargs, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index b2721915a1..6ed91f4319 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -30,6 +30,7 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -43,6 +44,7 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + batch_first: Assume that X is batched in the first dimension. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -51,17 +53,19 @@ def dense( # Remove when tex.quantize() can handle quantizer=None if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims) + output = tex.gemm(x, kernel, dimension_numbers=(contracting_dims, ((), ()))) if bias is not None: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) + output = _dense( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -75,23 +79,46 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types + batch_first: Assume that X is batched in the first dimension. Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +def _dense_fwd_rule( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set +): """Forward pass rule for dense layer transformation. Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims + x_contracting_dims, k_contracting_dims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + + # Check supported input layout + x_is_transposed = x.ndim - 1 not in x_contracting_dims + k_is_transposed = kernel.ndim - 1 in k_contracting_dims + assert ( + not x_is_transposed and not k_is_transposed + ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + + # Determine X batch dimension + # - If `batch_first=True` -> (batch, leading..., contracting...) + # - Otherwise -> (leading..., batch, contracting...) + # NOTE: Always assume a single batch dimension + x_bdim = None + num_cdims = len(x_contracting_dims) + if x.ndim >= num_cdims + 2: + # Assume X is batched if it has at least +2 dimensions more than the number of contracting + # dimensions. + x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) @@ -114,7 +141,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, output = tex.gemm( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -131,20 +158,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( casted_x_lhs, casted_kernel_rhs, @@ -153,8 +179,13 @@ def _dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) = ctx + fwd_x_contracting_dims, fwd_k_contracting_dims = map( + tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims + ) + casted_grad, dbias = tex.quantize_dbias( grad, is_dbias=use_bias, @@ -175,7 +206,7 @@ def _dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, - contracting_dims=(g_contracting_dim, k_contracting_dim), + dimension_numbers=((g_contracting_dim, k_contracting_dim), ((x_bdim,), ())), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -188,7 +219,7 @@ def _dense_bwd_rule( wgrad = tex.gemm( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), - contracting_dims=(x_contracting_dim, g_contracting_dim), + dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim, ))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 816436a00e..5608b24c5f 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -37,6 +37,7 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -57,6 +58,7 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix + batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers for different tensor types Returns: @@ -80,6 +82,7 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -94,6 +97,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -108,6 +112,7 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], + batch_first: bool, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -127,6 +132,7 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding + batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers Returns: @@ -144,6 +150,7 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -161,6 +168,7 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -178,6 +186,10 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] + x_bdim = None + if x.ndim > 2: + x_bdim = 0 if batch_first else x.ndim - 2 + x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -205,7 +217,7 @@ def _layernorm_dense_fwd_rule( output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -229,6 +241,7 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) return output, ctx @@ -241,6 +254,7 @@ def _layernorm_dense_bwd_rule( layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument kernel_axes, + batch_first, # pylint: disable=unused-argument ctx, grad, ): @@ -270,6 +284,7 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( @@ -293,7 +308,7 @@ def _layernorm_dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - contracting_dims=(g_constracting_dim, k_constracting_dim), + dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -306,7 +321,7 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - contracting_dims=(x_constracting_dim, g_constracting_dim), + dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim, ))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fb9a3163b9..53decc3c2a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -48,6 +48,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + batch_first: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -79,6 +80,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -124,12 +126,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -149,6 +152,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + batch_first: bool, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -174,6 +178,7 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of quantizer sets Returns: @@ -198,6 +203,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output @@ -222,6 +228,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -254,6 +261,10 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + x_bdim = None + if x.ndim > 2: + x_bdim = 0 if batch_first else x.ndim - 2 + use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -280,7 +291,7 @@ def _layernorm_mlp_fwd_rule( dot_1_output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) @@ -315,7 +326,7 @@ def _layernorm_mlp_fwd_rule( dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, ) @@ -345,6 +356,7 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) return dot_2_output, ctx @@ -362,6 +374,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, ctx, grad, ): @@ -378,7 +391,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first ( x, mu, @@ -397,6 +410,7 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -422,7 +436,7 @@ def _layernorm_mlp_bwd_rule( dgrad_2 = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, - contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + dimension_numbers=((g_contracting_dims_2, k_contracting_dims_2), ((x_bdim,), ())), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -436,7 +450,7 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, g_contracting_dims), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim, ), (x_bdim, ))), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -463,7 +477,7 @@ def _layernorm_mlp_bwd_rule( dgrad_1 = tex.gemm( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, - contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + dimension_numbers=((g_contracting_dims_1, k_contracting_dims_1), ((x_bdim,), ())), ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -473,7 +487,7 @@ def _layernorm_mlp_bwd_rule( wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - contracting_dims=(x_contracting_dims, g_contracting_dims), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim, ))), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From ddaaab95dc32f7f9e38da389c4ccc1c957cb49af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:40:05 +0000 Subject: [PATCH 16/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../encoder/test_model_parallel_encoder.py | 33 ++++++++++--------- examples/jax/encoder/test_multigpu_encoder.py | 33 ++++++++++--------- .../encoder/test_multiprocessing_encoder.py | 33 ++++++++++--------- tests/jax/test_distributed_layernorm_mlp.py | 6 ++-- transformer_engine/jax/cpp_extensions/gemm.py | 10 ++---- transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/layernorm_dense.py | 2 +- transformer_engine/jax/layernorm_mlp.py | 4 +-- 8 files changed, 65 insertions(+), 58 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index d733fa7097..00203a4537 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -308,9 +308,9 @@ def train_and_evaluate(args): key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) # Check if params are sufficiently sharded after initialization @@ -347,15 +347,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -517,8 +517,9 @@ def test_te_mxfp8_with_sp(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -526,8 +527,9 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -537,8 +539,9 @@ def test_te_delayed_scaling_fp8_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_with_sp_shardy(self): """Test Transformer Engine with DelayedScaling FP8 + SP""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 8c9751c620..3d7fab0115 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -289,9 +289,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -315,15 +315,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -465,8 +465,9 @@ def test_te_mxfp8(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -474,8 +475,9 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.535 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -487,8 +489,9 @@ def test_te_delayed_scaling_fp8_shardy(self): # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_fp8_supported, fp8_reason) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 35a0e766a4..b5d03c0796 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -412,9 +412,9 @@ def train_and_evaluate(args): out_shardings = { key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect } - jit_encoder_init = jax.jit(encoder.init, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) var_collect = jit_encoder_init(init_rngs, inputs, masks) optimizer = optax.adamw(args.lr) @@ -434,15 +434,15 @@ def train_and_evaluate(args): None, ) out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit(train_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) out_shardings = (None, None) - jit_eval_step = jax.jit(eval_step, - in_shardings=in_shardings, - out_shardings=out_shardings) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) if args.use_fp8: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) @@ -634,8 +634,9 @@ def test_te_mxfp8(self): assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) @@ -644,8 +645,9 @@ def test_te_bf16_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" ) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) @@ -656,8 +658,9 @@ def test_te_delayed_scaling_fp8_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" ) - @unittest.skipIf(not tex.gemm_uses_jax_dot(), - "TE cuBLAS GEMM custom op does not support shardy") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 3c27f0d472..a093ff5d91 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -239,7 +239,7 @@ def _test_layernorm_mlp_grad( m_grad, s_grad, dtype=bwd_test_type, - err_msg=f"multi_grads[{i}] is not close" + err_msg=f"multi_grads[{i}] is not close", ) else: assert_allclose( @@ -400,7 +400,9 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES) + @pytest_parametrize_wrapper( + "fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES + ) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 807ca62713..fc84e21947 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -475,15 +475,11 @@ def _parse_operand_output_specs(arg_infos, dimension_numbers): lambda specs: tuple(spec for spec in specs if spec is not None), (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), ) - assert ( - len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1 - ), ( + assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." ) - assert ( - len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1 - ), ( + assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." ) @@ -712,7 +708,7 @@ def shardy_sharding_rule(*args, **kwargs): del args, kwargs raise NotImplementedError( "TE cuBLAS GEMM custom op does not support the Shardy partitioner. You can disable the " - "custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE=\"^(?!GemmPrimitive$).+$\"` in the " + 'custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE="^(?!GemmPrimitive$).+$"` in the ' "environment, which will make GEMM operations in TE will execute with native " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` calls." ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 6ed91f4319..0d4a1b7524 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -219,7 +219,7 @@ def _dense_bwd_rule( wgrad = tex.gemm( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), - dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim, ))), + dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim,))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 5608b24c5f..62fa2cfcd2 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -321,7 +321,7 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim, ))), + dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim,))), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 53decc3c2a..5d129aa54d 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -450,7 +450,7 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim, ), (x_bdim, ))), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -487,7 +487,7 @@ def _layernorm_mlp_bwd_rule( wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim, ))), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From 44e5b81cfe2e20a74fa7dc6df0aadc8ad1f8c79d Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 25 Jun 2025 21:48:30 +0000 Subject: [PATCH 17/27] moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly Signed-off-by: Alp Dener --- examples/jax/encoder/test_model_parallel_encoder.py | 6 ++---- examples/jax/encoder/test_multigpu_encoder.py | 4 +--- examples/jax/encoder/test_multiprocessing_encoder.py | 4 +--- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 00203a4537..490d062fbc 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -66,12 +66,10 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) - x = x.reshape(x.shape[0], -1) - if self.enable_seq_paral: # Trigger all-gather to collect a complete tensor alone sequence on each device. x = jax.lax.with_sharding_constraint( - x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) + x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None) ) x = te_flax.DenseGeneral( @@ -86,7 +84,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) return x diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 3d7fab0115..be2e9795ad 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -56,13 +56,11 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) - x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256)(x) x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) return x diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index b5d03c0796..5b7f2e2c89 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -66,8 +66,6 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) - x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), @@ -80,7 +78,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x) + x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) return x From b8ca0b1dd2715d5580039171a705256c67f7d2a6 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 27 Jun 2025 16:37:46 +0000 Subject: [PATCH 18/27] added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding Signed-off-by: Alp Dener --- .../encoder/test_model_parallel_encoder.py | 6 +- examples/jax/encoder/test_multigpu_encoder.py | 18 +-- .../encoder/test_multiprocessing_encoder.py | 4 +- transformer_engine/jax/cpp_extensions/gemm.py | 66 +++++++++- transformer_engine/jax/quantize/helper.py | 121 +++++++++++++++++- transformer_engine/jax/quantize/tensor.py | 26 +--- 6 files changed, 204 insertions(+), 37 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 490d062fbc..00203a4537 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -66,10 +66,12 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + if self.enable_seq_paral: # Trigger all-gather to collect a complete tensor alone sequence on each device. x = jax.lax.with_sharding_constraint( - x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None) + x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) ) x = te_flax.DenseGeneral( @@ -84,7 +86,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) + x = nn.Dense(features=2)(x) return x diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index be2e9795ad..44cafa7396 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -56,11 +56,13 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + x = te_flax.DenseGeneral(features=256)(x) x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) + x = nn.Dense(features=2)(x) return x @@ -436,7 +438,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -444,7 +446,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -452,7 +454,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -460,7 +462,7 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf( @@ -470,7 +472,7 @@ def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf( @@ -482,7 +484,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @@ -496,7 +498,7 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 5b7f2e2c89..b5d03c0796 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -66,6 +66,8 @@ def __call__(self, x, mask, disable_dropout=False): ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) + x = x.reshape(x.shape[0], -1) + x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), @@ -78,7 +80,7 @@ def __call__(self, x, mask, disable_dropout=False): bias_axes=(NAMED_BROADCAST_AXIS,), )(x) - x = nn.Dense(features=2)(x.reshape(x.shape[0], -1)) + x = nn.Dense(features=2)(x) return x diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index fc84e21947..16b8553706 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -31,6 +31,8 @@ QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, + apply_padding_to_scale_inv, + remove_padding_from_scale_inv, ) from .misc import get_padded_spec @@ -153,7 +155,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12) + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14) inner_primitive = None outer_primitive = None @@ -167,13 +169,15 @@ def abstract( gelu_input, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - del use_split_accumulator + del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator # Sanity-check operand layouts and types operand_ndims = (lhs.ndim, rhs.ndim) @@ -295,13 +299,15 @@ def lowering( gelu_input, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): - del out_dtype + del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype contracting_dims, _ = dimension_numbers lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -343,12 +349,36 @@ def impl( gelu_input, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, ): + lhs_cdims, rhs_cdims = map( + sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0] + ) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + is_colwise=lhs_quantized_colwise, + flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, + scaling_mode, + rhs.shape, + is_colwise=rhs_quantized_colwise, + flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + ) + outputs = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, @@ -358,6 +388,8 @@ def impl( gelu_input, out_dtype=out_dtype, dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -372,6 +404,8 @@ def batcher( batch_dims, out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -412,6 +446,8 @@ def batcher( *batched_args, out_dtype=out_dtype, dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -582,6 +618,8 @@ def _parse_operand_output_specs(arg_infos, dimension_numbers): def infer_sharding_from_operands( out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -591,7 +629,8 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del out_dtype, scaling_mode, grad, use_split_accumulator, result_infos + del out_dtype, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, grad, + del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) @@ -614,6 +653,8 @@ def infer_sharding_from_operands( def partition( out_dtype, dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, @@ -678,6 +719,8 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): gelu_input, out_dtype=out_dtype, dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, @@ -722,6 +765,15 @@ def gemm_uses_jax_dot() -> bool: return not GemmPrimitive.enabled() +def _get_scale_inv_without_padding(scaled_tensor): + return remove_padding_from_scale_inv( + scaled_tensor.scale_inv, + scaled_tensor.scaling_mode, + scaled_tensor.data.shape, + is_colwise=scaled_tensor.is_colwise, + flatten_axis=scaled_tensor.flatten_axis, + ) + def _te_gemm( lhs: Union[jax.Array, ScaledTensor], rhs: Union[jax.Array, ScaledTensor], @@ -760,7 +812,7 @@ def _te_gemm( lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() scaling_mode = lhs_q.scaling_mode lhs_data = lhs_q.data - lhs_scale_inv = lhs_q.scale_inv + lhs_scale_inv = _get_scale_inv_without_padding(lhs_q) if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) @@ -778,7 +830,7 @@ def _te_gemm( f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." ) rhs_data = rhs_q.data - rhs_scale_inv = rhs_q.scale_inv + rhs_scale_inv = _get_scale_inv_without_padding(rhs_q) if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) @@ -799,6 +851,8 @@ def _te_gemm( gelu_input, out_dtype=out_dtype, dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), + lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, + rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c0617eafbb..3c217d65fb 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -9,7 +9,9 @@ """ from contextlib import contextmanager from enum import Enum -from typing import Optional, Tuple, Dict, Union +from typing import Optional, Tuple, Dict, Union, Sequence +from functools import reduce +import operator import jax import jax.numpy as jnp @@ -29,6 +31,8 @@ "is_fp8_available", "update_collections", "get_delayed_scaling", + "apply_padding_to_scale_inv", + "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", ] @@ -471,4 +475,119 @@ def update_collections(new: Collection, original: Collection) -> Collection: return new_coll +def remove_padding_from_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Slice padding out of padded inverse scale factors. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Inverse scale factor without padding. + """ + # Get expected unpadded scale shape and check if inverse scale already matches + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: + return scale_inv + + # Get the padded scale shape and make sure inverse scale matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=True, + flatten_axis=flatten_axis, + ) + assert scale_inv.shape == padded_scale_shape, ( + f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got " + f"{scale_inv.shape} instead." + ) + + # Reshape scale inverse to 2D in two stages to preserve the flatten axis + padded_scale_shape_2d = ( + reduce(operator.mul, padded_scale_shape[ : flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis : ]) + ) + scale_inv_2d = jnp.reshape( + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis : ])), + padded_scale_shape_2d + ) + + # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape + unpadded_scale_shape_2d = ( + reduce(operator.mul, unpadded_scale_shape[ : flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis : ]) + ) + scale_inv_2d_unpadded = jnp.asarray( + scale_inv_2d[ : unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + ) + + # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis + scale_inv_unpadded = jnp.reshape( + jnp.reshape( + scale_inv_2d_unpadded, + (*unpadded_scale_shape[: flatten_axis], scale_inv_2d_unpadded.shape[1]) + ), + unpadded_scale_shape + ) + return scale_inv_unpadded + + +def apply_padding_to_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Pad the scale inverse with zeros to match the necessary padded shape for this scaling + mode. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Padded inverse scale factor. + """ + # Get the expected padded scale shape and check if inverse scale already matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape: + return scale_inv + + # Get the expected unpadded scale shape and make sure inverse scales match + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) + + # Pad the scales with the lowest representable value (2^-127) and return + pad_width = tuple( + (0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape) + ) + return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index a87326e9fe..d454f8f43a 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -17,6 +17,7 @@ from transformer_engine_jax import QuantizeLayout +from .helper import apply_padding_to_scale_inv from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap from ..sharding import ( @@ -136,26 +137,13 @@ def __post_init__(self): self.scale_inv = jnp.empty((0,), dtype=jnp.float32) else: - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis + self.scale_inv = apply_padding_to_scale_inv( + self.scale_inv, + self.scaling_mode, + self.data.shape, + is_colwise=self.is_colwise, + flatten_axis=self.flatten_axis, ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) - ) - - # padding with the smallest number it can present - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 - ) def tree_flatten(self): """Flattens the tensor for JAX tree operations. From 718758259df2bb0e4fa0ac83ebb31132d0021244 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Jun 2025 16:40:13 +0000 Subject: [PATCH 19/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 13 ++++++--- transformer_engine/jax/quantize/helper.py | 28 ++++++++----------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 16b8553706..0f43ea47b8 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -357,9 +357,7 @@ def impl( grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map( - sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0] - ) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0]) lhs_transposed, rhs_transposed = _get_gemm_layout( (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) ) @@ -629,7 +627,13 @@ def infer_sharding_from_operands( arg_infos, result_infos, ): - del out_dtype, lhs_quantized_colwise, rhs_quantized_colwise, scaling_mode, grad, + del ( + out_dtype, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + grad, + ) del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( @@ -774,6 +778,7 @@ def _get_scale_inv_without_padding(scaled_tensor): flatten_axis=scaled_tensor.flatten_axis, ) + def _te_gemm( lhs: Union[jax.Array, ScaledTensor], rhs: Union[jax.Array, ScaledTensor], diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 3c217d65fb..122265ea27 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -497,10 +497,7 @@ def remove_padding_from_scale_inv( """ # Get expected unpadded scale shape and check if inverse scale already matches unpadded_scale_shape = scaling_mode.get_scale_shape( - data_shape, - is_colwise=is_colwise, - is_padded=False, - flatten_axis=flatten_axis + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: return scale_inv @@ -519,30 +516,30 @@ def remove_padding_from_scale_inv( # Reshape scale inverse to 2D in two stages to preserve the flatten axis padded_scale_shape_2d = ( - reduce(operator.mul, padded_scale_shape[ : flatten_axis]), - reduce(operator.mul, padded_scale_shape[flatten_axis : ]) + reduce(operator.mul, padded_scale_shape[:flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis:]), ) scale_inv_2d = jnp.reshape( - jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis : ])), - padded_scale_shape_2d + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])), + padded_scale_shape_2d, ) # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape unpadded_scale_shape_2d = ( - reduce(operator.mul, unpadded_scale_shape[ : flatten_axis]), - reduce(operator.mul, unpadded_scale_shape[flatten_axis : ]) + reduce(operator.mul, unpadded_scale_shape[:flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis:]), ) scale_inv_2d_unpadded = jnp.asarray( - scale_inv_2d[ : unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] ) # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis scale_inv_unpadded = jnp.reshape( jnp.reshape( scale_inv_2d_unpadded, - (*unpadded_scale_shape[: flatten_axis], scale_inv_2d_unpadded.shape[1]) + (*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]), ), - unpadded_scale_shape + unpadded_scale_shape, ) return scale_inv_unpadded @@ -585,9 +582,8 @@ def apply_padding_to_scale_inv( ) # Pad the scales with the lowest representable value (2^-127) and return - pad_width = tuple( - (0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape) - ) + pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME From 0b7692ac25854c4777806d1aedaf1d0961fab327 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 25 Jun 2025 15:58:35 +0000 Subject: [PATCH 20/27] stashing Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 146 ++++++++++++++++-- .../transformer_engine/comm_gemm_overlap.h | 83 ++++++++-- .../common/util/pybind_helper.h | 35 ++++- transformer_engine/jax/cpp_extensions/gemm.py | 52 ++++++- transformer_engine/jax/csrc/extensions.h | 3 + .../jax/csrc/extensions/gemm.cpp | 139 +++++++++++++++-- transformer_engine/jax/csrc/extensions/misc.h | 6 + .../jax/csrc/extensions/pybind.cpp | 73 +-------- 8 files changed, 430 insertions(+), 107 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 40595ea988..86bc05df29 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -53,17 +53,38 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl bool atomic_gemm) { // Initialize userbuf communicator if (!_comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, allgather_handle, barrier_handle, 1, 1, tp_size, 1); -#endif _comm_created = true; + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Initialized Userbuffers Communicator\n"); + } + } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +CommOverlapCore::CommOverlapCore(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { + if (!_comm_created) { + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); + _comm_created = true; + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Initialized Userbuffers Communicator (w/ MPI Boostrapping)\n"); + } } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -124,6 +145,7 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl } } + CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_stop_comm); cudaEventDestroy(_start_comm); @@ -262,6 +284,21 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : CommOverlapCore(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, + comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -285,6 +322,39 @@ CommOverlapBase::~CommOverlapBase() { cudaStreamDestroy(_stream_comm); } +void CommOverlapBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, "bytes, UB dtype has ", _ubuf.element_size(), + "bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + const size_t ubuf_size = _ubuf.numel(); + void *dst_ptr = _ubuf.dptr(); + if (local_chunk) { + NVTE_CHECK(source_size * _tp_size == ubuf_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", tensor_parallel_size=", _tp_size, + ", ubuf_size=", ubuf_size, ")"); + dst_ptr = (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); + } else { + NVTE_CHECK(source_size == ubuf_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", ubuf_size, ")"); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -600,6 +670,21 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) + : CommOverlapCore(tp_size, tp_size, num_max_streams, comm_cga_size, gemm_priority, + comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -607,13 +692,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; @@ -621,14 +706,14 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -651,7 +736,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); @@ -662,6 +747,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); } + CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); @@ -669,6 +755,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); } +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526d..de18a12427 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -26,17 +26,29 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); -enum class CommOverlapType { RS = 0, AG = 1 }; - -enum class CommOverlapAlgo { - BULK_OVERLAP_AG = 0, - BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG_P2P = 2, - SPLIT_PIPELINED_RS = 3, - SPLIT_PIPELINED_RS_P2P = 4, - ATOMIC_GEMM_RS = 5, - ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 +enum class CommOverlapType : int64_t { + NONE = 0, + RS = 1, + AG = 2 +}; + +enum class CommOverlapMethod : int64_t { + NONE = 0, + BULK = 1, + PIPELINE = 2, + RING_EXCHANGE = 3 +}; + +enum class CommOverlapAlgo : int64_t { + NO_OVERLAP = 0, + BULK_OVERLAP_AG = 1, + BULK_OVERLAP_RS = 2, + SPLIT_PIPELINED_AG_P2P = 3, + SPLIT_PIPELINED_RS = 4, + SPLIT_PIPELINED_RS_P2P = 5, + ATOMIC_GEMM_RS = 6, + ATOMIC_GEMM_AG_P2P = 7, + ATOMIC_GEMM_RS_P2P = 8 }; class CommOverlapCore { @@ -66,15 +78,27 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python + // External/framework collectives-based constructor CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); + // MPI-based constructor + CommOverlapCore(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + + virtual ~CommOverlapCore(); void set_ubuf_scale_inv(float *scale_inv) { @@ -82,12 +106,19 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -142,9 +173,14 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python + // External/framework collective-based constructor CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, @@ -153,8 +189,18 @@ class CommOverlapBase : public CommOverlapCore { bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + // MPI-based constructor + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int tp_size, + int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + virtual ~CommOverlapBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -215,9 +261,14 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + // External/framework collective-based constructor CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, @@ -226,8 +277,18 @@ class CommOverlapP2PBase : public CommOverlapCore { int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + // MPI-based constructor + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int tp_size, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); + virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a1cd85ba2a..861f04d0b4 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,8 +8,11 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include +#include +#include #include #include "cuda_runtime.h" @@ -17,12 +20,25 @@ #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt64", transformer_engine::DType::kInt64) \ .value("kInt32", transformer_engine::DType::kInt32) \ .value("kFloat32", transformer_engine::DType::kFloat32) \ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat8E8M0", transformer_engine::DType::kFloat8E8M0); \ + pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) \ + .value("GELU", NVTE_Activation_Type::GELU) \ + .value("GEGLU", NVTE_Activation_Type::GEGLU) \ + .value("SILU", NVTE_Activation_Type::SILU) \ + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) \ + .value("RELU", NVTE_Activation_Type::RELU) \ + .value("REGLU", NVTE_Activation_Type::REGLU) \ + .value("QGELU", NVTE_Activation_Type::QGELU) \ + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ + .value("SRELU", NVTE_Activation_Type::SRELU) \ + .value("SREGLU", NVTE_Activation_Type::SREGLU); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ @@ -75,16 +91,27 @@ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTE_Norm_Type", pybind11::module_local()) \ + .value("LayerNorm", NVTE_Norm_Type::LayerNorm) \ + .value("RMSNorm", NVTE_Norm_Type::RMSNorm); \ pybind11::enum_( \ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ + .value("NONE", transformer_engine::CommOverlapType::NONE) \ .value("RS", transformer_engine::CommOverlapType::RS) \ .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapMethod", \ + pybind11::module_local()) \ + .value("NONE", transformer_engine::CommOverlapMethod::NONE) \ + .value("BULK", transformer_engine::CommOverlapMethod::BULK) \ + .value("PIPELINE", transformer_engine::CommOverlapMethod::PIPELINE) \ + .value("RING_EXCHANGE", transformer_engine::CommOverlapMethod::RING_EXCHANGE); \ pybind11::enum_(m, "CommOverlapAlgo", \ pybind11::module_local()) \ + .value("NO_OVERLAP", transformer_engine::CommOverlapAlgo::NO_OVERLAP) \ .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ .value("SPLIT_PIPELINED_AG_P2P", \ @@ -128,6 +155,12 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); \ + m.dev("get_qkv_format", &transformer_engine::nvte_get_qkv_format, \ + py::call_guard()); \ + m.def("get_num_compute_streams", &nvte_get_num_compute_streams, \ + py::call_guard()); \ + m.def("is_non_nt_fp8_gemm_supported", &transformer_engine::nvte_is_non_tn_fp8_gemm_supported, \ py::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0f43ea47b8..af82ef9d26 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -5,7 +5,9 @@ import math import operator +from abc import ABC from collections.abc import Iterable +from dataclasses import dataclass, field from typing import Tuple, Sequence, Union from functools import partial, reduce @@ -15,7 +17,6 @@ from jax.sharding import NamedSharding, PartitionSpec import transformer_engine_jax as tex -from transformer_engine_jax import get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -47,7 +48,9 @@ ] -num_cublas_streams = get_num_compute_streams() +num_cublas_streams = tex.get_num_compute_streams() + +num_comm_overlap_max_streams = 3 def get_cublas_workspace_size_bytes() -> None: @@ -326,6 +329,9 @@ def lowering( "fuse_gelu": fuse_gelu, "grad": grad, "use_split_accumulator": use_split_accumulator, + "comm_overlap_id": -1, + "comm_overlap_method": tex.CommOverlapMethod.NONE, + "comm_type": tex.CommOverlapType.NONE, } operand_output_aliases = {} @@ -764,6 +770,48 @@ def shardy_sharding_rule(*args, **kwargs): register_primitive(GemmPrimitive) +@dataclass +class CommOverlapHelper: + + buffer_shape: Sequence[int] = (0, ) + buffer_dtype: jnp.dtype = jnp.bfloat16 + tp_size: int = 1 + comm_type: tex.CommOverlapType = tex.CommOverlapType.NONE + method: tex.CommOverlapMethod = tex.CommOverlapMethod.NONE + unique_id: int = field(default=-1, init=False) + + def __post_init__(self): + if self.comm_type == tex.CommOverlapType.NONE: + assert self.method == tex.CommOverlapMethod.NONE + else: + assert self.method != tex.CommOverlapMethod.NONE + if self.comm_type == tex.CommOverlapType.AG: + assert self.method != tex.CommOverlapType.PIPELINE, ( + "Comm+GEMM overlap w/ PIPELINE method does not support all-gather." + ) + else: + # Reduce-Scatter overlap always needs an auxiliary output + self.needs_aux_out = True + assert self.tp_size > 1, ( + "Comm+GEMM overlap requires tensor-parallel size larger than 1." + ) + + def is_enabled(self): + return self.method != tex.CommOverlapMethod.NONE + + def is_bulk(self): + return self.method == tex.CommOverlapMethod.BULK + + def is_p2p(self): + return self.method == tex.CommOverlapMethod.RING_EXCHANGE + + def is_all_gather(self): + return self.comm_type == tex.CommOverlapType.AG + + def is_reduce_scatter(self): + return self.comm_type == tex.CommOverlapType.RS + + def gemm_uses_jax_dot() -> bool: """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" return not GemmPrimitive.enabled() diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 59079fe3f0..6c2c021116 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,8 @@ #include "transformer_engine/multi_stream.h" // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::CommOverlapMethod); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::CommOverlapType); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); namespace transformer_engine { diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 2c5f027ba1..35d629f837 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -100,13 +100,47 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +static std::unordered_map comm_overlaps; + +int64_t CreateCommOverlapBuffer(CommOverlapType comm_type, CommOverlapMethod method, + const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, + int set_sm_margin, bool use_ce, bool atomic_gemm, + bool rs_overlap_first_gemm, bool aggregate_ag) { + int64_t unique_id = 0; + hash_combine(unique_id, static_cast(comm_type), static_cast(method), buffer_shape[0], + buffer_shape[0], static_cast(buffer_dtype), tp_size, num_splits, + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, use_ce, atomic_gemm, rs_overlap_first_gemm, aggregate_ag); + + auto it = comm_overlaps.find(unique_id); + if (it == comm_overlaps.end()) { + if (method == CommOverlapMethod::RING_EXCHANGE) { + comm_overlaps[unique_id] = reinterpret_cast( + new CommOverlapP2PBase(buffer_shape, buffer_dtype, tp_size, comm_type, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, use_ce, atomic_gemm, aggregate_ag)); + } else { + comm_overlaps[unique_id] = reinterpret_cast( + new CommOverlapBase(buffer_shape, buffer_dtype, tp_size, num_splits, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, rs_overlap_first_gemm)); + } + } + + return unique_id; +} + Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + Buffer_Type aux_in, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type aux_out, Result_Type lhs_swizzle, + Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, int64_t comm_overlap_id, + CommOverlapMethod comm_overlap_method, CommOverlapType comm_type) { // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) @@ -125,8 +159,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, " - "expected ", + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(output->dimensions())); @@ -169,9 +202,90 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + if (comm_type == CommOverlapType::NONE) { + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + } else { + auto executor = comm_overlaps[comm_overlap_id]; + auto tp_size = executor->get_tp_size(); + if (comm_overlap_method == CommOverlapMethod::BULK) { + // Prepare the auxiliary output tensor + auto aux_out_dims = aux_out->dimensions(); + std::vector aux_out_shape = {0}; + auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); + if ((comm_type == CommOverlapType::AG && aux_out->element_count() > 0) + || comm_type == CommOverlapType::RS) { + std::vector aux_out_shape = {product(aux_out_dims, 0, aux_out_dims.size() - 1), + static_cast(aux_out_dims.back())}; + } + auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); + + // Copy the auxiliary data into the communications buffer + auto aux_in_dims = aux_in.dimensions(); + std::vector aux_in_shape = {product(aux_in_dims, 0, aux_in_dims.size() -1 ), + static_cast(aux_in_dims.back())}; + auto aux_in_dtype = convert_ffi_datatype_to_te_dtype(aux_in.element_type()); + auto aux_in_ = TensorWrapper(aux_in.untyped_data(), aux_in_shape, aux_in_dtype); + if (comm_type == CommOverlapType::AG && aux_out->element_count() > 0) { + NVTE_CHECK(aux_in_shape[0] == tp_size * aux_out_shape[0], + "cuBLAS GEMM w/ bulk AG overlap auxiliary output is sized incorrectly, ", + "expected (", aux_in_shape[0] / tp_size, ",", aux_in_shape[1], ") but got ", + to_string_like(aux_out_dims)); + } else if (comm_type == CommOverlapType::RS) { + NVTE_CHECK(tp_size * aux_in_shape[0] == aux_out_shape[0], + "cuBLAS GEMM w/ bulk RS overlap auxiliary output is sized incorrectly, ", + "expected (", aux_in_shape[0] * tp_size, ",", aux_in_shape[1], ") but got ", + to_string_like(aux_out_dims)); + } + executor->copy_into_buffer(stream, aux_in_, (comm_type == CommOverlapType::AG)); + + // Launch GEMM w/ bulk overlap + executor->bulk_overlap(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, comm_type, aux_out_, + stream); + } else if (comm_type == CommOverlapType::RS) { + // Prepare the auxiliary buffer for the reduce-scattered GEMM output + auto aux_out_shape = std::vector(out_shape); + aux_out_shape.at(0) /= tp_size; + auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); + auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); + NVTE_CHECK(aux_out_.numel() == aux_out->element_count(), + "cuBLAS GEMM->RS overlap auxiliary buffer is sized incorrectly, expected ", + aux_out_.numel(), " elements ", to_string_like(aux_out_shape), " but got ", + aux_out->element_count(), " elements ", to_string_like(aux_out->dimensions())); + + // Launch GEMM+RS + executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, + stream); + } else if (comm_type == CommOverlapType::AG) { + // Prepare the auxiliary buffer for all-gathered LHS + std::vector aux_out_shape = {0}; + auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); + if (aux_out->element_count() > 0) { + aux_out_shape = std::vector(lhs_shape); + aux_out_shape.at(0) *= tp_size; + auto aux_out_numel = aux_out_shape[0] * aux_out_shape[1]; + NVTE_CHECK(aux_out_numel == aux_out->element_count(), + "cuBLAS AG->GEMM overlap auxiliary buffer is sized incorrectly, expected ", + aux_out_numel, " elements ", to_string_like(aux_out_shape), " but got ", + aux_out->element_count(), " elements ", to_string_like(aux_out->dimensions())); + } + auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); + + // Copy the distributed LHS operand into the local chunk of the communication buffer + executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); + + // Launch AG+GEMM + executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, + stream); + } else { + NVTE_ERROR("cuBLAS GEMM w/ comm. overlap invoked with invalid collective type (", + static_cast(comm_type), ")"); + } + } return ffi_with_cuda_error_check(); } @@ -199,7 +313,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") - .Attr("use_split_accumulator"), + .Attr("use_split_accumulator") + .Attr("comm_overlap_id") + .Attr("comm_overlap_method") + .Attr("comm_type"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb6..4578a09391 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -87,5 +87,11 @@ constexpr struct Alignment { std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +template +void hash_combine(int64_t &seed, const T &v, Rest... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c1..d0ef724b20 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -6,6 +6,7 @@ #include "../extensions.h" +#include "common/util/pybind_helper.h" namespace transformer_engine { namespace jax { @@ -69,12 +70,13 @@ pybind11::dict Registrations() { } PYBIND11_MODULE(transformer_engine_jax, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("registrations", &Registrations); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); - m.def("get_num_compute_streams", &nvte_get_num_compute_streams); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes); m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes); @@ -82,83 +84,18 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); - m.def("nvte_get_qkv_format", &nvte_get_qkv_format); - m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); - - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); - - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU) - .export_values(); - - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); - - pybind11::enum_(m, "NVTE_Norm_Type", pybind11::module_local()) - .value("LayerNorm", NVTE_Norm_Type::LayerNorm) - .value("RMSNorm", NVTE_Norm_Type::RMSNorm) - .export_values(); pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) - .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) - .export_values(); + .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING); pybind11::enum_(m, "QuantizeLayout", pybind11::module_local()) .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) - .export_values(); + .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE); } } // namespace jax From 77eaa63863f558eb8f1844f93b6696004e064254 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 2 Jul 2025 07:23:50 +0000 Subject: [PATCH 21/27] both AG and RS overlaps working Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 233 ++++ .../transformer_engine/comm_gemm_overlap.h | 4 + .../common/util/pybind_helper.h | 55 +- transformer_engine/jax/attention.py | 4 +- transformer_engine/jax/cpp_extensions/gemm.py | 1129 +++++++++++++---- transformer_engine/jax/csrc/extensions.h | 12 + .../jax/csrc/extensions/gemm.cpp | 80 +- .../jax/csrc/extensions/pybind.cpp | 61 +- transformer_engine/jax/dense.py | 71 +- transformer_engine/jax/flax/module.py | 236 +++- transformer_engine/jax/layernorm_dense.py | 71 +- transformer_engine/jax/layernorm_mlp.py | 90 +- 12 files changed, 1662 insertions(+), 384 deletions(-) create mode 100644 examples/jax/comm_gemm_overlap/comm_gemm_overlap.py diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py new file mode 100644 index 0000000000..85744f43c8 --- /dev/null +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" + +import argparse +from functools import partial +from pprint import pprint + +import numpy as np +from mpi4py import MPI + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +import transformer_engine_jax as tex +from transformer_engine.jax.sharding import get_padded_spec +from transformer_engine.jax.cpp_extensions import ( + gemm, + CommOverlapHelper, +) + +jax.clear_caches() + +# This script needs to be launched via `mpirun` with 1 process per GPU +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.distributed.initialize(cluster_detection_method="mpi4py") + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=1) +parser.add_argument("-zp", "--fsdp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) +parser.add_argument("-np", "--num-gpus", type=int, default=numranks) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"]) +parser.add_argument("--check-result", action="store_true") +args = parser.parse_args() + +# Operand shapes +dtype = jnp.bfloat16 +lhs_shape = ( + [args.seq_length, args.hidden_size] + if args.comm_type == "AG" + else [args.seq_length, args.activation_size] +) +rhs_shape = ( + [args.hidden_size, args.activation_size] + if args.comm_type == "AG" + else [args.activation_size, args.hidden_size] +) + +# Operand partitioning +batched = not args.no_batch +fsdp = not args.no_fsdp +input_specs = [None] * len(lhs_shape) +weight_specs = [None] * len(rhs_shape) +weight_no_fsdp = weight_specs.copy() +if batched: + lhs_shape = [args.batch_size] + lhs_shape + if fsdp: + mesh_shape = {"dp": args.dp_size, "zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", tp_resource="tp", cp_resource="tp", fsdp_resource="zp" + ) + if args.comm_type == "AG": + input_specs = [("dp", "zp"), "tp", None] + weight_specs = ["zp", "tp"] + weight_no_fsdp = [None, "tp"] + elif args.comm_type == "RS": + input_specs = [("dp", "zp"), None, "tp"] + weight_specs = ["tp", "zp"] + weight_no_fsdp = ["tp", None] + else: + mesh_shape = {"dp": args.dp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", + tp_resource="tp", + cp_resource="tp", + ) + if args.comm_type == "AG": + input_specs = ["dp", "tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = ["dp", None, "tp"] + weight_specs = ["tp", None] + weight_no_fsdp = weight_specs +else: + if fsdp: + mesh_shape = {"zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource(fsdp_resource="zp", tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = ["zp", "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", "zp"] + weight_no_fsdp = ["tp", None] + else: + mesh_shape = {"tp": args.tp_size} + mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", None] + weight_no_fsdp = weight_specs + +# Mesh setup and sharding definitions +devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) +mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +no_sharding = NamedSharding(mesh, PartitionSpec(None)) +input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) +weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) +weight_no_fsdp_sharding = NamedSharding(mesh, PartitionSpec(*weight_no_fsdp)) + +# Operand initialization +key = jax.random.PRNGKey(0) +key1, key2 = jax.random.split(key, 2) +lhs_data = jax.random.normal(key1, lhs_shape, dtype=dtype) +rhs_data = jax.random.normal(key2, rhs_shape, dtype=dtype) +lhs = jax.device_put(lhs_data, input_sharding) +rhs = jax.device_put(rhs_data, weight_sharding) +dimension_numbers = (((-1, ), (0, )), ((0, ), ())) + +# Name of comm+GEMM overlap layer +overlap_method = tex.CommOverlapMethod.RING_EXCHANGE +comm_type = tex.CommOverlapType.AG if args.comm_type == "AG" else tex.CommOverlapType.RS + +# Bootstrap Userbuffers communicators and communication buffers +# NOTE: All-gather overlap requires buffer to be sized the LHS operand's global shape. +# Reduce-scatter overlap requires buffer to be sized to the GEMM output's global shape. +output_shape = (*lhs_shape[:-1], rhs_shape[-1]) +buffer_shape = list(lhs_shape if comm_type == tex.CommOverlapType.AG else output_shape).copy() +if batched: + # The only all-gathered dimension is sequence, batch is still sharded for the buffer + buffer_shape[0] = buffer_shape[0] // (args.dp_size * args.fsdp_size) +overlap_helper = CommOverlapHelper( + method=overlap_method, + comm_type=comm_type, + buffer_shape=buffer_shape, + buffer_dtype=dtype, + tp_size=args.tp_size, + tp_resource="tp", + sp_resource="tp", +) +if myrank == 0: + print(f"{myrank}: OVERLAP CONFIG:", flush=True) + pprint(overlap_helper) + print( + f"\n{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n" + + f"{myrank}: LHS sharding: {lhs.sharding.spec}\n" + + f"{myrank}: RHS sharding: {rhs.sharding.spec}\n", + flush=True, + ) + +@jax.jit +def _gemm_wrapper(x, y): + return partial( + gemm, + dimension_numbers=(((-1, ), (0, )), ((0, ), ())), + comm_overlap=overlap_helper, + )(x, y) + +rhs_no_fsdp = jax.lax.with_sharding_constraint(rhs, weight_no_fsdp_sharding) + +with te.sharding.global_shard_guard(mesh_resource): + output = _gemm_wrapper(lhs, rhs) + +jax.block_until_ready(output) +if myrank == 0: + print( + f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT " + + f"{output.shape}\n" + + f"{myrank}: Sharding: {get_padded_spec(output.sharding.spec, output.ndim)}\n", + flush=True, + ) + +if args.check_result: + ref_global = jnp.matmul( + jax.device_put(lhs_data, no_sharding), jax.device_put(rhs_data, no_sharding) + ) + jax.block_until_ready(ref_global) + if myrank == 0: + print(f"{myrank}: Global reference: {ref_global}\n", flush=True) + + output_global = jax.lax.with_sharding_constraint(output, no_sharding) + jax.block_until_ready(output_global) + if myrank == 0: + print(f"{myrank}: Global output: {output_global}\n", flush=True) + + diff = jnp.abs(ref_global - output_global).flatten() + if myrank == 0: + print(f"{myrank}: Global difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(ref_global.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + numerics_failed = False + if rel_err > rtol and abs_err > atol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m} " + + f"with {output.flatten()[m].item()} vs {ref_global.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +tex.destroy_all_comm_overlap_buffers() diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index de18a12427..b823ac0671 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -101,6 +101,8 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void* get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; @@ -117,6 +119,8 @@ class CommOverlapCore { TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + + int get_tp_size() { return _tp_size; } bool is_atomic_gemm() { return _atomic_gemm; } diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 861f04d0b4..beb96545a7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -122,30 +122,31 @@ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ - py::call_guard()) \ + pybind11::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + pybind11::call_guard()) \ .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ - py::call_guard()); \ + pybind11::call_guard()); \ + pybind11::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", \ + pybind11::module_local()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + pybind11::call_guard()); \ + pybind11::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + pybind11::call_guard()); \ m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def( \ "get_stream_priority_range", \ [](int device_id = -1) { \ @@ -153,14 +154,14 @@ transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ return std::make_pair(low_pri, high_pri); \ }, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); \ - m.dev("get_qkv_format", &transformer_engine::nvte_get_qkv_format, \ - py::call_guard()); \ + pybind11::call_guard()); \ + m.def("get_qkv_format", &nvte_get_qkv_format, \ + pybind11::call_guard()); \ m.def("get_num_compute_streams", &nvte_get_num_compute_streams, \ - py::call_guard()); \ - m.def("is_non_nt_fp8_gemm_supported", &transformer_engine::nvte_is_non_tn_fp8_gemm_supported, \ - py::call_guard()); + pybind11::call_guard()); \ + m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported, \ + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fe4109cee8..4d1e8316c3 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -17,7 +17,7 @@ from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Format -from transformer_engine_jax import nvte_get_qkv_format +from transformer_engine_jax import get_qkv_format from . import cpp_extensions as tex @@ -109,7 +109,7 @@ def get_qkv_format(self): """ Return the corresponding qkv_format (BSHD, SBHD, THD) """ - return QKVFormat(nvte_get_qkv_format(self.value)) + return QKVFormat(get_qkv_format(self.value)) def is_qkvpacked(self): """ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index af82ef9d26..8426cfdc0b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -5,11 +5,10 @@ import math import operator -from abc import ABC from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Tuple, Sequence, Union from functools import partial, reduce +from typing import Tuple, Sequence, Union import jax import jax.numpy as jnp @@ -35,10 +34,13 @@ apply_padding_to_scale_inv, remove_padding_from_scale_inv, ) -from .misc import get_padded_spec +from .misc import get_padded_spec, jax_dtype_to_te_dtype +from ..sharding import global_mesh_resource __all__ = [ + "CommOverlapHelper", + "CommOverlapHelperSet", "gemm", "grouped_gemm", "gemm_uses_jax_dot", @@ -50,7 +52,8 @@ num_cublas_streams = tex.get_num_compute_streams() -num_comm_overlap_max_streams = 3 +CUDA_STREAM_PRIORITY_LOWEST = None +CUDA_STREAM_PRIORITY_HIGHEST = None def get_cublas_workspace_size_bytes() -> None: @@ -151,6 +154,690 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +@dataclass(frozen=True) +class CommOverlapHelper: + """ + Helper object that carries comm+GEMM overlap configuration, initializes the internal + communication buffer, and generates lowering arguments and partitioning rules for + the GemmPrimitive. + """ + # Core init arguments + method: tex.CommOverlapMethod = field(default=tex.CommOverlapMethod.NONE) + comm_type: tex.CommOverlapType = field(default=tex.CommOverlapType.NONE) + buffer_shape: Sequence[int] = field(default=None) + buffer_dtype: jnp.dtype = field(default=jnp.bfloat16) + tp_size: int = field(default=None) + + # Userbuffers bootstrap kwargs + num_splits: int = field(default=None, kw_only=True) + num_max_streams: int = field(default=3, kw_only=True) + comm_cga_size: int = field(default=None, kw_only=True) + gemm_priority: int = field(default=None, kw_only=True) + comm_priority: int = field(default=None, kw_only=True) + num_comm_sm: int = field(default=None, kw_only=True) + set_sm_margin: bool = field(default=None, kw_only=True) + use_ce: bool = field(default=None, kw_only=True) + atomic_gemm: bool = field(default=False, kw_only=True) + rs_overlap_first_gemm: bool = field(default=False, kw_only=True) + aggregate_ag: bool = field(default=False, kw_only=True) + + # Other kwargs not passed to Userbuffers + tp_resource: str = field(default=None, kw_only=True) + sp_resource: str = field(default=None, kw_only=True) + output_all_gathered_lhs: bool = field(default=False, kw_only=True) + flatten_axis: int = field(default=-1, kw_only=True) + + # Internal attributes + is_enabled: bool = field(default=False, init=False, compare=True) + unique_id: int = field(default=None, init=False, compare=False) + sharded_impl: bool = field(default=False, init=False, compare=False) + gather_dim: int = field(default=-2, init=False, compare=False) + scatter_dim: int = field(default=-2, init=False, compare=False) + + def __post_init__(self): + # Update global min/max CUDA stream priority values if not already done + global CUDA_STREAM_PRIORITY_LOWEST, CUDA_STREAM_PRIORITY_HIGHEST + if CUDA_STREAM_PRIORITY_LOWEST is None or CUDA_STREAM_PRIORITY_HIGHEST is None: + ( + CUDA_STREAM_PRIORITY_LOWEST, + CUDA_STREAM_PRIORITY_HIGHEST, + ) = tex.get_stream_priority_range() + + object.__setattr__(self, "is_enabled", self.method != tex.CommOverlapMethod.NONE) + if self.is_enabled: + assert self.buffer_shape is not None, ( + f"CommOverlapHelper: {self.buffer_shape} is not a valid buffer shape." + ) + assert self.comm_type != tex.CommOverlapType.NONE, ( + f"CommOverlapHelper: {self.comm_type} is not a valid collective type for " + f"{self.method}." + ) + assert self.tp_size % 2 == 0, ( + "CommOverlapHelper: Tensor-parallel axis size must be divisible by 2, got " + f"{self.tp_size}." + ) + if not self.is_bulk() and not self.is_p2p(): + # Pipelined overlap is only for reduce-scatter + assert self.comm_type != tex.CommOverlapType.AG, ( + f"CommOverlapHelper: {self.comm_type} is not a valid collective type for " + f"{self.method}." + ) + + # Collapse buffer shape to 2D + if len(self.buffer_shape) > 2: + if self.flatten_axis < 0: + object.__setattr__(self, "flatten_axis", self.flatten_axis + len(self.buffer_shape)) + object.__setattr__( + self, + "buffer_shape", + ( + reduce(operator.mul, self.buffer_shape[ : self.flatten_axis]), + reduce(operator.mul, self.buffer_shape[self.flatten_axis : ]) + ) + ) + + # Num splits for P2P overlap is always fixed to TP size + if self.is_p2p(): + object.__setattr__(self, "num_splits", self.tp_size) + elif self.num_splits is None: + object.__setattr__(self, "num_splits", self.tp_size) + + # Set conditional defaults for config options not specified at init time + if self.comm_cga_size is None: + object.__setattr__(self, "comm_cga_size", 1 if self.is_p2p() else 2) + if self.num_comm_sm is None: + object.__setattr__(self, "num_comm_sm", 1 if self.is_p2p() else 16) + if self.set_sm_margin is None: + object.__setattr__(self, "set_sm_margin", not self.is_p2p()) + if self.use_ce is None: + object.__setattr__(self, "use_ce", self.is_p2p()) + if self.gemm_priority is None: + object.__setattr__(self, "gemm_priority", CUDA_STREAM_PRIORITY_LOWEST) + if self.comm_priority is None: + object.__setattr__(self, "comm_priority", CUDA_STREAM_PRIORITY_HIGHEST) + + # Update mesh resources for tensor- and sequence-parallel dimensions + if self.tp_resource is None: + object.__setattr__(self, "tp_resource", global_mesh_resource().tp_resource) + if self.sp_resource is None: + object.__setattr__(self, "sp_resource", global_mesh_resource().cp_resource) + + # Allocate the communication buffer + args, kwargs = self.get_bootstrap_args_kwargs() + object.__setattr__(self, "unique_id", tex.create_comm_overlap_buffer(*args, **kwargs)) + + def _set_sharded_impl(self, value): + assert isinstance(value, bool) + object.__setattr__(self, "sharded_impl", value) + + def _set_gather_dim(self, value): + assert isinstance(value, int) + object.__setattr__(self, "gather_dim", value) + + def _set_scatter_dim(self, value): + assert isinstance(value, int) + object.__setattr__(self, "scatter_dim", value) + + def is_bulk(self): + """Check if this is a bulk overlap.""" + return self.method == tex.CommOverlapMethod.BULK + + def is_p2p(self): + """Check if this is a peer-to-peer (ring-exchange) overlap.""" + return self.method == tex.CommOverlapMethod.RING_EXCHANGE + + def is_all_gather(self): + """Check if the overlapped collective is an all-gather.""" + return self.comm_type == tex.CommOverlapType.AG + + def is_reduce_scatter(self): + """Check if the overlapped collective is a reduce-scatter.""" + return self.comm_type == tex.CommOverlapType.RS + + def has_aux_output(self): + """Check if the comm+GEMM overlap has an auxiliary output.""" + return ( + self.is_enabled + and (self.is_bulk() or (self.is_all_gather() and self.output_all_gathered_lhs)) + ) + + def get_bootstrap_args_kwargs(self): + """Generate positional and keyword arguments to bootstrap Userbuffers.""" + args = ( + self.comm_type, + self.method, + self.buffer_shape, + jax_dtype_to_te_dtype(self.buffer_dtype), + self.tp_size, + ) + kwargs = { + "num_splits" : self.num_splits, + "num_max_streams" : self.num_max_streams, + "comm_cga_size" : self.comm_cga_size, + "gemm_priority" : self.gemm_priority, + "comm_priority" : self.comm_priority, + "num_comm_sm" : self.num_comm_sm, + "set_sm_margin" : self.set_sm_margin, + "use_ce" : self.use_ce, + "atomic_gemm" : self.atomic_gemm, + "rs_overlap_first_gemm" : self.rs_overlap_first_gemm, + "aggregate_ag" : self.aggregate_ag + } + return args, kwargs + + def get_lowering_kwargs(self): + """Generate a dictionary of keyword arguments used in GemmPrimitive.lowering().""" + aux_axis_boundary = -1 + if self.is_enabled and self.sharded_impl: + if self.is_all_gather(): + assert self.gather_dim >= 0, ( + "Internal TE error: CommOverlapHelper.gather_dim is not set correctly in " + "GemmPrimitive." + ) + aux_axis_boundary = self.gather_dim + 1 + elif self.is_reduce_scatter(): + assert self.scatter_dim >= 0, ( + "Internal TE error: CommOverlapHelper.scatter_dim is not set correctly in " + "GemmPrimitive." + ) + aux_axis_boundary = self.scatter_dim + 1 + + return { + "comm_overlap_id" : self.unique_id, + "comm_overlap_method" : int(self.method.value), + "comm_type" : int(self.comm_type.value), + "aux_axis_boundary" : aux_axis_boundary, + } + + @staticmethod + def _check_operand_specs(lhs_specs, rhs_specs, dimension_numbers): + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + + def _split_specs(specs, contracting_dims, batch_dims): + ndims = len(specs) + cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) + + # Batch specs + bspecs = tuple(specs[i] for i in bdims) + + # Non-batch leading dimension specs + lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) + + # Non-batch contracting dimension specs + cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + + return bspecs, lspecs, cspecs + + ( + (lhs_bspecs, lhs_lspecs, lhs_cspecs), + (rhs_bspecs, rhs_lspecs, rhs_cspecs), + ) = map( + _split_specs, + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + + # Batched dimensions must have the same sharding + if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: + assert all( + lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) + ), ( + "cuBLAS GEMM operand batch dimensions must have the same sharding: " + f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + ) + + # Only one each of the non-batched leading dimensions and non-batched contracting + # dimensions can be sharded + lhs_ldims, rhs_ldims = map( + lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + (lhs_ndim, rhs_ndim), + (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), + ) + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( + lambda specs: tuple(spec for spec in specs if spec is not None), + (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), + ) + assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " + f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." + ) + assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " + f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + ) + + # Extract single leading and contracting dimension specs + (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + lambda specs: None if len(specs) == 0 else specs[0], + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + ) + return (lhs_lspec, lhs_cspec), (rhs_lspec, rhs_cspec) + + def _get_no_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + + (lhs_lspec, lhs_cspec), (rhs_lspec, rhs_cspec) = self._check_operand_specs( + lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) + ) + + # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts + # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. + # 1. K1 == K2 != None and N == None + # LHS: (B, M, K) + # RHS: (B, None, K) + # OUT: (B, M, None) --(AR)-> (B, M, None) + # 2. K1 == K2 != None and M == N != None + # LHS: (B, M, K) + # RHS: (B, N, K)--(AG)->(B, None, K) + # OUT: (B, M, None) --(RS)--> (B, M, N) + # 3. M == N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, M, K)--(AG)->(B, None, None) + # OUT: (B, M, None) + # 4. M != N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, N, K)--(AG)->(B, N, None) + # OUT: (B, M, N) + reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec + all_reduce_output = reduce_flag and rhs_lspec is None + reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec + all_reduce_spec = reduce_scatter_spec = scatter_dim = None + + lhs_non_contracting_specs, rhs_non_contracting_specs = map( + lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + ) + out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) + if reduce_scatter_output: + # All-gather (if necessary) the non-batch non-contracting dimension of RHS + # LHS: (B, M, K) + # RHS: (B, N, K) --(AG)-> (B, None, K) + # OUT: (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) + rhs_spec = tuple( + rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) + ) + reduce_scatter_spec = lhs_cspec + scatter_dim = out_specs.index(rhs_lspec) + + elif all_reduce_output: + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + all_reduce_spec = lhs_cspec + else: + # All-gather (if necessary) the non-batch contracting dimensions + # LHS: (B, M, K) --(AG)-> (B, M, None) + # RHS: (B, N, K) --(AG)-> (B, N, None) + # OUT: (B, M, None) x (B, N, None)^T = (B, M, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] + for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Check if RHS non-contracting spec also appears in the LHS non-contracting specs + if rhs_lspec is not None and rhs_lspec in tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims + ): + # All-gather (if necessary) the non-batch non-contracting dimensions of RHS + # LHS: (B, M, None) + # RHS: (B, N, None) --(AG)-> (B, None, None) + # OUT: (B, M, None) x (B, None, None)^T = (B, M, None) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_contracting_specs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), + (out_specs, bias_specs, gelu_specs, (None, )), + (all_reduce_spec, reduce_scatter_spec, scatter_dim), + ) + + def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + assert self.sp_resource in aux_in_specs, ( + "CommOverlapHelper: Auxiliary input for bulk all-gather overlap is not sharded " + f"over the sequence-parallel mesh resource {self.sp_resource} in any dimension." + ) + + aux_out_specs = (None, ) + bulk_comm_dim = aux_in_specs.index(self.sp_resource) + aux_in_specs_batch = aux_in_specs[ : bulk_comm_dim] + aux_in_specs_tail = aux_in_specs[bulk_comm_dim + 1: ] + if self.is_all_gather(): + assert all(spec is None for spec in aux_in_specs_tail), ( + "CommOverlapHelper: Trailing dimensions of the auxiliary input for bulk all-gather " + "overlap cannot be sharded." + ) + self._set_gather_dim(bulk_comm_dim) + aux_out_specs = ( + *aux_in_specs_batch, + None, # all-gathered dimension + *[None for _ in range(len(aux_in_specs_tail))] + ) + else: + assert all(spec is None for spec in aux_in_specs[bulk_comm_dim : ]), ( + "CommOverlapHelper: Non-batch dimensions of the auxiliary input for bulk " + "reduce-scatter overlap cannot be sharded." + ) + self._set_scatter_dim(bulk_comm_dim) + aux_out_specs = ( + *aux_in_specs_batch, + self.sp_resource, + *[None for _ in range(len(aux_in_specs_tail))], + ) + + # GEMM is independent of communication so specs are as if there is no overlap + operand_specs, output_specs, xla_reduce_info = self._get_specs_no_overlap( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) + + return ( + operand_specs, + (*output_specs[:-1], aux_out_specs), + xla_reduce_info, + ) + + + def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + ) + + (lhs_lspec, _), _ = self._check_operand_specs( + lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) + ) + assert lhs_lspec == self.sp_resource, ( + "CommOverlapHelper: Non-batch leading dimension of the LHS operand for AG->GEMM " + f"overlap must be sharded over the sequence-parallel mesh resource {self.sp_resource}, " + f"but got {lhs_lspec} sharding instead." + ) + + # AG->GEMM overlap: Require non-batched contracting dimensions to be unsharded (e.g. FSDP) + # LHS: (B, M, None) + # RHS: (N, None) + # OUT: (B, M, None) --(UB-AG)-> (B, None, None) x (N, None)^T = (B, None, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] for i in range(rhs_ndim) + ) + + # GEMM output spec keeps LHS batch spec and RHS non-contracting specs, but is None + # in the non-batched leading dimensions. + lhs_non_cspecs_gathered = list( + lhs_specs[i] if i in lhs_bdims else None for i in range(lhs_ndim) if i not in lhs_cdims + ) + rhs_non_cspecs = tuple( + rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims + ) + out_specs = (*lhs_non_cspecs_gathered, *rhs_non_cspecs) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_cspecs_gathered) : ] + gelu_specs = out_specs + + # Auxiliary input/output specs depend on bulk vs. non-bulk overlap + aux_out_specs = (None, ) + if self.output_all_gathered_lhs: + # Auxiliary output is the same as the LHS spec, except the gathered dimension unsharded + self._set_gather_dim(lhs_specs.index(lhs_lspec)) + aux_out_specs = list(lhs_specs).copy() + aux_out_specs[lhs_specs.index(lhs_lspec)] = None + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), + (out_specs, bias_specs, gelu_specs, aux_out_specs), + (None, None, None), + ) + + def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + ) + + (_, lhs_cspec), (_, rhs_cspec) = self._check_operand_specs( + lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) + ) + assert lhs_cspec == rhs_cspec == self.tp_resource, ( + "CommOverlapHelper: Non-batched contracting dimensions of LHS and RHS operands for " + "GEMM->RS overlap must be sharded over the tensor-parallel resource " + f"{self.tp_resource}, but got LHS:{lhs_cspec} and RHS:{rhs_cspec} sharding instead." + ) + + # GEMM->RS overlap: Require non-contracting non-batch dimensions to be unsharded (e.g. FSDP) + # LHS: (B, M, K) --(XLA-AG)-> (B, None, K) + # RHS: (N, K) --(XLA-AG)-> (None, K) + # OUT: (B, None, K) x (B, None, K) = (B, None, None) --(UB-RS)-> (B, M, None) + lhs_specs = tuple( + None if i not in lhs_bdims + lhs_cdims else lhs_specs[i] for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i not in rhs_bdims + rhs_cdims else rhs_specs[i] for i in range(rhs_ndim) + ) + + # GEMM output is the internal communication buffer, but we will use the XLA output buffer + # as the final reduce-scattered output so we shard it accordingly here. + lhs_specs_scattered = list(lhs_specs).copy() + for i in range(lhs_ndim): + if i not in lhs_bdims: + # Update only the first non-batch leading dimension to the TP resource + lhs_specs_scattered[i] = self.tp_resource + break + lhs_non_cspecs_scattered = tuple( + lhs_specs_scattered[i] for i in range(lhs_ndim) if i not in lhs_cdims + ) + rhs_non_cspecs = tuple( + rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims + ) + out_specs = (*lhs_non_cspecs_scattered, *rhs_non_cspecs) + self._set_scatter_dim(out_specs.index(self.tp_resource)) + + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_cspecs_scattered) : ] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), + (out_specs, bias_specs, gelu_specs, (None, )), + (None, None, None), + ) + + def get_partitioning_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + """ + Correct operand specs to partititions suitable for the GemmPrimitive, and infer the + partition specs of the outputs. + """ + if self.is_bulk(): + return self._get_bulk_overlap_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) + + impl_map = { + tex.CommOverlapType.NONE : self._get_no_overlap_rules, + tex.CommOverlapType.AG : self._get_all_gather_rules, + tex.CommOverlapType.RS : self._get_reduce_scatter_rules, + } + return impl_map[self.comm_type](lhs_specs, rhs_specs, aux_in_specs, dimension_numbers) + + +@dataclass(frozen=True) +class CommOverlapHelperSet: + """ + A set of CommOverlapHelper objects that provide complementary comm+GEMM overlap configurations + for FPROP, DGRAD and WGRAD GEMMs in FWD/BWD passes through Dense-layers. + """ + fprop: CommOverlapHelper = field(default=None) + dgrad: CommOverlapHelper = field(default=None) + wgrad: CommOverlapHelper = field(default=None) + + def _sanity_check(self): + if not self.fprop.is_enabled: + assert self.dgrad is None or not self.dgrad.is_enabled, ( + "CommOverlapHelperSet: Comm+GEMM overlap for DGRAD requires comm+GEMM overlap " + "for FPROP to be enabled first." + ) + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: Comm+GEMM overlap for WGRAD requires comm+GEMM overlap " + "for FPROP to be enabled first." + ) + return + + assert not self.fprop.is_bulk(), ( + "CommOverlapHelperSet: Comm+GEMM overlap for FPROP does not support bulk collectives." + ) + + if self.fprop.is_all_gather(): + if self.dgrad is not None: + if self.fprop.output_all_gathered_lhs: + assert not self.dgrad.is_enabled, ( + "CommOverlapHelperSet: AG->GEMM FPROP does not have a corresponding DGRAD " + "overlap when it is configured to return a copy of the all-gathered LHS " + "operand as the auxiliary output." + ) + + elif self.dgrad.is_enabled: + assert ( + (self.dgrad.is_bulk() and self.dgrad.is_all_gather()) + or (not self.dgrad_is_bulk() and self.dgrad.is_reduce_scatter()) + ), ( + "CommOverlapHelperSet: AG->GEMM FPROP requires DGRAD overlap to be either " + "BULK-AG or GEMM->RS." + ) + + if self.wgrad is not None: + if ( + self.dgrad is not None + and self.dgrad.is_enabled + and self.dgrad.is_bulk() # not checking all-gather because we enforced it above + ): + assert ( + self.wgrad.is_enabled + and self.wgrad.is_bulk() + and self.wgrad.is_reduce_scatter() + ), ( + "CommOverlapHelperSet: AG->GEMM FPROP with BULK-AG DGRAD requires " + "WGRAD to overlap with BULK-RS." + ) + else: + assert not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: AG->GEMM FPROP does not have a corresponding WGRAD " + "overlap when DGRAD does not overlap with BULK-AG." + ) + + elif self.fprop.is_reduce_scatter(): + if self.dgrad is not None and self.dgrad.is_enabled: + assert not self.dgrad.is_bulk() and self.dgrad.is_all_gather(), ( + "CommOverlapHelperSet: GEMM->RS overlap in FPROP requires DGRAD overlap to " + "be AG->GEMM." + ) + + if self.wgrad is not None: + assert not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: GEMM->RS overlap in FPROP does not have a " + "corresponding WGRAD overlap." + ) + + def __post_init__(self): + if self.fprop is None: + object.__setattr__(self, "fprop", CommOverlapHelper()) + object.__setattr__(self, "dgrad", CommOverlapHelper()) + object.__setattr__(self, "wgrad", CommOverlapHelper()) + + self._sanity_check() + + if self.fprop.is_enabled: + # FWD/BWD paths with overlap: + # + # 1. AG->GEMM: (B, M, None) --(LHS AG)-> (B, None, None) x (None, N) = (B, None, N) + # DGRAD + Bulk-AG: (B, None, N) x (None, N)^T = (B, None, None) + # (B, M, None) --(LHS bulk-AG)-> (B, None, None) + # WGRAD + Bulk-RS: (B, None, None)^T x (B, None, N) = (None, N) + # (B, None, None) --(DGRAD bulk RS)-> (B, M, None) + # + # 2. GEMM->RS in FPROP: (B, None, K) x (K, None) = (B, None, None) --(RS)-> (B, M, None) + # AG->DGRAD: (B, M, None) --(GRAD AG)-> (B, None, None) x (K, None)^T = (B, None, K) + # WGRAD w/ AG-GRAD from DGRAD: (B, None, K)^T x (B, None, None) = (K, None) + + if self.dgrad is None: + if self.fprop.is_all_gather() and self.fprop.output_all_gathered_lhs: + # If the AG->GEMM FPROP already saved the all-gathered LHS in the autograd + # context, we don't need to overlap a BULK-AG for it with DGRAD. + object.__setattr__(self, "dgrad", CommOverlapHelper()) + + else: + # Otherwise, AG->GEMM FPROP needs BULK-AG DGRAD, and GEMM->RS FPROP needs + # AG->GEMM DGRAD w/ all-gathered gradient returned as auxiliary output to be + # re-used in WGRAD. + object.__setattr__( + self, + "dgrad", + CommOverlapHelper( + method=( + tex.CommOverlapMethod.BULK + if self.fprop.is_all_gather() + else tex.CommOverlapMethod.RING_EXCHANGE + ), + comm_type=tex.CommOverlapType.AG, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + tp_resource=self.fprop.tp_resource, + sp_resource=self.fprop.sp_resource, + output_all_gathered_lhs=self.fprop.is_reduce_scatter(), + ) + ) + + if self.wgrad is None: + if ( + self.fprop.is_all_gather() + and self.dgrad.is_enabled + and self.dgrad.is_bulk() + and self.dgrad.is_all_gather() + ): + # If FPROP does AG->GEMM and DGRAD does BULK-AG, WGRAD needs to do a BULK-RS + object.__setattr__( + self, + "wgrad", + CommOverlapHelper( + method=tex.CommOverlapMethod.BULK, + comm_type=tex.CommOverlapType.RS, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + tp_resource=self.fprop.tp_resource, + sp_resource=self.fprop.sp_resource, + ) + ) + + else: + # Otherwise, WGRAD does not support comm+GEMM overlap + object.__setattr__(self, "wgrad", CommOverlapHelper()) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -158,7 +845,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14) + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -170,6 +857,7 @@ def abstract( rhs_scale_inv, bias, gelu_input, + aux_in, out_dtype, dimension_numbers, lhs_quantized_colwise, @@ -179,6 +867,7 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + comm_overlap, ): del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator @@ -227,10 +916,68 @@ def abstract( (lhs.shape, rhs.shape), (lhs_contracting_dims, rhs_contracting_dims), ) - out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) + out_shape = [*lhs_non_contracting_shape, *rhs_non_contracting_shape] output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) - # Validate bias + # Auxiliary output for comm+GEMM overlap + aux_out_shape = (0, ) + aux_out_dtype = jnp.bfloat16 + if comm_overlap.is_enabled: + if comm_overlap.is_bulk(): + # Bulk overlap will all-gather or reduce-scatter the tensor in the auxiliary input + # and return the result of the collective in the auxiliary output + assert aux_in.size > 0, ( + "cuBLAS GEMM w/ bulk collective overlap requires an auxiliary input." + ) + assert aux_in.ndim > 1, ( + "cuBLAS GEMM w/ bulk collective overlap only supports multidimensional " + "auxiliary inputs." + ) + + aux_out_shape = list(aux_in.shape).copy() + aux_out_dtype = aux_in.dtype + if comm_overlap.sharded_impl: + if comm_overlap["comm_type"] == tex.CommOverlapType.AG: + aux_out_shape[comm_overlap.gather_dim] *= comm_overlap.tp_size + else: + assert aux_in.shape[comm_overlap.scatter_dim] % comm_overlap.tp_size, ( + "cuBLAS GEMM w/ bulk reduce-scatter overlap requires the auxiliary " + "input to be divisible by tensor-parallel size in dimension index " + f"{comm_overlap.scatter_dim}." + ) + aux_out_shape[comm_overlap.scatter_dim] = ( + aux_out_shape[comm_overlap.scatter_dim] // comm_overlap.tp_size + ) + + elif comm_overlap.is_all_gather(): + # Sharded abstract multiplies gathered dimension by TP size + if comm_overlap.sharded_impl: + out_shape[comm_overlap.gather_dim] *= comm_overlap.tp_size + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # AG->GEMM overlap can copy all-gathered LHS into the auxiliary buffer + if comm_overlap.output_all_gathered_lhs: + aux_out_shape = list(lhs.shape).copy() + aux_out_dtype = lhs.dtype + + # Sharded abstract multiplies gathered dimension by TP size + if comm_overlap.sharded_impl: + aux_out_shape[comm_overlap.gather_dim] *= comm_overlap.tp_size + elif comm_overlap.is_reduce_scatter(): + # GEMM->RS auxiliary output is the reduce-scattered output + rs_out_shape = list(out_shape).copy() + + # Sharded abstract divides scattered dimension by TP size + if comm_overlap.sharded_impl: + rs_out_shape[comm_overlap.scatter_dim] = ( + rs_out_shape[comm_overlap.scatter_dim] // comm_overlap.tp_size + ) + + output = jax.core.ShapedArray(shape=rs_out_shape, dtype=out_dtype) + + aux_out = jax.core.ShapedArray(shape=aux_out_shape, dtype=aux_out_dtype) + + # Validate bias -- shape always depends on pure GEMM output even for GEMM->RS overlap bias_shape = (0,) bias_dtype = out_dtype if fuse_bias: @@ -249,7 +996,7 @@ def abstract( bias_shape = rhs_non_contracting_shape bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) - # Validate pre-GeLU + # Validate pre-GeLU -- shape always depends on pure GEMM output even for GEMM->RS overlap pre_gelu_shape = (0,) pre_gelu_dtype = out_dtype if fuse_gelu: @@ -278,13 +1025,17 @@ def abstract( lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) - # Declare cuBLAS workspace - # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not - # necessarily 256 bytes aligned, we add some padding to ensure alignment. - workspace_size = get_cublas_workspace_size_bytes() + 256 + # Size cuBLAS workspace -- multiplied by number of comm+GEMM overlap compute streams + workspace_size = get_cublas_workspace_size_bytes() + if comm_overlap.is_enabled: + workspace_size *= comm_overlap.num_max_streams + + # cuBLAS requires workspace pointers aligned to 256 bytes but XLA does not guarantee that + # so we add to the size here and align the pointer in the C++ custom call. + workspace_size += 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + return output, bias_grad, pre_gelu_out, aux_out, lhs_swizzle, rhs_swizzle, workspace @staticmethod def outer_abstract(*args, **kwargs): @@ -300,6 +1051,7 @@ def lowering( rhs_scale_inv, bias, gelu_input, + aux_in, out_dtype, dimension_numbers, lhs_quantized_colwise, @@ -309,6 +1061,7 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + comm_overlap, ): del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype contracting_dims, _ = dimension_numbers @@ -318,7 +1071,7 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, aux_in) kwargs = { "scaling_mode": int(scaling_mode.value), "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), @@ -329,10 +1082,8 @@ def lowering( "fuse_gelu": fuse_gelu, "grad": grad, "use_split_accumulator": use_split_accumulator, - "comm_overlap_id": -1, - "comm_overlap_method": tex.CommOverlapMethod.NONE, - "comm_type": tex.CommOverlapType.NONE, } + kwargs.update(comm_overlap.get_lowering_kwargs()) operand_output_aliases = {} if fuse_bias and not grad: @@ -353,6 +1104,7 @@ def impl( rhs_scale_inv, bias, gelu_input, + aux_in, out_dtype, dimension_numbers, lhs_quantized_colwise, @@ -362,6 +1114,8 @@ def impl( fuse_gelu, grad, use_split_accumulator, + comm_overlap, + ): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0]) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -390,6 +1144,7 @@ def impl( rhs_scale_inv, bias, gelu_input, + aux_in, out_dtype=out_dtype, dimension_numbers=dimension_numbers, lhs_quantized_colwise=lhs_quantized_colwise, @@ -399,6 +1154,7 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, ) return outputs[:-3] # discard workspace arrays @@ -415,9 +1171,10 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + comm_overlap, ): assert GemmPrimitive.outer_primitive is not None - lhs, _, rhs, *_ = batched_args + lhs, _, rhs, *_, aux_in_bdims = batched_args lhs_bdims, _, rhs_bdims, *_ = batch_dims contracting_dims, batch_dims = dimension_numbers arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) @@ -445,6 +1202,16 @@ def batcher( if fuse_gelu and not grad: pre_gelu_bdims = out_bdims + aux_out_bdims = (None, ) + if comm_overlap.is_enabled: + if comm_overlap.is_bulk(): + # Bulk overlap auxiliary output must have the same batch dims as the auxiliary input + aux_out_bdims = aux_in_bdims + elif comm_overlap.is_all_gather() and comm_overlap.output_all_gathered_lhs: + # AG->GEMM overlap with all-gathered LHS output must have same batch dims as + # sharded LHS input + aux_out_bdims = arg_lhs_bdims + return ( GemmPrimitive.outer_primitive.bind( *batched_args, @@ -457,165 +1224,9 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, ), - (out_bdims, bias_bdims, pre_gelu_bdims), - ) - - @staticmethod - def _decompose_operand_specs(specs, contracting_dims, batch_dims): - ndims = len(specs) - cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) - - # Batch specs - bspecs = tuple(specs[i] for i in bdims) - - # Non-batch leading dimension specs - lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) - - # Non-batch contracting dimension specs - cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) - - return bspecs, lspecs, cspecs - - @staticmethod - def _parse_operand_output_specs(arg_infos, dimension_numbers): - lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) - contracting_dims, batch_dims = dimension_numbers - lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) - lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( - sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims - ) - ( - (lhs_bspecs, lhs_lspecs, lhs_cspecs), - (rhs_bspecs, rhs_lspecs, rhs_cspecs), - ) = map( - GemmPrimitive._decompose_operand_specs, - (lhs_specs, rhs_specs), - (lhs_cdims, rhs_cdims), - (lhs_bdims, rhs_bdims), - ) - - # Batched dimensions must have the same sharding - if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: - assert all( - lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) - ), ( - "cuBLAS GEMM operand batch dimensions must have the same sharding: " - f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." - ) - - # Only one each of the non-batched leading dimensions and non-batched contracting - # dimensions can be sharded - lhs_ldims, rhs_ldims = map( - lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), - (lhs_ndim, rhs_ndim), - (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), - ) - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( - lambda specs: tuple(spec for spec in specs if spec is not None), - (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), - ) - assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( - "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " - f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." - ) - assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( - "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " - f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." - ) - - # Extract single leading and contracting dimension specs - (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( - lambda specs: None if len(specs) == 0 else specs[0], - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), - ) - - # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts - # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. - # 1. K1 == K2 != None and N == None - # LHS: (B, M, K) - # RHS: (B, None, K) - # OUT: (B, M, None) --(AR)-> (B, M, None) - # 2. K1 == K2 != None and M == N != None - # LHS: (B, M, K) - # RHS: (B, N, K)--(AG)->(B, None, K) - # OUT: (B, M, None) --(RS)--> (B, M, N) - # 3. M == N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, M, K)--(AG)->(B, None, None) - # OUT: (B, M, None) - # 4. M != N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, N, K)--(AG)->(B, N, None) - # OUT: (B, M, N) - reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec - all_reduce_output = reduce_flag and rhs_lspec is None - reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec - all_reduce_spec = reduce_scatter_spec = scatter_dim = None - - lhs_non_contracting_specs, rhs_non_contracting_specs = map( - lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), - (lhs_specs, rhs_specs), - (lhs_cdims, rhs_cdims), - ) - out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) - if reduce_scatter_output: - # All-gather (if necessary) the non-batch non-contracting dimension of RHS - # (B, N, K) --(AG)-> (B, None, K) - # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) - rhs_spec = tuple( - rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) - ) - reduce_scatter_spec = lhs_cspec - scatter_dim = out_specs.index(rhs_lspec) - - elif all_reduce_output: - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) - all_reduce_spec = lhs_cspec - else: - # All-gather (if necessary) the non-batch contracting dimensions - # (B, M, K) --(AG)-> (B, M, None) - # (B, N, K) --(AG)-> (B, N, None) - # (B, M, None) x (B, N, None)^T = (B, M, N) - lhs_specs = tuple( - None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] - for i in range(lhs_ndim) - ) - rhs_specs = tuple( - None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Check if RHS non-contracting spec also appears in the LHS non-contracting specs - if rhs_lspec is not None and rhs_lspec in tuple( - lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims - ): - # All-gather (if necessary) the non-batch non-contracting dimensions of RHS - # (B, N, None) --(AG)-> (B, None, None) - # (B, M, None) x (B, None, None)^T = (B, M, None) - rhs_specs = tuple( - None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) - - # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_non_contracting_specs) :] - gelu_specs = out_specs - - return ( - (lhs_specs, rhs_specs, bias_specs, gelu_specs), - (out_specs, bias_specs, gelu_specs), - all_reduce_spec, - reduce_scatter_spec, - scatter_dim, + (out_bdims, bias_bdims, pre_gelu_bdims, aux_out_bdims), ) @staticmethod @@ -629,6 +1240,7 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + comm_overlap, mesh, arg_infos, result_infos, @@ -642,22 +1254,28 @@ def infer_sharding_from_operands( ) del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) + lhs_specs, _, rhs_specs, *_, aux_in_specs = map(get_padded_spec, arg_infos) + ( + _, (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), *_ + ) = comm_overlap.get_partitioning_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers ) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Discard bias gradient spec if there is no bias fusion + # Discard bias gradient and pre-GeLU output specs based on fusion choices if not fuse_bias: - dbias_specs = (None,) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) - - # Discard pre-GeLU output spec if there is no GeLU fusion + bias_grad_specs = (None,) if not fuse_gelu: pre_gelu_specs = (None,) - pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) - return [out_sharding, dbias_sharding, pre_gelu_sharding] + # Assemble output shardings + out_shardings = list( + map( + lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs) + ) + ) + + return out_shardings @staticmethod def partition( @@ -670,56 +1288,63 @@ def partition( fuse_gelu, grad, use_split_accumulator, + comm_overlap, mesh, arg_infos, result_infos, ): del result_infos + lhs_specs, _, rhs_specs, *_, aux_in_specs = map(get_padded_spec, arg_infos) ( - (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), - (out_specs, dbias_specs, pre_gelu_specs), - all_reduce_spec, - reduce_scatter_spec, - scatter_dim, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, dimension_numbers) - - # Assemble argument shardings - # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. - none_sharding = NamedSharding(mesh, PartitionSpec(None)) - lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) - rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) - arg_shardings = ( - lhs_sharding, - lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, - rhs_sharding, - rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + (lhs_specs, rhs_specs, bias_specs, gelu_input_specs, aux_in_specs), + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), + (all_reduce_spec, reduce_scatter_spec, scatter_dim), + ) = comm_overlap.get_partitioning_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers ) - # Discard bias input spec if there is no bias fusion - if not fuse_bias: - bias_input_specs = (None,) - arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),) + # Block scale inverses match their operands, but tensor scale inverses are unsharded. + lhs_scale_specs = (None, ) + rhs_scale_specs = (None, ) + if scaling_mode.is_1d_block_scaling() and not comm_overlap.is_enabled: + lhs_scale_specs = lhs_specs + rhs_scale_specs = rhs_specs - # Discard pre-GeLU input spec if there is no GeLU fusion + # Discard bias and pre-GeLU specs based on fusion choices + if not fuse_bias: + bias_specs = (None,) + bias_grad_specs = (None,) if not fuse_gelu: gelu_input_specs = (None,) - arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) - - # Assemble output shardings - out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] + pre_gelu_specs = (None,) - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: - dbias_specs = (None,) - out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) + # Assemble argument shardings + arg_shardings = tuple( + map( + lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), + ( + lhs_specs, + lhs_scale_specs, + rhs_specs, + rhs_scale_specs, + bias_specs, + gelu_input_specs, + aux_in_specs + ), + ) + ) - # Discard pre-GeLU output spec if there is no GeLU fusion - if not fuse_gelu: - pre_gelu_specs = (None,) - out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) + # Assemble output shardings + out_shardings = list( + map( + lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), + ) + ) - def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, aux_in): + comm_overlap._set_sharded_impl(True) outputs = GemmPrimitive.impl( lhs, lhs_scale_inv, @@ -727,6 +1352,7 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): rhs_scale_inv, bias, gelu_input, + aux_in, out_dtype=out_dtype, dimension_numbers=dimension_numbers, lhs_quantized_colwise=lhs_quantized_colwise, @@ -736,7 +1362,9 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, ) + comm_overlap._set_sharded_impl(False) # All-Reduce/Reduce-Scatter GEMM output if all_reduce_spec is not None: @@ -770,48 +1398,6 @@ def shardy_sharding_rule(*args, **kwargs): register_primitive(GemmPrimitive) -@dataclass -class CommOverlapHelper: - - buffer_shape: Sequence[int] = (0, ) - buffer_dtype: jnp.dtype = jnp.bfloat16 - tp_size: int = 1 - comm_type: tex.CommOverlapType = tex.CommOverlapType.NONE - method: tex.CommOverlapMethod = tex.CommOverlapMethod.NONE - unique_id: int = field(default=-1, init=False) - - def __post_init__(self): - if self.comm_type == tex.CommOverlapType.NONE: - assert self.method == tex.CommOverlapMethod.NONE - else: - assert self.method != tex.CommOverlapMethod.NONE - if self.comm_type == tex.CommOverlapType.AG: - assert self.method != tex.CommOverlapType.PIPELINE, ( - "Comm+GEMM overlap w/ PIPELINE method does not support all-gather." - ) - else: - # Reduce-Scatter overlap always needs an auxiliary output - self.needs_aux_out = True - assert self.tp_size > 1, ( - "Comm+GEMM overlap requires tensor-parallel size larger than 1." - ) - - def is_enabled(self): - return self.method != tex.CommOverlapMethod.NONE - - def is_bulk(self): - return self.method == tex.CommOverlapMethod.BULK - - def is_p2p(self): - return self.method == tex.CommOverlapMethod.RING_EXCHANGE - - def is_all_gather(self): - return self.comm_type == tex.CommOverlapType.AG - - def is_reduce_scatter(self): - return self.comm_type == tex.CommOverlapType.RS - - def gemm_uses_jax_dot() -> bool: """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" return not GemmPrimitive.enabled() @@ -832,6 +1418,7 @@ def _te_gemm( rhs: Union[jax.Array, ScaledTensor], bias: jax.Array = None, gelu_input: jax.Array = None, + aux_in: jax.Array = None, lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), @@ -839,6 +1426,7 @@ def _te_gemm( fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + comm_overlap: CommOverlapHelper = CommOverlapHelper(), ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands lhs_data = lhs @@ -888,12 +1476,14 @@ def _te_gemm( rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) - # Dummy empties for bias and gelu + # Dummy empties for bias, gelu and aux_in out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype if bias is None or not (fuse_bias and not grad): bias = jnp.empty(0, dtype=out_dtype) if gelu_input is None or not (fuse_gelu and grad): gelu_input = jnp.empty(0, dtype=out_dtype) + if aux_in is None or not comm_overlap.is_enabled: + aux_in = jnp.empty(0, dtype=jnp.bfloat16) return GemmPrimitive.outer_primitive.bind( lhs_data, @@ -902,6 +1492,7 @@ def _te_gemm( rhs_scale_inv, bias, gelu_input, + aux_in, out_dtype=out_dtype, dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, @@ -911,6 +1502,7 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, ) @@ -1259,6 +1851,8 @@ def gemm( use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + comm_overlap: CommOverlapHelper, default = None + Helper object that manages comm+GEMM overlap options. Returns ------- @@ -1310,13 +1904,18 @@ def gemm( # Discard empty outputs grad = kwargs.get("grad", False) + comm_overlap = kwargs.get("comm_overlap", CommOverlapHelper()) clean_outputs = outputs[0] # first output is the final result and is never empty - if (fuse_bias and grad) or (fuse_gelu and not grad): + if (fuse_bias and grad) or (fuse_gelu and not grad) or comm_overlap.has_aux_output(): clean_outputs = (outputs[0],) if fuse_bias and grad: # only return bias gradient if it exists clean_outputs += (outputs[1],) if fuse_gelu and not grad: # only return pre-GeLU output if it exists clean_outputs += (outputs[2],) + if comm_overlap.has_aux_output(): + # only return aux output for bulk overlap or non-bulk all-gather overlap + # with gathered LHS output + clean_outputs += (outputs[3],) return clean_outputs diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 6c2c021116..1339d3c5dd 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -125,6 +125,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +int64_t CreateCommOverlapBuffer(CommOverlapType comm_type, CommOverlapMethod method, + const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, int num_splits = 3, int num_max_streams = 3, + int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 16, int set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool rs_overlap_first_gemm = false, + bool aggregate_ag = false); + +void DestroyCommOverlapBuffer(size_t unique_id); + +void DestroyAllCommOverlapBuffers(); + // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 35d629f837..f0c86910e6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -132,15 +132,30 @@ int64_t CreateCommOverlapBuffer(CommOverlapType comm_type, CommOverlapMethod met return unique_id; } +void DestroyCommOverlapBuffer(size_t unique_id) { + auto it = comm_overlaps.find(unique_id); + if (it != comm_overlaps.end()) { + delete it->second; + comm_overlaps.erase(it); + } +} + +void DestroyAllCommOverlapBuffers() { + for (auto it = comm_overlaps.begin(); it != comm_overlaps.end();) { + delete it->second; + it = comm_overlaps.erase(it); + } +} + Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type aux_in, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type aux_out, Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, - int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, - bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, - bool use_split_accumulator, int64_t comm_overlap_id, - CommOverlapMethod comm_overlap_method, CommOverlapType comm_type) { + CommOverlapMethod comm_overlap_method, CommOverlapType comm_type, + int64_t comm_overlap_id, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, + int64_t aux_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) @@ -153,15 +168,16 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); - // Output tensor + // Output tensor -- create with nullptr for GEMM->RS overlap because GEMM output goes into + // the communication buffer. We can use the XLA output buffer for the reduce-scattered + // auxiliary output tensor later. std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, expected ", - out_.numel(), " elements ", to_string_like(out_shape), " but got ", - output->element_count(), " elements ", to_string_like(output->dimensions())); + void* out_ptr = + (comm_type == CommOverlapType::RS && comm_overlap_method != CommOverlapMethod::BULK) + ? comm_overlaps[comm_overlap_id]->get_ubuf_dptr() : output->untyped_data(); + auto out_ = TensorWrapper(out_ptr, out_shape, out_dtype); // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; @@ -203,6 +219,11 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); if (comm_type == CommOverlapType::NONE) { + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", + out_.numel(), " elements ", to_string_like(out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, workspace_.data(), false, use_split_accumulator, num_math_sm, stream); @@ -216,15 +237,17 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); if ((comm_type == CommOverlapType::AG && aux_out->element_count() > 0) || comm_type == CommOverlapType::RS) { - std::vector aux_out_shape = {product(aux_out_dims, 0, aux_out_dims.size() - 1), - static_cast(aux_out_dims.back())}; + std::vector aux_out_shape = { + product(aux_out_dims, 0, aux_axis_boundary), + product(aux_out_dims, aux_axis_boundary, aux_out_dims.size())}; } auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); // Copy the auxiliary data into the communications buffer auto aux_in_dims = aux_in.dimensions(); - std::vector aux_in_shape = {product(aux_in_dims, 0, aux_in_dims.size() -1 ), - static_cast(aux_in_dims.back())}; + std::vector aux_in_shape = { + product(aux_in_dims, 0, aux_axis_boundary), + product(aux_in_dims, aux_axis_boundary, aux_in_dims.size())}; auto aux_in_dtype = convert_ffi_datatype_to_te_dtype(aux_in.element_type()); auto aux_in_ = TensorWrapper(aux_in.untyped_data(), aux_in_shape, aux_in_dtype); if (comm_type == CommOverlapType::AG && aux_out->element_count() > 0) { @@ -246,18 +269,18 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i stream); } else if (comm_type == CommOverlapType::RS) { // Prepare the auxiliary buffer for the reduce-scattered GEMM output - auto aux_out_shape = std::vector(out_shape); - aux_out_shape.at(0) /= tp_size; - auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); - auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); - NVTE_CHECK(aux_out_.numel() == aux_out->element_count(), - "cuBLAS GEMM->RS overlap auxiliary buffer is sized incorrectly, expected ", - aux_out_.numel(), " elements ", to_string_like(aux_out_shape), " but got ", - aux_out->element_count(), " elements ", to_string_like(aux_out->dimensions())); + auto rs_out_shape = std::vector(out_shape); + rs_out_shape.at(0) /= tp_size; + auto rs_out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto rs_out_ = TensorWrapper(output->untyped_data(), rs_out_shape, rs_out_dtype); + NVTE_CHECK(rs_out_.numel() == output->element_count(), + "cuBLAS GEMM->RS overlap output buffer is sized incorrectly, expected ", + rs_out_.numel(), " elements ", to_string_like(rs_out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); // Launch GEMM+RS executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, - workspace_, grad, false, use_split_accumulator, aux_out_, + workspace_, grad, false, use_split_accumulator, rs_out_, stream); } else if (comm_type == CommOverlapType::AG) { // Prepare the auxiliary buffer for all-gathered LHS @@ -299,24 +322,27 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // aux_in .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out + .Ret() // aux_out .Ret() // lhs_swizzled .Ret() // rhs_swizzled .Ret() // workspace .Attr("scaling_mode") + .Attr("comm_overlap_method") + .Attr("comm_type") + .Attr("comm_overlap_id") .Attr("lhs_axis_boundary") .Attr("rhs_axis_boundary") + .Attr("aux_axis_boundary") .Attr("lhs_transposed") .Attr("rhs_transposed") .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") - .Attr("use_split_accumulator") - .Attr("comm_overlap_id") - .Attr("comm_overlap_method") - .Attr("comm_type"), + .Attr("use_split_accumulator"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index d0ef724b20..fc899a92de 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -69,27 +69,51 @@ pybind11::dict Registrations() { return dict; } +} // namespace jax +} // namespace transformer_engine + PYBIND11_MODULE(transformer_engine_jax, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) - m.def("registrations", &Registrations); - m.def("get_fused_attn_backend", &GetFusedAttnBackend); - m.def("get_cuda_version", &GetCudaRuntimeVersion); - m.def("get_cudnn_version", &GetCudnnRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("registrations", &transformer_engine::jax::Registrations); + m.def("get_fused_attn_backend", &transformer_engine::jax::GetFusedAttnBackend); + m.def("get_cuda_version", &transformer_engine::jax::GetCudaRuntimeVersion); + m.def("get_cudnn_version", &transformer_engine::jax::GetCudnnRuntimeVersion); + m.def("get_device_compute_capability", &transformer_engine::jax::GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); - m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes); - m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes); - m.def("get_norm_fwd_workspace_sizes", &GetNormForwardWorkspaceSizes); - m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); - m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); - m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); - - pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) - .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) - .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) - .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) - .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING); + m.def("get_dact_dbias_quantize_workspace_sizes", + &transformer_engine::jax::GetDActDBiasQuantizeWorkspaceSizes); + m.def("get_dbias_quantize_workspace_sizes", + &transformer_engine::jax::GetDBiasQuantizeWorkspaceSizes); + m.def("get_norm_fwd_workspace_sizes", &transformer_engine::jax::GetNormForwardWorkspaceSizes); + m.def("get_norm_bwd_workspace_sizes", &transformer_engine::jax::GetNormBackwardWorkspaceSizes); + m.def("get_fused_attn_fwd_workspace_sizes", + &transformer_engine::jax::GetFusedAttnForwardWorkspaceSizes); + m.def("get_fused_attn_bwd_workspace_sizes", + &transformer_engine::jax::GetFusedAttnBackwardWorkspaceSizes); + m.def("create_comm_overlap_buffer", &transformer_engine::jax::CreateCommOverlapBuffer, + pybind11::arg("comm_type"), pybind11::arg("method"), pybind11::arg("buffer_shape"), + pybind11::arg("buffer_dtype"), pybind11::arg("tp_size"), pybind11::pos_only(), + pybind11::kw_only(), pybind11::arg("num_splits") = 4, pybind11::arg("num_max_streams") = 3, + pybind11::arg("comm_cga_size") = 2, pybind11::arg("gemm_priority") = 0, + pybind11::arg("comm_priority") = 0, pybind11::arg("num_comm_sm") = 16, + pybind11::arg("set_sm_margin") = true, pybind11::arg("use_ce") = true, + pybind11::arg("atomic_gemm") = false, pybind11::arg("rs_overlap_first_gemm") = false, + pybind11::arg("aggregate_ag") = false, + pybind11::call_guard()); + m.def("destroy_comm_overlap_buffer", &transformer_engine::jax::DestroyCommOverlapBuffer, + pybind11::call_guard()); + m.def("destroy_all_comm_overlap_buffers", &transformer_engine::jax::DestroyAllCommOverlapBuffers, + pybind11::call_guard()); + + pybind11::enum_(m, "JAXX_Scaling_Mode", + pybind11::module_local()) + .value("NO_SCALING", transformer_engine::jax::JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", + transformer_engine::jax::JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", transformer_engine::jax::JAXX_Scaling_Mode::MXFP8_1D_SCALING) + .value("CURRENT_TENSOR_SCALING", + transformer_engine::jax::JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING); pybind11::enum_(m, "QuantizeLayout", pybind11::module_local()) @@ -98,5 +122,4 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE); } -} // namespace jax -} // namespace transformer_engine + diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 0d4a1b7524..0670a4811f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -31,6 +31,7 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, batch_first: bool = True, + comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -45,6 +46,7 @@ def dense( bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -59,13 +61,15 @@ def dense( output += jnp.reshape(bias, bias_new_shape) else: output = _dense( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, + quantizer_set ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, + quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -78,20 +82,23 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir contracting_dims: Contracting dimensions specification input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix - quantizer_set: QuantizerSet which contains quantizers for different tensor types batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, + quantizer_set ) return output def _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, + quantizer_set ): """Forward pass rule for dense layer transformation. @@ -144,6 +151,7 @@ def _dense_fwd_rule( dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + comm_overlap=comm_overlaps.fprop, ) if use_bias and tex.gemm_uses_jax_dot(): @@ -164,7 +172,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -194,6 +202,21 @@ def _dense_bwd_rule( noop_scaled_tensor=True, ) + # If casted_x has transposed data-layout, we need to untranspose it here, and then transpose + # it back after the bulk-AG. This should ideally never be necessary if the data layouts are + # handled correctly in the tensor usages. + dgrad_aux_in = None + dgrad_aux_transposed_axes = ( + *tuple(range(casted_x_lhs.flatten_axis, casted_x_lhs.ndim)), + *tuple(range(casted_x_lhs.flatten_axis)), + ) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + dgrad_aux_in = ( + casted_x_lhs.data.transpose(dgrad_aux_transposed_axes) + if casted_x_lhs.data_layout == "T" + else casted_x_lhs.data + ) + # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim g_contracting_dim = tuple( @@ -207,8 +230,9 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, dimension_numbers=((g_contracting_dim, k_contracting_dim), ((x_bdim,), ())), + comm_overlap=comm_overlaps.dgrad, + aux_in=dgrad_aux_in, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -216,13 +240,40 @@ def _dense_bwd_rule( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) + casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + # LHS was bulk all-gathered during DGRAD and returned as auxiliary input + casted_x_lhs.data = ( + dgrad[-1].transpose(dgrad_aux_transposed_axes) + if casted_x_lhs.data_layout == "T" + else dgrad[-1] + ) + # DGRAD output will need to be bulk reduce-scattered during WGRAD + dgrad = dgrad[0] + elif comm_overlaps.dgrad.is_all_gather() and comm_overlaps.dgrad.output_gathered_lhs: + # GRAD was all-gathered for DGRAD and a copy of the gathered GRAD is in the auxiliary output + casted_grad_rhs.data = ( + dgrad[-1].transpose(*range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), + *range(casted_grad_rhs.flatten_axis)) + if casted_grad_rhs.data_layout == "T" + else dgrad[-1] + ) + dgrad = dgrad[1] + wgrad = tex.gemm( casted_x_lhs, - casted_grad.get_tensor(usage=TensorUsage.RHS), + casted_grad_rhs, dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim,))), + comm_overlap=comm_overlaps.wgrad, + aux_in=(dgrad if comm_overlaps.wgrad.is_bulk() else None), ) - wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) + if comm_overlaps.wgrad.is_bulk(): + # DGRAD was bulk reduce-scattered during WGRAD and returned as auxiliary output + dgrad = wgrad[-1] + wgrad = wgrad[0] + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index bd311472f0..00f7aab1a2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -7,6 +7,7 @@ from functools import reduce import operator from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union +from dataclasses import field import numpy as np import jax.numpy as jnp @@ -29,18 +30,21 @@ jax_scaled_softmax, jax_scaled_masked_softmax, jax_scaled_upper_triang_masked_softmax, + CommOverlapHelper, + CommOverlapHelperSet, ) from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode -from ..sharding import get_non_contracting_logical_axes +from ..sharding import get_non_contracting_logical_axes, get_padded_spec + +import transformer_engine_jax as tex PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = jnp.ndarray +jnp.dtype = jnp.dtype PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] -Initializer = Callable[[PRNGKey, Shape, DType], Array] +Initializer = Callable[[PRNGKey, Shape, jnp.dtype], jnp.ndarray] def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: @@ -108,7 +112,7 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") -def _combine_biases(*masks: List[Array]): +def _combine_biases(*masks: List[jnp.ndarray]): """Combine attention biases.""" masks = [m for m in masks if m is not None] if not masks: @@ -149,6 +153,77 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, return output +def _generate_comm_overlap_metas( + inputs_shape: Sequence[int], + param_shape: Sequence[int], + param_partitioning: nn.LogicallyPartitioned, + enabled: bool = True, + config: dict = {}, +): + if not enabled: + return CommOverlapHelperSet() + + param_sharding = param_partitioning.get_sharding() + param_specs = get_padded_spec(param_sharding.spec, len(param_shape)) + column_parallel = param_specs[-1] is not None + row_parallel = any(spec is not None for spec in param_specs[:-1]) + + comm_type = config.pop("comm_type", None) + if row_parallel and column_parallel: + assert comm_type is not None, ( + "Collective type for communication overlap must be explicitly set via " + "`comm_overlap_config={'comm_type' : ... }` when module parameters are " + "sharded in both contracting and non-contracting dimensions " + "(e.g. FSDP+TP sharding)." + ) + row_parallel = comm_type == tex.CommOverlapType.RS + column_parallel = comm_type == tex.CommOverlapType.AG + + mesh = param_sharding.mesh + buffer_shape = inputs_shape + tp_size = 1 + tp_resource = None + if row_parallel: + contracting_specs = tuple(spec for spec in param_specs[:-1] if spec is not None) + assert len(contracting_specs) == 1, ( + "Module parameter cannot have more than one sharded contracting dimension " + "GEMM->RS overlap is enabled." + ) + tp_resource = contracting_specs[0] + tp_size = mesh.shape[mesh.axis_names.index(tp_resource)] + comm_type = tex.CommOverlapType.RS + buffer_shape = (*inputs_shape[:-1], param_shape[-1]) + + elif column_parallel: + tp_resource = param_specs[-1] + assert tp_resource is not None, ( + "Module parameter must be sharded in the non-contracting dimension when " + "AG->GEMM overlap is enabled." + ) + tp_size = mesh.shape[mesh.axis_names.index(tp_resource)] + comm_type = tex.CommOverlapType.AG + + else: + raise AssertionError("") + + method = config.pop("method", tex.CommOverlapMethod.RING_EXCHANGE) + buffer_shape = config.pop("buffer_shape", buffer_shape) + buffer_dtype = config.pop("buffer_dtype", jnp.bfloat16) + tp_size = config.pop("tp_size", tp_size) + tp_resource = config.pop("tp_resource", tp_resource) + return CommOverlapHelperSet( + fprop=CommOverlapHelper( + method=method, + comm_type=comm_type, + buffer_shape=buffer_shape, + buffer_dtype=buffer_dtype, + tp_size=tp_size, + tp_resource=tp_resource, + **config, + ) + ) + + class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. @@ -172,7 +247,7 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods softmax_type: SoftmaxType = SoftmaxType.SCALED @nn.compact - def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray: + def __call__(self, inputs: jnp.ndarray, mask: jnp.ndarray = None, bias: jnp.ndarray = None) -> jnp.ndarray: batch = inputs.shape[0] heads = inputs.shape[1] q_seqlen = inputs.shape[2] @@ -287,7 +362,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods scale_axes: Tuple[str, ...] = ("embed",) bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): @@ -415,12 +490,17 @@ class DenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + enable_comm_overlap: bool, default = False + Enable fine-grained All-Gather or Reduce-Scatter overlap with GEMM for sequence-parallel + inputs. + comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options. Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -436,9 +516,11 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 transpose_batch_sequence: bool = False input_axes: Tuple[str, ...] = () + enable_comm_overlap: bool = False + comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -448,7 +530,7 @@ def __post_init__(self): super().__post_init__() @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """ Apply the dense layer transformation to the input. @@ -476,9 +558,15 @@ def __call__(self, inputs: Array) -> Array: "Expected len(kernel_shape) to match len(kernel_axes)," f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" ) + else: + assert not self.enable_comm_overlap, ( + "Communication + GEMM overlap requires the dot kernel sharding to be defined in " + "`kernel_axes`." + ) + kernel_partitioning = nn.with_logical_partitioning(self.kernel_init, self.kernel_axes) kernel = self.param( "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_partitioning, kernel_shape, self.dtype, ) @@ -505,6 +593,13 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + comm_overlaps=_generate_comm_overlap_metas( + inputs.shape, + kernel_shape, + kernel_partitioning, + enabled=self.enable_comm_overlap, + config=self.comm_overlap_config, + ) ) if self.enable_low_rank_adaptation: @@ -617,12 +712,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input of dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + enable_comm_overlap: bool, default = False + Enable fine-grained All-Gather overlap with GEMM for sequence-parallel inputs. + comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options. Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -650,11 +749,13 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True + dtype: jnp.dtype = jnp.float32 + transpose_batch_sequence: bool = False layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None + enable_comm_overlap: bool = False + comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -672,7 +773,7 @@ def __post_init__(self): super().__post_init__() @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """ Apply layer normalization to the input followed by a dense layer transformation. @@ -742,9 +843,21 @@ def __call__(self, inputs: Array) -> Array: axis = _normalize_axes(axis, y.ndim) kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features + + if self.kernel_axes: + assert len(kernel_shape) == len(self.kernel_axes), ( + "Expected len(kernel_shape) to match len(kernel_axes)," + f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" + ) + else: + assert not self.enable_comm_overlap, ( + "Communication + GEMM overlap requires the dot kernel sharding to be defined in " + "`kernel_axes`." + ) + kernel_partitioning = nn.with_logical_partitioning(self.kernel_init, self.kernel_axes) kernel = self.param( "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_partitioning, kernel_shape, self.dtype, ) @@ -753,6 +866,16 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(0, len(axis))) + if self.enable_comm_overlap: + # All-Gather is the only supported collective to overlap in LayerNormDenseGeneral + self.comm_overlap_config.update({"comm_type" : tex.CommOverlapType.AG}) + comm_overlaps = _generate_comm_overlap_metas( + inputs.shape, + kernel_shape, + kernel_partitioning, + enabled=self.enable_comm_overlap, + config=self.comm_overlap_config, + ) if fuse_layernorm: z = layernorm_dense( y, @@ -766,6 +889,7 @@ def __call__(self, inputs: Array) -> Array: dot_input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + comm_overlaps=comm_overlaps, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -776,6 +900,7 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + comm_overlaps=comm_overlaps, ) if self.enable_low_rank_adaptation: @@ -924,12 +1049,25 @@ class LayerNormMLP(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input of 2nd dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + enable_comm_overlap: bool, default = False + Enable fine-grained All-Gather overlap with the 1st dot and Reduce-Scatter overlap with + the 2nd dot. + enable_dot_1_comm_overlap: bool, default = False + Enable fine-grained All-Gather overlap with the 1st dot. This option is overriden by + `enable_comm_overlap=True`. + enable_dot_2_comm_overlap: bool, default = False + Enable fine-grained Reduce-Scatter overlap with the 2nd dot. This option is overriden by + `enable_comm_overlap=True`. + dot_1_comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options for the 1st dot. + dot_2_comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options for the 2nd dot. Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -960,11 +1098,16 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True + dtype: jnp.dtype = jnp.float32 + transpose_batch_sequence: bool = False layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None + enable_comm_overlap: bool = False + enable_dot_1_comm_overlap: bool = False + enable_dot_2_comm_overlap: bool = False + dot_1_comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call + dot_2_comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -978,7 +1121,7 @@ def __post_init__(self): super().__post_init__() @nn.compact - def __call__(self, inputs: Array, deterministic: bool = False) -> Array: + def __call__(self, inputs: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1082,9 +1225,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): axis = _canonicalize_tuple(self.axis) axis = _normalize_axes(axis, y.ndim) kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim) + self.enable_dot_1_comm_overlap = self.enable_dot_1_comm_overlap or self.enable_comm_overlap + if self.kernel_1_axes: + assert len(kernel_1_each_shape) == len(self.kernel_axes), ( + "Expected len(kernel_1_shape) to match len(kernel_1_axes)," + f"got kernel_shape {kernel_1_each_shape} and kernel_axes {self.kernel_1_axes}" + ) + else: + assert not self.enable_dot_1_comm_overlap, ( + "Communication + GEMM overlap for the 1st dot requires the kernel sharding to be " + "defined in `kernel_1_axes`." + ) + kernel_1_partitioning = nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1) kernel_1 = self.param( "wi_kernel", - nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1), + kernel_1_partitioning, num_activations, -2, kernel_1_each_shape, @@ -1097,9 +1252,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple + self.enable_dot_2_comm_overlap = self.enable_dot_2_comm_overlap or self.enable_comm_overlap + if self.kernel_2_axes: + assert len(kernel_2_shape) == len(self.kernel_2_axes), ( + "Expected len(kernel_2_shape) to match len(kernel_2_axes)," + f"got kernel_shape {kernel_2_shape} and kernel_axes {self.kernel_2_axes}" + ) + else: + assert not self.enable_dot_2_comm_overlap, ( + "Communication + GEMM overlap for the 2nd dot requires the kernel sharding to be " + "defined in `kernel_2_axes`." + ) + kernel_2_partitioning = nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2) kernel_2 = self.param( "wo_kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2), + kernel_2_partitioning, kernel_2_shape, self.dtype, ) @@ -1131,6 +1298,26 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name = "ffn1" ffn2_ckpt_name = "ffn2" + if self.enable_dot_1_comm_overlap: + # All-Gather is the only supported collective to overlap with the 1st dot + self.dot_1_comm_overlap_config.update({"comm_type" : tex.CommOverlapType.AG}) + ffn1_comm_overlaps = _generate_comm_overlap_metas( + inputs.shape, + kernel_1_each_shape, + kernel_1_partitioning, + enabled=self.enable_dot_1_comm_overlap, + config=self.enable_dot_1_comm_overlap, + ) + if self.enable_dot_2_comm_overlap: + # Reduce-Scatter is the only supported collective to overlap with the 2nd dot + self.dot_2_comm_overlap_config.update({"comm_type" : tex.CommOverlapType.RS}) + ffn2_comm_overlaps = _generate_comm_overlap_metas( + inputs.shape, + kernel_2_shape, + kernel_2_partitioning, + enabled=self.enable_dot_2_comm_overlap, + config=self.enable_dot_2_comm_overlap, + ) if use_fused_layernorm_mlp: out = layernorm_mlp( y, @@ -1150,6 +1337,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn2_ckpt_name=ffn2_ckpt_name, activation_type=normalized_acts, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), + ffn1_comm_overlaps=ffn1_comm_overlaps, + ffn2_comm_overlaps=ffn2_comm_overlaps, ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1168,6 +1357,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + comm_overlaps=ffn1_comm_overlaps, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) @@ -1178,6 +1368,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + comm_overlaps=ffn1_comm_overlaps, ) if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None: @@ -1259,6 +1450,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_2_input_axes, kernel_axes=self.kernel_axes_2, quantizer_set=ffn2_quantizer_set, + comm_overlaps=ffn2_comm_overlaps, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 62fa2cfcd2..57caaba5a8 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -38,6 +38,7 @@ def layernorm_dense( dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, batch_first: bool = True, + comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -59,6 +60,7 @@ def layernorm_dense( dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: Set of quantizers for different tensor types Returns: @@ -83,6 +85,7 @@ def layernorm_dense( dot_input_axes, kernel_axes, batch_first, + comm_overlaps, quantizer_set, ) return output @@ -98,6 +101,7 @@ def layernorm_dense( 9, 10, 11, + 12, ), ) def _layernorm_dense( @@ -113,6 +117,7 @@ def _layernorm_dense( dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], batch_first: bool, + comm_overlaps: tex.CommOverlapHelperSet, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -133,6 +138,7 @@ def _layernorm_dense( layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: Set of quantizers Returns: @@ -151,6 +157,7 @@ def _layernorm_dense( dot_input_axes, kernel_axes, batch_first, + comm_overlaps, quantizer_set, ) return output @@ -169,6 +176,7 @@ def _layernorm_dense_fwd_rule( dot_input_axes, kernel_axes, batch_first, + comm_overlaps, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -212,7 +220,9 @@ def _layernorm_dense_fwd_rule( casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out...) + # (batch..., sequence, hidden_in) x (hidden_in, hidden_out...) + # NOTE: Comm+GEMM overlap can only do AG->GEMM here to all-gather a sequence-parallel layernorm + # output because the weights for a QKV projection is always column-parallel. use_bias = bias is not None output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), @@ -220,14 +230,31 @@ def _layernorm_dense_fwd_rule( dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + comm_overlap=comm_overlaps.fprop, ) + # If Comm+GEMM overlap for FPROP was configured to return the all-gathered layernorm output + # as the auxiliary output, we may need to transpose it here to match the expected data + # layout in the backward pass. Otherwise, the + casted_ln_out_for_bwd = casted_ln_out.get_tensor(TensorUsage.LHS_TRANS) + ln_out_transposed_dims = ( + *tuple(range(casted_ln_out_for_bwd.flatten_axis, casted_ln_out_for_bwd.ndim)), + *tuple(range(casted_ln_out_for_bwd.flatten_axis)) + ) + if comm_overlaps.fprop.output_all_gathered_lhs: + casted_ln_out_for_bwd.data = ( + output[-1].transpose(ln_out_transposed_dims) + if casted_ln_out_for_bwd.data_layout == "T" + else output[-1] + ) + output = output[0] + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), + casted_ln_out_for_bwd, casted_kernel.get_tensor(TensorUsage.RHS_TRANS), x.shape, kernel.shape, @@ -255,6 +282,7 @@ def _layernorm_dense_bwd_rule( dot_input_axes, # pylint: disable=unused-argument kernel_axes, batch_first, # pylint: disable=unused-argument + comm_overlaps, ctx, grad, ): @@ -304,26 +332,59 @@ def _layernorm_dense_bwd_rule( dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd ) + # If casted_ln_out has transposed data-layout, we need to untranspose it here, and then + # transpose it back after the bulk-AG. This should ideally never be necessary if the data + # layouts are handled correctly in the tensor usages. + dgrad_aux_in = None + dgrad_aux_transposed_axes = ( + *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), + *tuple(range(casted_ln_out.flatten_axis)), + ) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + dgrad_aux_in = ( + casted_ln_out.data.transpose(dgrad_aux_transposed_axes) + if casted_ln_out.data_layout == "T" + else casted_ln_out.data + ) + # NT GEMM dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), + comm_overlap=comm_overlaps.dgrad, + aux_in=dgrad_aux_in, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) - g_constracting_dim = x_constracting_dim = tuple( range(0, len(x_shape) - len(x_contracting_dims_in_fwd)) ) # TN GEMM + casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + # LHS was bulk all-gathered during DGRAD and returned as auxiliary input + casted_ln_out.data = ( + dgrad[-1].transpose(dgrad_aux_transposed_axes) + if casted_ln_out.data_layout == "T" + else dgrad[-1] + ) + # DGRAD output will need to be bulk reduce-scattered during WGRAD + dgrad = dgrad[0] + wgrad = tex.gemm( casted_ln_out, - casted_grad.get_tensor(TensorUsage.RHS), + casted_grad_rhs, dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim,))), + comm_overlap=comm_overlaps.wgrad, + aux_in=(dgrad if comm_overlaps.wgrad.is_bulk() else None), ) + if comm_overlaps.wgrad.is_bulk(): + # DGRAD was bulk reduce-scattered during WGRAD and returned as auxiliary output + dgrad = wgrad[-1] + wgrad = wgrad[0] + dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) dx, dgamma, dbeta = tex.normalization_bwd( diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 5d129aa54d..a5a087f05b 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -49,6 +49,8 @@ def layernorm_mlp( ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), batch_first: bool = True, + ffn1_comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), + ffn2_comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -80,6 +82,8 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + ffn1_comm_overlaps: A set of CommOverlapHelper objects for FFN1 FPROP, DGRAD and WGRAD. + ffn2_comm_overlaps: A set of CommOverlapHelper objects for FFN2 FPROP, DGRAD and WGRAD. batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations @@ -127,12 +131,14 @@ def layernorm_mlp( ffn2_ckpt_name, activation_type, batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -153,6 +159,8 @@ def _layernorm_mlp( ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], batch_first: bool, + ffn1_comm_overlaps: tex.CommOverlapHelperSet, + ffn2_comm_overlaps: tex.CommOverlapHelperSet, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -179,6 +187,8 @@ def _layernorm_mlp( ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) batch_first: Assume that X is batched in the first dimension. + ffn1_comm_overlaps: A set of CommOverlapHelper objects for FFN1 FPROP, DGRAD and WGRAD. + ffn2_comm_overlaps: A set of CommOverlapHelper objects for FFN2 FPROP, DGRAD and WGRAD. quantizer_sets: Tuple of quantizer sets Returns: @@ -204,6 +214,8 @@ def _layernorm_mlp( ffn2_ckpt_name, activation_type, batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, quantizer_sets, ) return output @@ -229,6 +241,8 @@ def _layernorm_mlp_fwd_rule( ffn2_ckpt_name, activation_type, batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -287,16 +301,23 @@ def _layernorm_mlp_fwd_rule( ) # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out) + # (batch..., sequence, hidden_in) x (hidden_in, hidden_out) + # NOTE: Comm+GEMM overlap can only do AG->GEMM here to all-gather a sequence-parallel layernorm + # output. dot_1_output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + comm_overlap=ffn1_comm_overlaps.fprop, ) - if dot_1_input_axes is not None and kernel_1_axes is not None: + if ( + not ffn1_comm_overlaps.fprop.is_enabled + and dot_1_input_axes is not None + and kernel_1_axes is not None + ): dot_1_output_axes = ( *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), @@ -315,7 +336,8 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True ) - casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) + if not ffn2_comm_overlaps.fprop.is_enabled: + casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True @@ -323,12 +345,16 @@ def _layernorm_mlp_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) + # NOTE: Comm+GEMM overlap can only do GEMM->RS to reduce-scatter the FFN2 output. We don't need + # an auxiliary input/output here for this because it's already handled in the custom op + # and the returned array is the final reduce-scattered result. dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + comm_overlap=ffn2_comm_overlaps.fprop, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -375,6 +401,8 @@ def _layernorm_mlp_bwd_rule( ffn2_ckpt_name, activation_type, batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, ctx, grad, ): @@ -433,24 +461,40 @@ def _layernorm_mlp_bwd_rule( # NT GEMM # (batch..., hidden_out) x (hidden_in, hidden_out) + # NOTE: The only possible comm. overlap with FFN2 DGRAD is an AG+GEMM with all-gathered + # gradient returned in the auxiliary output to be re-used in the FFN2 WGRAD GEMM. dgrad_2 = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, dimension_numbers=((g_contracting_dims_2, k_contracting_dims_2), ((x_bdim,), ())), + comm_overlap=ffn2_comm_overlaps.dgrad, ) - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - x_contracting_dims = g_contracting_dims = tuple( range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) ) # TN GEMM # (hidden, batch...,) x (hidden, batch...) + # NOTE: There is no possible comm. overlap with FFN2 WGRAD, but we need to re-use the + # all-gathered gradient returned in the auxiliary output of FFN2 DGRAD. + casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) + if ffn2_comm_overlaps.dgrad.is_enabled: + casted_grad_rhs.data = ( + dgrad_2[-1].transpose(*range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), + *range(casted_grad_rhs.flatten_axis)) + if casted_grad_rhs.data_layout == "T" + else dgrad_2[-1] + ) + dgrad_2 = dgrad_2[1] + else: + dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) + wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), + comm_overlap=ffn2_comm_overlaps.wgrad, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -473,14 +517,39 @@ def _layernorm_mlp_bwd_rule( dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd ) + # If FFN1 DGRAD is bulk all-gathering the layernorm output, but the layernorm output + # has transposed data layout, we need to un-transpose it here before the all-gather and + # transpose it again before using it in FFN1 WGRAD. Also make sure we do not already have the + # the gathered layernorm output from FPROP. + # NOTE: This transpose should not be necessary if the tensor usages work correctly! + dgrad_1_aux_in = None + ln_out_transposed_dims = ( + *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), + *tuple(range(casted_ln_out.flatten_axis)) + ) + if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_gathered_lhs: + dgrad_1_aux_in = ( + casted_ln_out.data.transpose(ln_out_transposed_dims) + if casted_ln_out.data_layout == "T" + else casted_ln_out.data + ) + # NT GEMM dgrad_1 = tex.gemm( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, dimension_numbers=((g_contracting_dims_1, k_contracting_dims_1), ((x_bdim,), ())), + comm_overlap=ffn1_comm_overlaps.dgrad, + aux_in=dgrad_1_aux_in, ) - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) + if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_gathered_lhs: + casted_ln_out.data = ( + dgrad_1[-1].transpose(ln_out_transposed_dims) + if casted_ln_out.data_layout == "T" + else dgrad_1[-1] + ) + dgrad_1 = dgrad_1[0] # TN GEMM # (hidden, batch...) x (hidden, batch...) @@ -488,8 +557,15 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), + comm_overlap=ffn1_comm_overlaps.wgrad, + aux_in=(dgrad_1 if ffn1_comm_overlaps.wgrad.is_bulk() else None), ) + if ffn1_comm_overlaps.wgrad.is_bulk(): + # FFN1 DGRAD was bulk reduce-scattered during FFN2 WGRAD and returned as auxiliary output + dgrad_1 = wgrad_1[-1] + wgrad_1 = wgrad_1[0] + dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) dx, dgamma, dbeta = tex.normalization_bwd( From aeddd66cbb92d1501eff27d929162b849f555a35 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 4 Jul 2025 07:41:50 +0000 Subject: [PATCH 22/27] Comm+GEMM overlap working with row-parallel DenseGeneral FWD/BWD Signed-off-by: Alp Dener --- .../jax/comm_overlap/flax_with_overlap.py | 252 ++++++++++ .../gemm_with_overlap.py} | 9 - .../comm_overlap/layer_prim_with_overlap.py | 382 +++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 447 +++++++++++------- transformer_engine/jax/dense.py | 14 +- transformer_engine/jax/flax/module.py | 10 +- transformer_engine/jax/layernorm_dense.py | 12 +- transformer_engine/jax/layernorm_mlp.py | 23 +- transformer_engine/jax/quantize/tensor.py | 3 +- 9 files changed, 963 insertions(+), 189 deletions(-) create mode 100644 examples/jax/comm_overlap/flax_with_overlap.py rename examples/jax/{comm_gemm_overlap/comm_gemm_overlap.py => comm_overlap/gemm_with_overlap.py} (95%) create mode 100644 examples/jax/comm_overlap/layer_prim_with_overlap.py diff --git a/examples/jax/comm_overlap/flax_with_overlap.py b/examples/jax/comm_overlap/flax_with_overlap.py new file mode 100644 index 0000000000..8a37757afb --- /dev/null +++ b/examples/jax/comm_overlap/flax_with_overlap.py @@ -0,0 +1,252 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" +import os +import argparse +from functools import partial + +from mpi4py import MPI + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from flax.linen import partitioning as nn_partitioning + +import transformer_engine.jax as te +import transformer_engine_jax as tex +from transformer_engine.jax.sharding import ( + get_padded_spec, + MeshResource, + HIDDEN_AXES, + HIDDEN_TP_AXES, + BATCH_AXES, + SEQLEN_TP_AXES, + SEQLEN_AXES, + W_NO_SHARD_AXES, + W_FSDP_AXES, + W_TP_AXES, + W_JOINED_AXES, +) +from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP +from transformer_engine.common import recipe + +# This script needs to be launched via `mpirun` with 1 process per GPU +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.clear_caches() +jax.distributed.initialize(cluster_detection_method="mpi4py") +assert jax.local_device_count() == 1, ( + f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" +) + +# Parse script arguments +_supported_layers = (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP) +TE_LAYER_MAP = dict((layer.__name__.lower(), layer) for layer in _supported_layers) +def _te_flax_layer(layer_name): + assert isinstance(layer_name, str) and layer_name.lower() in TE_LAYER_MAP + return TE_LAYER_MAP[layer_name.lower()] + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) +parser.add_argument("-np", "--num-gpus", type=int, default=numranks) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--layer-type", type=_te_flax_layer, default=DenseGeneral, + choices=TE_LAYER_MAP.keys()) +parser.add_argument("--fp8-recipe", type=str.lower, default="none", + choices=["none", "current", "delayed", "mxfp8"]) +parser.add_argument("--check-result", action="store_true") +parser.add_argument("--seed", type=int, default=42) +args = parser.parse_args() + +# FP8 recipe +fp8_recipe = None +match args.fp8_recipe: + case "current": + fp8_recipe = recipe.Float8CurrentScaling() + case "delayed": + fp8_recipe = recipe.DelayedScaling() + case "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling() + case _: + fp8_recipe = None + +# Single GPU evaluation +layer_kwargs = { "use_bias" : True } +match args.layer_type: + case DenseGeneral: + layer_kwargs.update({"features" : args.hidden_size, "name" : "proj"}) + case LayerNormDenseGeneral: + layer_kwargs.update( + { + "features" : 3 * args.hidden_size, + "return_layernorm_output" : False, + "name" : "qkv" + } + ) + case LayerNormMLP: + layer_kwargs.update( + { + "intermediate_dim" : args.activation_size, + "return_layernorm_output" : False, + "name" : "mlp" + } + ) + +rng = jax.random.PRNGKey(args.seed) +rng, params_rng = jax.random_split(rng) +init_rngs = {"params" : params_rng} + +dtype = jnp.bfloat16 +input_shape = (args.seq_length, args.hidden_size) +if not args.no_batch: + input_shape = (args.batch_size, ) + input_shape +x = jnp.random.normal(rng, input_shape, dtype=jnp.bfloat16) + +with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + model_single = partial(args.layer_type, **layer_kwargs) + params_single = model_single.init(init_rngs, x, deterministic=True) + output_single = model_single.apply(params_single, x, deterministic=True) + +# Resources and partition specs +DEVICE_DP_AXIS = "dp" +DEVICE_TP_AXIS = "tp" +mesh_shape = (args.dp_size, args.tp_size) +mesh_axes = (DEVICE_DP_AXIS, DEVICE_TP_AXIS) +mesh_resource = MeshResource( + dp_resource=DEVICE_DP_AXIS if args.no_fsdp else None, + fsdp_resource=None if args.no_fsdp else DEVICE_DP_AXIS, + tp_resource=DEVICE_TP_AXIS, +) + +INPUT_AXES = (SEQLEN_TP_AXES if args.layer_type != DenseGeneral else SEQLEN_AXES, + HIDDEN_AXES if args.layer_type != DenseGeneral else HIDDEN_TP_AXES) +INTERMEDIATE_AXES = (SEQLEN_AXES, HIDDEN_TP_AXES) +if not args.no_batch: + INPUT_AXES = (BATCH_AXES, ) + INPUT_AXES + INTERMEDIATE_AXES = (BATCH_AXES, ) + INTERMEDIATE_AXES + +LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES, ) + +KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) +BIAS_AXES_ROW_PARALLEL = (W_NO_SHARD_AXES, ) +KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) +BIAS_AXES_COL_PARALLEL = (W_TP_AXES, ) +if args.layer_type == LayerNormMLP: + KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) + BIAS_AXES_COL_PARALLEL = (W_JOINED_AXES, W_NO_SHARD_AXES) + +# Multi GPU evaluation +layer_kwargs.update({"enable_comm_overlap" : True}) +if args.layer_type in (DenseGeneral, LayerNormDenseGeneral): + layer_kwargs.update( + { + "kernel_axes" : KERNEL_AXES_COL_PARALLEL, + "bias_axes" : BIAS_AXES_COL_PARALLEL, + "comm_overlap_config" : {"method" : tex.CommOverlapMethod.RING_EXCHANGE}, + } + ) + if args.layer_type == LayerNormDenseGeneral: + layer_kwargs.update( + { + "layernorm_input_axes" : INPUT_AXES, + "scale_axes" : LN_SCALE_AXES, + "ln_bias_axes" : LN_BIAS_AXES, + "dot_input_axes" : INPUT_AXES, + } + ) +else: + layer_kwargs.update( + { + "layernorm_input_axes" : INPUT_AXES, + "scale_axes" : LN_SCALE_AXES, + "ln_bias_axes" : LN_BIAS_AXES, + "dot_1_input_axes" : INPUT_AXES, + "kernel_1_axes" : KERNEL_AXES_COL_PARALLEL, + "bias_axes_1" : BIAS_AXES_COL_PARALLEL, + "dot_2_input_axes" : INTERMEDIATE_AXES, + "kernel_2_axes" : KERNEL_AXES_ROW_PARALLEL, + "bias_axes_2" : BIAS_AXES_ROW_PARALLEL, + "dot_1_comm_overlap_config" : {"method" : tex.CommOverlapMethod.RING_EXCHANGE}, + "dot_2_comm_overlap_config" : {"method" : tex.CommOverlapMethod.RING_EXCHANGE}, + } + ) + +device_mesh = mesh_utils.create_device_mesh((args.dp_size, args.tp_size)) +mesh = Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)) +axis_rules = nn_partitioning.axis_rules( + ( + (BATCH_AXES, DEVICE_DP_AXIS), + (SEQLEN_AXES, None), + (SEQLEN_TP_AXES, DEVICE_TP_AXIS), + (HIDDEN_AXES, None), + (HIDDEN_TP_AXES, DEVICE_TP_AXIS), + (W_NO_SHARD_AXES, None), + (W_JOINED_AXES, None), + (W_FSDP_AXES, None if args.no_fsdp else DEVICE_DP_AXIS), + (W_TP_AXES, DEVICE_TP_AXIS), + ) +) +with mesh, axis_rules, te.fp8_autocast( + enabled=fp8_recipe is not None, + fp8_recipe=fp8_recipe, + mesh_resource=mesh_resource, +): + model_sharded = partial(args.layer_type, **layer_kwargs) + params_sharded = model_sharded.init(init_rngs, x, deterministic=True) + output_sharded = model_sharded.apply(params_sharded, x, deterministic=True) + +if myrank == 0: + print( + f"{myrank}: {args.layer_type.__name__} OUTPUT {output_sharded.shape}\n" + + f" Sharding: {get_padded_spec(output_sharded.sharding.spec, output_sharded.ndim)}\n", + flush=True, + ) + +if args.check_result: + output_gathered = jax.lax.with_sharding_constraint( + output_sharded, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(output_gathered) + + diff = jnp.abs(output_single - output_gathered).flatten() + if myrank == 0: + print(f"{myrank}: Global output difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(output_single.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + numerics_failed = False + if rel_err > rtol and abs_err > atol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m} " + + f"with {output_gathered.flatten()[m].item()} vs {output_single.flatten()[m].item()} " + + f"| rel. error = {rel_err} (tol = {rtol}) " + + f"| abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +tex.destroy_all_comm_overlap_buffers() diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_overlap/gemm_with_overlap.py similarity index 95% rename from examples/jax/comm_gemm_overlap/comm_gemm_overlap.py rename to examples/jax/comm_overlap/gemm_with_overlap.py index 85744f43c8..6a03976f3a 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_overlap/gemm_with_overlap.py @@ -63,7 +63,6 @@ fsdp = not args.no_fsdp input_specs = [None] * len(lhs_shape) weight_specs = [None] * len(rhs_shape) -weight_no_fsdp = weight_specs.copy() if batched: lhs_shape = [args.batch_size] + lhs_shape if fsdp: @@ -74,11 +73,9 @@ if args.comm_type == "AG": input_specs = [("dp", "zp"), "tp", None] weight_specs = ["zp", "tp"] - weight_no_fsdp = [None, "tp"] elif args.comm_type == "RS": input_specs = [("dp", "zp"), None, "tp"] weight_specs = ["tp", "zp"] - weight_no_fsdp = ["tp", None] else: mesh_shape = {"dp": args.dp_size, "tp": args.tp_size} mesh_resource = te.MeshResource( @@ -92,7 +89,6 @@ elif args.comm_type == "RS": input_specs = ["dp", None, "tp"] weight_specs = ["tp", None] - weight_no_fsdp = weight_specs else: if fsdp: mesh_shape = {"zp": args.fsdp_size, "tp": args.tp_size} @@ -103,7 +99,6 @@ elif args.comm_type == "RS": input_specs = [None, "tp"] weight_specs = ["tp", "zp"] - weight_no_fsdp = ["tp", None] else: mesh_shape = {"tp": args.tp_size} mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") @@ -113,7 +108,6 @@ elif args.comm_type == "RS": input_specs = [None, "tp"] weight_specs = ["tp", None] - weight_no_fsdp = weight_specs # Mesh setup and sharding definitions devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) @@ -121,7 +115,6 @@ no_sharding = NamedSharding(mesh, PartitionSpec(None)) input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) -weight_no_fsdp_sharding = NamedSharding(mesh, PartitionSpec(*weight_no_fsdp)) # Operand initialization key = jax.random.PRNGKey(0) @@ -171,8 +164,6 @@ def _gemm_wrapper(x, y): comm_overlap=overlap_helper, )(x, y) -rhs_no_fsdp = jax.lax.with_sharding_constraint(rhs, weight_no_fsdp_sharding) - with te.sharding.global_shard_guard(mesh_resource): output = _gemm_wrapper(lhs, rhs) diff --git a/examples/jax/comm_overlap/layer_prim_with_overlap.py b/examples/jax/comm_overlap/layer_prim_with_overlap.py new file mode 100644 index 0000000000..3f5a98d439 --- /dev/null +++ b/examples/jax/comm_overlap/layer_prim_with_overlap.py @@ -0,0 +1,382 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" +import os +import argparse +from functools import partial + +from mpi4py import MPI + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.common import recipe +from transformer_engine.jax.sharding import ( + MeshResource, + global_shard_guard, + generate_pspec, + BATCH_AXES, + SEQLEN_AXES, + SEQLEN_TP_AXES, + HIDDEN_AXES, + HIDDEN_TP_AXES, + JOINED_AXES, + W_FSDP_AXES, + W_NO_SHARD_AXES, + W_JOINED_AXES, + W_TP_AXES, +) +from transformer_engine.jax.dense import dense +from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.layernorm_mlp import layernorm_mlp +from transformer_engine.jax.cpp_extensions import CommOverlapHelper, CommOverlapHelperSet + +import transformer_engine_jax as tex + +# This script needs to be launched via `mpirun` with 1 process per GPU +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.clear_caches() +jax.distributed.initialize(cluster_detection_method="mpi4py") +assert jax.local_device_count() == 1, ( + f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" +) + +# Parse script arguments +_supported_prims = (dense, layernorm_dense, layernorm_mlp) +TE_PRIM_MAP = dict((prim.__name__.lower(), prim) for prim in _supported_prims) +def _te_layer_prim(prim_name): + assert isinstance(prim_name, str) and prim_name.lower() in TE_PRIM_MAP + return TE_PRIM_MAP[prim_name.lower()] + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=1) +parser.add_argument("-zp", "--fsdp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) +parser.add_argument("-np", "--num-gpus", type=int, default=numranks) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--layer-type", type=_te_layer_prim, default=dense, + choices=TE_PRIM_MAP.keys()) +parser.add_argument("--fp8-recipe", type=str.lower, default="none", + choices=["none", "current", "delayed", "mxfp8"]) +parser.add_argument("--check-result", action="store_true") +parser.add_argument("--seed", type=int, default=42) +args = parser.parse_args() + +# FP8 recipe +fp8_recipe = None +match args.fp8_recipe: + case "current": + fp8_recipe = recipe.Float8CurrentScaling() + case "delayed": + fp8_recipe = recipe.DelayedScaling() + case "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling() + case _: + fp8_recipe = None + +# Declare inputs +dtype = jnp.bfloat16 +input_shape = (args.seq_length, args.hidden_size) +if not args.no_batch: + input_shape = (args.batch_size, ) + input_shape +features = args.hidden_size # post-attention projection +if args.layer_type is layernorm_dense: + features *= 3 # QKV projection +kernel_shape = ( + (args.hidden_size, 1, args.activation_size) # MLP FFN1 + if args.layer_type is layernorm_mlp + else (args.hidden_size, features) +) +bias_shape = (1, args.activation_size) if args.layer_type is layernorm_mlp else (features, ) + +rng = jax.random.PRNGKey(args.seed) +rng, params_rng = jax.random.split(rng) +params_rng, kernel_rng = jax.random.split(params_rng) +params_rng, bias_rng = jax.random.split(params_rng) +x = jax.random.normal(rng, input_shape, dtype=jnp.bfloat16) + +gamma = beta = None +if args.layer_type in (layernorm_dense, layernorm_mlp): + params_rng, gamma_rng = jax.random.split(params_rng) + gamma = jax.random.normal(gamma_rng, (args.hidden_size, ), dtype=jnp.bfloat16) + params_rng, beta_rng = jax.random.split(params_rng) + beta = jax.random.normal(beta_rng, (args.hidden_size, ), dtype=jnp.bfloat16) + +kernel_1 = jax.random.normal(kernel_rng, kernel_shape, dtype=jnp.bfloat16) +bias_1 = jax.random.normal(bias_rng, bias_shape, dtype=jnp.bfloat16) + +kernel_2 = bias_2 = None +if args.layer_type is layernorm_mlp: + kernel_rng, kernel_2_rng = jax.random.split(kernel_rng) + kernel_2 = jax.random.normal(kernel_2_rng, (args.activation_size, args.hidden_size), + dtype=jnp.bfloat16) + bias_rng, bias_2_rng = jax.random.split(bias_rng) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_size, ), dtype=jnp.bfloat16) + +if myrank == 0: + print(f"[{myrank}|{numranks}] {args.layer_type.__name__} inputs:\n" + + f" x: {x.shape}\n" + + f" gamma: {gamma.shape if gamma is not None else None}\n" + + f" beta: {beta.shape if beta is not None else None}\n" + + f" kernel_1: {kernel_1.shape}\n" + + f" bias_1: {bias_1.shape}\n" + + f" kernel_2: {kernel_2.shape if kernel_2 is not None else None}\n" + + f" bias_2: {bias_2.shape if bias_2 is not None else None}\n") + +# Single GPU evaluation +def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_): + layer_args = [] + layer_kwargs = {} + + if layer_type_ is dense: + layer_args = (x_, kernel_1_, bias_1_) + layer_kwargs = { + "contracting_dims" : ((x.ndim - 1, ), (0, )) + } + + elif layer_type_ is layernorm_dense: + layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) + + elif layer_type_ is layernorm_mlp: + layer_args = (x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_) + + return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) + +with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + fwd_bwd_serial = jax.jit(jax.value_and_grad(partial(_eval_layer_serial, args.layer_type), + argnums=range(7))) + output_serial, grads_serial = fwd_bwd_serial(x, gamma, beta, kernel_1, bias_1, kernel_2, bias_2) + +# Device mesh and logical axis resources +DEVICE_FSDP_AXIS = "zp" +DEVICE_DP_AXIS = "dp" +DEVICE_TP_AXIS = "tp" +mesh_shape = { + DEVICE_TP_AXIS: args.tp_size +} +if not args.no_batch: + mesh_shape[DEVICE_DP_AXIS] = args.dp_size +if not args.no_fsdp: + mesh_shape[DEVICE_FSDP_AXIS] = args.fsdp_size +devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) +mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +mesh_resource = MeshResource( + dp_resource=None if args.no_batch else DEVICE_DP_AXIS, + fsdp_resource=None if args.no_fsdp else DEVICE_FSDP_AXIS, + tp_resource=DEVICE_TP_AXIS, +) +if myrank == 0: + print(f"[{myrank}|{numranks}] Device mesh: {mesh}\n") + +# Logical axes +INPUT_AXES = (SEQLEN_AXES if args.layer_type is dense else SEQLEN_TP_AXES, + HIDDEN_TP_AXES if args.layer_type is dense else HIDDEN_AXES) +INTERMEDIATE_AXES = (JOINED_AXES, HIDDEN_TP_AXES) +if not args.no_batch: + INPUT_AXES = (BATCH_AXES, ) + INPUT_AXES + INTERMEDIATE_AXES = (BATCH_AXES, ) + INTERMEDIATE_AXES + +LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES, ) + +KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) +BIAS_AXES_ROW_PARALLEL = (W_NO_SHARD_AXES, ) +KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) +BIAS_AXES_COL_PARALLEL = (W_TP_AXES, ) +if args.layer_type is layernorm_mlp: + KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) + BIAS_AXES_COL_PARALLEL = (W_JOINED_AXES, W_TP_AXES) + +KERNEL_1_AXES = KERNEL_AXES_ROW_PARALLEL if args.layer_type is dense else KERNEL_AXES_COL_PARALLEL +BIAS_1_AXES = BIAS_AXES_ROW_PARALLEL if args.layer_type is dense else BIAS_AXES_COL_PARALLEL +KERNEL_2_AXES = KERNEL_AXES_ROW_PARALLEL if args.layer_type is layernorm_mlp else None +BIAS_2_AXES = BIAS_AXES_ROW_PARALLEL if args.layer_type is layernorm_mlp else None + +# Multi GPU evaluation +def _eval_layer_sharded( + layer_type_, + comm_overlaps_, + x_, + gamma_, + beta_, + kernel_1_, + bias_1_, + kernel_2_, + bias_2_, +): + layer_args = [] + layer_kwargs = {} + + if layer_type_ is dense: + layer_args = (x_, kernel_1_, bias_1_) + layer_kwargs = { + "input_axes" : INPUT_AXES, + "kernel_axes" : KERNEL_AXES_ROW_PARALLEL, + "comm_overlaps": comm_overlaps_[0], + "contracting_dims" : ((x.ndim - 1, ), (0, )) + } + + elif layer_type_ is layernorm_dense: + layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) + layer_kwargs = { + "layernorm_input_axes" : INPUT_AXES, + "dot_input_axes": INPUT_AXES, + "kernel_axes" : KERNEL_AXES_COL_PARALLEL, + "comm_overlaps": comm_overlaps_[0], + } + + elif layer_type_ is layernorm_mlp: + layer_args = (x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_) + layer_kwargs = { + "norm_input_axes" : INPUT_AXES, + "dot_1_input_axes" : INPUT_AXES, + "kernel_1_axes" : KERNEL_AXES_COL_PARALLEL, + "dot_2_input_axes" : INTERMEDIATE_AXES, + "kernel_2_axes" : KERNEL_AXES_ROW_PARALLEL, + "ffn1_comm_overlaps" : comm_overlaps_[0], + "ffn2_comm_overlaps" : comm_overlaps_[1], + } + + return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) + +with mesh, global_shard_guard(mesh_resource), te.fp8_autocast( + enabled=fp8_recipe is not None, + fp8_recipe=fp8_recipe, + mesh_resource=mesh_resource, +): + # Comm+GEMM overlap configs + # NOTE: Need to set `tp_resource=` kwarg when *not* initializing under a `global_shard_guard()`. + # Also need `logical_tp_axis=` and `logical_sp_axis=` kwargs if they differ from TE's + # built-in logical axis names. + buffer_shape = list(input_shape).copy() + if not args.no_batch: + buffer_shape[0] = buffer_shape[0] // (args.dp_size * args.fsdp_size) + fprop_1_overlap = CommOverlapHelper( + comm_type=tex.CommOverlapType.RS if args.layer_type is dense else tex.CommOverlapType.AG, + method=tex.CommOverlapMethod.RING_EXCHANGE, + buffer_shape=buffer_shape, + ) + comm_overlaps = [ + CommOverlapHelperSet(fprop=fprop_1_overlap) + ] + if args.layer_type is layernorm_mlp: + fprop_2_overlap = CommOverlapHelper( + comm_type=tex.CommOverlapType.RS, + method=tex.CommOverlapMethod.RING_EXCHANGE, + buffer_shape=buffer_shape, + ) + comm_overlaps.append(CommOverlapHelperSet(fprop=fprop_2_overlap)) + + x_sharding = NamedSharding(mesh, generate_pspec(INPUT_AXES)) + x = jax.device_put(x, x_sharding) + + gamma_sharding = beta_sharding = None + if gamma is not None: + gamma_sharding = NamedSharding(mesh, generate_pspec(LN_SCALE_AXES)) + gamma = jax.device_put(gamma, gamma_sharding) + if beta is not None: + beta_sharding = NamedSharding(mesh, generate_pspec(LN_BIAS_AXES)) + beta = jax.device_put(beta, beta_sharding) + + kernel_1_sharding = NamedSharding(mesh, generate_pspec(KERNEL_1_AXES)) + bias_1_sharding = NamedSharding(mesh, generate_pspec(BIAS_1_AXES)) + + kernel_2_sharding = bias_2_sharding = None + if kernel_2 is not None: + kernel_2_sharding = NamedSharding(mesh, generate_pspec(KERNEL_2_AXES)) + kernel_2 = jax.device_put(kernel_2, kernel_2_sharding) + if bias_2 is not None: + bias_2_sharding = NamedSharding(mesh, generate_pspec(BIAS_2_AXES)) + bias_2 = jax.device_put(bias_2, bias_2_sharding) + + input_shardings = ( + x_sharding, + gamma_sharding, + beta_sharding, + kernel_1_sharding, + bias_1_sharding, + kernel_2_sharding, + bias_2_sharding, + ) + output_shardings = ( + NamedSharding(mesh, PartitionSpec()), + input_shardings, + ) + value_and_grad_sharded = jax.jit( + jax.value_and_grad( + partial(_eval_layer_sharded, args.layer_type,comm_overlaps), + argnums=range(7) + ), + in_shardings=input_shardings, + out_shardings=output_shardings, + ) + + output_sharded, grads_sharded = value_and_grad_sharded( + x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2 + ) + +if args.check_result: + diff = jnp.abs(output_serial - output_sharded) + if myrank == 0: + print(f"[{myrank}|{numranks}] Output: serial = {output_serial} | sharded = {output_sharded}") + rel_err = diff / max(abs(diff), 1e-5) + if rel_err > 0.02 and diff > 0.001: + if myrank == 0: + print("NUMERICAL CHECK_FAILED: Output not close enough!\n") + else: + if myrank == 0: + print("NUMERICAL CHECK PASSED\n") + + labels = ("dX", "dGamma", "dBeta", "dKernel_1", "dBias_1", "dKernel_2", "dBias_2") + for i, (serial, sharded) in enumerate(zip(grads_serial, grads_sharded)): + if serial is not None and sharded is not None: + if myrank == 0: + print(f"[{myrank}|{numranks}] {labels[i]} : {sharded.shape}\n" + + f" Sharding: {sharded.sharding.spec}\n") + gathered = jax.lax.with_sharding_constraint( + sharded, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(gathered) + diff = jnp.abs(serial - gathered).flatten() + if myrank == 0: + print(f"{myrank}: Global {labels[i]} difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(output_serial.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + if rel_err > rtol and abs_err > atol: + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"{labels[i]} not close enough at index {m} " + + f"with {gathered.flatten()[m].item()} vs {serial.flatten()[m].item()} " + + f"| rel. error = {rel_err} (tol = {rtol}) " + + f"| abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +tex.destroy_all_comm_overlap_buffers() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 8426cfdc0b..03131e5d92 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -35,7 +35,13 @@ remove_padding_from_scale_inv, ) from .misc import get_padded_spec, jax_dtype_to_te_dtype -from ..sharding import global_mesh_resource +from ..sharding import ( + global_mesh_resource, + get_mesh_axis_size, + generate_pspec, + W_TP_AXES, + SEQLEN_TP_AXES, +) __all__ = [ @@ -162,18 +168,20 @@ class CommOverlapHelper: the GemmPrimitive. """ # Core init arguments - method: tex.CommOverlapMethod = field(default=tex.CommOverlapMethod.NONE) comm_type: tex.CommOverlapType = field(default=tex.CommOverlapType.NONE) + method: tex.CommOverlapMethod = field(default=tex.CommOverlapMethod.NONE) buffer_shape: Sequence[int] = field(default=None) buffer_dtype: jnp.dtype = field(default=jnp.bfloat16) - tp_size: int = field(default=None) + tp_size: int = field( + default_factory=lambda:get_mesh_axis_size(global_mesh_resource().tp_resource) + ) # Userbuffers bootstrap kwargs num_splits: int = field(default=None, kw_only=True) num_max_streams: int = field(default=3, kw_only=True) comm_cga_size: int = field(default=None, kw_only=True) - gemm_priority: int = field(default=None, kw_only=True) - comm_priority: int = field(default=None, kw_only=True) + gemm_priority: int = field(default=CUDA_STREAM_PRIORITY_LOWEST, kw_only=True) + comm_priority: int = field(default=CUDA_STREAM_PRIORITY_HIGHEST, kw_only=True) num_comm_sm: int = field(default=None, kw_only=True) set_sm_margin: bool = field(default=None, kw_only=True) use_ce: bool = field(default=None, kw_only=True) @@ -182,14 +190,15 @@ class CommOverlapHelper: aggregate_ag: bool = field(default=False, kw_only=True) # Other kwargs not passed to Userbuffers - tp_resource: str = field(default=None, kw_only=True) - sp_resource: str = field(default=None, kw_only=True) + tp_resource: str = field(default_factory=lambda:global_mesh_resource().tp_resource) + logical_tp_axis: str = field(default=W_TP_AXES, kw_only=True) + logical_sp_axis: str = field(default=SEQLEN_TP_AXES, kw_only=True) output_all_gathered_lhs: bool = field(default=False, kw_only=True) flatten_axis: int = field(default=-1, kw_only=True) # Internal attributes - is_enabled: bool = field(default=False, init=False, compare=True) - unique_id: int = field(default=None, init=False, compare=False) + is_enabled: bool = field(default=False, init=False) + unique_id: int = field(default=-1, init=False, compare=False) sharded_impl: bool = field(default=False, init=False, compare=False) gather_dim: int = field(default=-2, init=False, compare=False) scatter_dim: int = field(default=-2, init=False, compare=False) @@ -202,25 +211,38 @@ def __post_init__(self): CUDA_STREAM_PRIORITY_LOWEST, CUDA_STREAM_PRIORITY_HIGHEST, ) = tex.get_stream_priority_range() + if self.gemm_priority is None: + object.__setattr__(self, "gemm_priority", CUDA_STREAM_PRIORITY_LOWEST) + if self.comm_priority is None: + object.__setattr__(self, "comm_priority", CUDA_STREAM_PRIORITY_HIGHEST) - object.__setattr__(self, "is_enabled", self.method != tex.CommOverlapMethod.NONE) - if self.is_enabled: - assert self.buffer_shape is not None, ( - f"CommOverlapHelper: {self.buffer_shape} is not a valid buffer shape." - ) - assert self.comm_type != tex.CommOverlapType.NONE, ( + if self.method != tex.CommOverlapMethod.NONE or self.comm_type != tex.CommOverlapType.NONE: + assert self.method != tex.CommOverlapMethod.NONE, ( f"CommOverlapHelper: {self.comm_type} is not a valid collective type for " f"{self.method}." ) + assert self.comm_type != tex.CommOverlapType.NONE, ( + f"CommOverlapHelper: {self.method} is not a valid overlap method for " + f"{self.comm_type}." + ) + assert self.buffer_shape is not None and len(self.buffer_shape) >= 2, ( + f"CommOverlapHelper: {self.buffer_shape} is not a valid buffer shape." + ) + assert self.tp_resource is not None, ( + "CommOverlapHelper: Communication + GEMM overlap requires a valid TP resource. " + "This must either be specified via the `tp_resource=` keyword, or " + "`CommOverlapHelper` needs to be initialized under a " + "`te.sharding.global_shard_guard()` using a `te.sharding.MeshResource()` with a " + "valid tensor-parallel mesh axis name." + ) assert self.tp_size % 2 == 0, ( - "CommOverlapHelper: Tensor-parallel axis size must be divisible by 2, got " - f"{self.tp_size}." + f"CommOverlapHelper: Tensor-parallel axis of {self.tp_size} is not divisible by 2." ) if not self.is_bulk() and not self.is_p2p(): # Pipelined overlap is only for reduce-scatter - assert self.comm_type != tex.CommOverlapType.AG, ( - f"CommOverlapHelper: {self.comm_type} is not a valid collective type for " - f"{self.method}." + assert not self.is_all_gather(), ( + f"CommOverlapHelper: {self.method} is not a valid overlap method for " + f"{self.comm_type}." ) # Collapse buffer shape to 2D @@ -251,20 +273,11 @@ def __post_init__(self): object.__setattr__(self, "set_sm_margin", not self.is_p2p()) if self.use_ce is None: object.__setattr__(self, "use_ce", self.is_p2p()) - if self.gemm_priority is None: - object.__setattr__(self, "gemm_priority", CUDA_STREAM_PRIORITY_LOWEST) - if self.comm_priority is None: - object.__setattr__(self, "comm_priority", CUDA_STREAM_PRIORITY_HIGHEST) - - # Update mesh resources for tensor- and sequence-parallel dimensions - if self.tp_resource is None: - object.__setattr__(self, "tp_resource", global_mesh_resource().tp_resource) - if self.sp_resource is None: - object.__setattr__(self, "sp_resource", global_mesh_resource().cp_resource) # Allocate the communication buffer args, kwargs = self.get_bootstrap_args_kwargs() object.__setattr__(self, "unique_id", tex.create_comm_overlap_buffer(*args, **kwargs)) + object.__setattr__(self, "is_enabled", True) def _set_sharded_impl(self, value): assert isinstance(value, bool) @@ -512,13 +525,13 @@ def _get_no_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu ) def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): - assert self.sp_resource in aux_in_specs, ( + assert self.tp_resource in aux_in_specs, ( "CommOverlapHelper: Auxiliary input for bulk all-gather overlap is not sharded " - f"over the sequence-parallel mesh resource {self.sp_resource} in any dimension." + f"over the tensor-parallel mesh resource {self.tp_resource} in any dimension." ) aux_out_specs = (None, ) - bulk_comm_dim = aux_in_specs.index(self.sp_resource) + bulk_comm_dim = aux_in_specs.index(self.tp_resource) aux_in_specs_batch = aux_in_specs[ : bulk_comm_dim] aux_in_specs_tail = aux_in_specs[bulk_comm_dim + 1: ] if self.is_all_gather(): @@ -540,7 +553,7 @@ def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_ self._set_scatter_dim(bulk_comm_dim) aux_out_specs = ( *aux_in_specs_batch, - self.sp_resource, + self.tp_resource, *[None for _ in range(len(aux_in_specs_tail))], ) @@ -566,16 +579,16 @@ def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu (lhs_lspec, _), _ = self._check_operand_specs( lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) ) - assert lhs_lspec == self.sp_resource, ( + assert lhs_lspec == self.tp_resource, ( "CommOverlapHelper: Non-batch leading dimension of the LHS operand for AG->GEMM " - f"overlap must be sharded over the sequence-parallel mesh resource {self.sp_resource}, " + f"overlap must be sharded over the tensor-parallel mesh resource {self.tp_resource}, " f"but got {lhs_lspec} sharding instead." ) # AG->GEMM overlap: Require non-batched contracting dimensions to be unsharded (e.g. FSDP) # LHS: (B, M, None) - # RHS: (N, None) - # OUT: (B, M, None) --(UB-AG)-> (B, None, None) x (N, None)^T = (B, None, N) + # RHS: (None, N) + # OUT: (B, M, None) --(AG)-> (B, None, None) x (None, N) = (B, None, N) lhs_specs = tuple( None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] for i in range(lhs_ndim) ) @@ -614,8 +627,11 @@ def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): contracting_dims, batch_dims = dimension_numbers lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) - lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( - sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + lhs_cdims, rhs_cdims = map( + sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims + ) + lhs_bdims, rhs_bdims = map( + sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims ) (_, lhs_cspec), (_, rhs_cspec) = self._check_operand_specs( @@ -628,9 +644,9 @@ def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimensio ) # GEMM->RS overlap: Require non-contracting non-batch dimensions to be unsharded (e.g. FSDP) - # LHS: (B, M, K) --(XLA-AG)-> (B, None, K) - # RHS: (N, K) --(XLA-AG)-> (None, K) - # OUT: (B, None, K) x (B, None, K) = (B, None, None) --(UB-RS)-> (B, M, None) + # LHS: (B, None, K) + # RHS: (K, None) + # OUT: (B, None, K) x (K, None) = (B, None, None) --(UB-RS)-> (B, M, None) lhs_specs = tuple( None if i not in lhs_bdims + lhs_cdims else lhs_specs[i] for i in range(lhs_ndim) ) @@ -640,24 +656,25 @@ def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimensio # GEMM output is the internal communication buffer, but we will use the XLA output buffer # as the final reduce-scattered output so we shard it accordingly here. - lhs_specs_scattered = list(lhs_specs).copy() - for i in range(lhs_ndim): - if i not in lhs_bdims: - # Update only the first non-batch leading dimension to the TP resource - lhs_specs_scattered[i] = self.tp_resource - break - lhs_non_cspecs_scattered = tuple( - lhs_specs_scattered[i] for i in range(lhs_ndim) if i not in lhs_cdims + lhs_bspecs = tuple( + lhs_specs[i] for i in range(lhs_ndim) if i in lhs_bdims and i not in lhs_cdims + ) + lhs_lspecs = tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims ) rhs_non_cspecs = tuple( rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims ) - out_specs = (*lhs_non_cspecs_scattered, *rhs_non_cspecs) + out_specs = ( + *lhs_bspecs, + self.tp_resource, + *[None for _ in range(len(lhs_lspecs) - 1)], + *rhs_non_cspecs + ) self._set_scatter_dim(out_specs.index(self.tp_resource)) - # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_non_cspecs_scattered) : ] + bias_specs = out_specs[len(lhs_bspecs) + len(lhs_lspecs) : ] gelu_specs = out_specs return ( @@ -683,6 +700,91 @@ def get_partitioning_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_n } return impl_map[self.comm_type](lhs_specs, rhs_specs, aux_in_specs, dimension_numbers) + def get_logical_grad_axes(self, lhs_axes, rhs_axes, dimension_numbers): + """ + Combine LHS and RHS operand logical axis names in the forward pass into the gradient's + logical axes in the backward pass. + """ + if not lhs_axes or not rhs_axes: + assert not lhs_axes and not rhs_axes, ( + "CommOverlapHelper: Logical axes must either be defined or not defined for both " + "forward operands." + ) + return None + + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_axes, rhs_axes)) + lhs_cdims, rhs_cdims = map( + sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims + ) + lhs_bdims, rhs_bdims = map( + sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims + ) + + lhs_batch_axes = tuple( + lhs_axes[i] for i in range(lhs_ndim) if i in lhs_bdims and i not in lhs_cdims + ) + lhs_leading_axes = tuple( + lhs_axes[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims + ) + rhs_non_contracting_axes = tuple( + rhs_axes[i] for i in range(rhs_ndim) if i not in rhs_cdims + ) + + grad_axes = (*lhs_batch_axes, *lhs_leading_axes, *rhs_non_contracting_axes) + if self.is_enabled and not self.is_bulk(): + if self.is_all_gather(): + grad_axes = ( + *lhs_batch_axes, + *[None for _ in range(len(lhs_leading_axes))], + *rhs_non_contracting_axes, + ) + elif self.is_reduce_scatter(): + grad_axes = ( + *lhs_batch_axes, + self.logical_sp_axis, + *[None for _ in range(len(lhs_leading_axes) - 1)], + *[None for _ in range(len(rhs_non_contracting_axes))], + ) + else: + # Generate grad axes without any communication overlap + lhs_specs = generate_pspec(lhs_axes) + lhs_lspec = tuple( + lhs_specs[i] for i in range(lhs_ndim) + if i not in lhs_bdims + lhs_cdims and lhs_specs[i] is not None + ) + lhs_lspec = None if len(lhs_lspec) == 0 else lhs_lspec[0] + lhs_cspec = tuple( + lhs_specs[i] for i in lhs_cdims if i not in lhs_bdims and lhs_specs[i] is not None + ) + lhs_cspec = None if len(lhs_cspec) == 0 else lhs_cspec[0] + + rhs_specs = generate_pspec(rhs_axes) + rhs_lspec = tuple( + rhs_specs[i] for i in range(rhs_ndim) + if i not in rhs_bdims + rhs_cdims and rhs_specs[i] is not None + ) + rhs_lspec = None if len(rhs_lspec) == 0 else rhs_lspec[0] + rhs_cspec = tuple( + rhs_specs[i] for i in rhs_cdims if i not in rhs_bdims and rhs_specs[i] is not None + ) + rhs_cspec = None if len(rhs_cspec) == 0 else rhs_cspec[0] + + if not ( + lhs_cspec is not None + and lhs_cspec == rhs_cspec + and lhs_lspec is not None + and lhs_lspec == rhs_lspec + ): + # Trailing dimension is not scattered (i.e. not doing jax.lax.psum_scatter) + grad_axes = ( + *lhs_batch_axes, + *lhs_leading_axes, + *[None for _ in range(len(rhs_non_contracting_axes))] + ) + + return grad_axes + @dataclass(frozen=True) class CommOverlapHelperSet: @@ -695,147 +797,162 @@ class CommOverlapHelperSet: wgrad: CommOverlapHelper = field(default=None) def _sanity_check(self): - if not self.fprop.is_enabled: - assert self.dgrad is None or not self.dgrad.is_enabled, ( - "CommOverlapHelperSet: Comm+GEMM overlap for DGRAD requires comm+GEMM overlap " - "for FPROP to be enabled first." - ) - assert self.wgrad is None or not self.wgrad.is_enabled, ( - "CommOverlapHelperSet: Comm+GEMM overlap for WGRAD requires comm+GEMM overlap " - "for FPROP to be enabled first." + # Require any argument that exists to be a `CommOverlapHelper` instance + for overlap, name in zip((self.fprop, self.dgrad, self.wgrad), ("fprop", "dgrad", "wgrad")): + if overlap is not None: + assert isinstance(overlap, CommOverlapHelper), ( + f"CommOverlapHelperSet: Expected `{name}` to be a {CommOverlapHelper} but got " + f"{type(overlap)} instead." + ) + + # If FPROP overlap is not defined or not enabled, require DGRAD and WGRAD to also not be + # be defined or not enabled + if self.fprop is None or not self.fprop.is_enabled: + assert ( + (self.dgrad is None or not self.dgrad.is_enabled) + and (self.wgrad is None or not self.wgrad.is_enabled) + ), ( + "CommOverlapHelperSet: Cannot do communication overlap for DGRAD and/or WGRAD when " + "there is no communication overlap for FPROP." ) return assert not self.fprop.is_bulk(), ( - "CommOverlapHelperSet: Comm+GEMM overlap for FPROP does not support bulk collectives." + "CommOverlapHelperSet: Cannot overlap bulk collectives with FPROP." ) if self.fprop.is_all_gather(): - if self.dgrad is not None: - if self.fprop.output_all_gathered_lhs: - assert not self.dgrad.is_enabled, ( - "CommOverlapHelperSet: AG->GEMM FPROP does not have a corresponding DGRAD " - "overlap when it is configured to return a copy of the all-gathered LHS " - "operand as the auxiliary output." - ) - - elif self.dgrad.is_enabled: - assert ( - (self.dgrad.is_bulk() and self.dgrad.is_all_gather()) - or (not self.dgrad_is_bulk() and self.dgrad.is_reduce_scatter()) - ), ( - "CommOverlapHelperSet: AG->GEMM FPROP requires DGRAD overlap to be either " - "BULK-AG or GEMM->RS." + if self.dgrad is not None and self.dgrad.is_enabled: + if self.dgrad.is_bulk() and self.dgrad.is_all_gather(): + assert not self.fprop.output_all_gathered_lhs, ( + "CommOverlapHelperSet: AG->GEMM FPROP does not support BULK-AG overlap for " + "DGRAD when the all-gathered LHS is already saved in the forward pass." ) - - if self.wgrad is not None: - if ( - self.dgrad is not None - and self.dgrad.is_enabled - and self.dgrad.is_bulk() # not checking all-gather because we enforced it above - ): assert ( - self.wgrad.is_enabled + self.wgrad is not None + and self.wgrad.is_enabled and self.wgrad.is_bulk() and self.wgrad.is_reduce_scatter() ), ( - "CommOverlapHelperSet: AG->GEMM FPROP with BULK-AG DGRAD requires " - "WGRAD to overlap with BULK-RS." + "CommOverlapHelperSet: AG->GEMM FPROP with BULK-AG overlap for DGRAD " + "requires BULK-RS overlap for WGRAD." ) + + elif not self.dgrad.is_bulk() and self.dgrad.is_reduce_scatter(): + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: AG->GEMM FPROP with GEMM->RS DGRAD does not support " + "communication overlap for WGRAD." + ) + else: - assert not self.wgrad.is_enabled, ( - "CommOverlapHelperSet: AG->GEMM FPROP does not have a corresponding WGRAD " - "overlap when DGRAD does not overlap with BULK-AG." + raise AssertionError( + "CommOverlapHelperSet: AG->GEMM FPROP requires communication overlap for " + "DGRAD to be either BULK-AG or GEMM->RS." ) + else: + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: AG->GEMM FPROP with no communication overlap for DGRAD" + "does not support communication overlap for WGRAD." + ) elif self.fprop.is_reduce_scatter(): if self.dgrad is not None and self.dgrad.is_enabled: assert not self.dgrad.is_bulk() and self.dgrad.is_all_gather(), ( - "CommOverlapHelperSet: GEMM->RS overlap in FPROP requires DGRAD overlap to " - "be AG->GEMM." + "CommOverlapHelperSet: GEMM->RS FPROP requires communication overlap for DGRAD " + "to be AG->GEMM." ) - if self.wgrad is not None: - assert not self.wgrad.is_enabled, ( - "CommOverlapHelperSet: GEMM->RS overlap in FPROP does not have a " - "corresponding WGRAD overlap." - ) + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: GEMM->RS FPROP does not support communication overlap " + "for WGRAD." + ) + + else: + raise RuntimeError( + "CommOverlapHelperSet: Internal TE error, unrecognized collective type " + f"{self.fprop.comm_type} in communication overlap for FPROP." + ) def __post_init__(self): + self._sanity_check() + if self.fprop is None: object.__setattr__(self, "fprop", CommOverlapHelper()) - object.__setattr__(self, "dgrad", CommOverlapHelper()) - object.__setattr__(self, "wgrad", CommOverlapHelper()) - self._sanity_check() + # Column-parallel layers: QKV projection and MLP FFN1 + # FPROP with AG->GEMM: + # LHS:(B, M, None)--(AG)->(B, None, None) x RHS:(None, N) = OUT:(B, None, N) + # DGRAD w/ BULK-AG for LHS: + # GRAD:(B, None, N) x RHS:(None, N)^T = DGRAD:(B, None, None) + # LHS:(B, M, None)--(BULK-AG)->(B, None, None) + # WGRAD w/ BULK-RS for DGRAD: + # LHS:(B, None, None)^T x GRAD:(B, None, N) = WGRAD:(None, N) + # DGRAD:(B, None, None)--(BULK-RS)->(B, M, None) + # + # Row-parallel layers: Post-attention projection and MLP FFN2 + # FPROP with GEMM->RS: + # LHS:(B, None, K) x RHS:(K, None) = (B, None, None)--(RS)->(B, M, None) + # DGRAD with AG->GEMM (all-gathered GRAD saved for WGRAD): + # GRAD:(B, M, None)--(AG)->(B, None, None) x RHS:(K, None)^T = (B, None, K) + # WGRAD with NO OVERLAP: + # LHS:(B, None, K)^T x GRAD:(B, None, None) = (K, None) + if self.dgrad is None: + dgrad_overlap = None + + if self.fprop.is_all_gather() and not self.fprop.output_all_gathered_lhs: + # FPROP AG->GEMM and DGRAD BULK-AG for LHS if all-gathered LHS is not saved + # from FPROP + dgrad_overlap = CommOverlapHelper( + method=tex.CommOverlapMethod.BULK, + comm_type=tex.CommOverlapType.AG, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + logical_tp_axis=self.fprop.logical_tp_axis, + logical_sp_axis=self.fprop.logical_sp_axis, + ) - if self.fprop.is_enabled: - # FWD/BWD paths with overlap: - # - # 1. AG->GEMM: (B, M, None) --(LHS AG)-> (B, None, None) x (None, N) = (B, None, N) - # DGRAD + Bulk-AG: (B, None, N) x (None, N)^T = (B, None, None) - # (B, M, None) --(LHS bulk-AG)-> (B, None, None) - # WGRAD + Bulk-RS: (B, None, None)^T x (B, None, N) = (None, N) - # (B, None, None) --(DGRAD bulk RS)-> (B, M, None) - # - # 2. GEMM->RS in FPROP: (B, None, K) x (K, None) = (B, None, None) --(RS)-> (B, M, None) - # AG->DGRAD: (B, M, None) --(GRAD AG)-> (B, None, None) x (K, None)^T = (B, None, K) - # WGRAD w/ AG-GRAD from DGRAD: (B, None, K)^T x (B, None, None) = (K, None) - - if self.dgrad is None: - if self.fprop.is_all_gather() and self.fprop.output_all_gathered_lhs: - # If the AG->GEMM FPROP already saved the all-gathered LHS in the autograd - # context, we don't need to overlap a BULK-AG for it with DGRAD. - object.__setattr__(self, "dgrad", CommOverlapHelper()) + elif self.fprop.is_reduce_scatter(): + # FPROP GEMM->RS and DGRAD AG->GEMM + dgrad_overlap = CommOverlapHelper( + method=tex.CommOverlapMethod.RING_EXCHANGE, + comm_type=tex.CommOverlapType.AG, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + logical_tp_axis=self.fprop.logical_tp_axis, + logical_sp_axis=self.fprop.logical_sp_axis, + ) - else: - # Otherwise, AG->GEMM FPROP needs BULK-AG DGRAD, and GEMM->RS FPROP needs - # AG->GEMM DGRAD w/ all-gathered gradient returned as auxiliary output to be - # re-used in WGRAD. - object.__setattr__( - self, - "dgrad", - CommOverlapHelper( - method=( - tex.CommOverlapMethod.BULK - if self.fprop.is_all_gather() - else tex.CommOverlapMethod.RING_EXCHANGE - ), - comm_type=tex.CommOverlapType.AG, - buffer_shape=self.fprop.buffer_shape, - buffer_dtype=self.fprop.buffer_dtype, - tp_size=self.fprop.tp_size, - tp_resource=self.fprop.tp_resource, - sp_resource=self.fprop.sp_resource, - output_all_gathered_lhs=self.fprop.is_reduce_scatter(), - ) - ) + else: + dgrad_overlap = CommOverlapHelper() - if self.wgrad is None: - if ( - self.fprop.is_all_gather() - and self.dgrad.is_enabled - and self.dgrad.is_bulk() - and self.dgrad.is_all_gather() - ): - # If FPROP does AG->GEMM and DGRAD does BULK-AG, WGRAD needs to do a BULK-RS - object.__setattr__( - self, - "wgrad", - CommOverlapHelper( - method=tex.CommOverlapMethod.BULK, - comm_type=tex.CommOverlapType.RS, - buffer_shape=self.fprop.buffer_shape, - buffer_dtype=self.fprop.buffer_dtype, - tp_size=self.fprop.tp_size, - tp_resource=self.fprop.tp_resource, - sp_resource=self.fprop.sp_resource, - ) - ) + object.__setattr__(self, "dgrad", dgrad_overlap) - else: - # Otherwise, WGRAD does not support comm+GEMM overlap - object.__setattr__(self, "wgrad", CommOverlapHelper()) + if self.wgrad is None: + wgrad_overlap = self.wgrad + + if ( + self.fprop.is_all_gather() + and self.dgrad.is_enabled + and self.dgrad.is_bulk() + and self.dgrad.is_all_gather() + ): + # FPROP AG->GEMM, DGRAD BULK-AG for LHS and WGRAD BULK-RS for DGRAD + wgrad_overlap = CommOverlapHelper( + method=tex.CommOverlapMethod.BULK, + comm_type=tex.CommOverlapType.RS, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + logical_tp_axis=self.fprop.logical_tp_axis, + logical_sp_axis=self.fprop.logical_sp_axis, + ) + + else: + wgrad_overlap = CommOverlapHelper() + + object.__setattr__(self, "wgrad", wgrad_overlap) class GemmPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 0670a4811f..a02a38d38f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -201,6 +201,12 @@ def _dense_bwd_rule( quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) + casted_grad = with_sharding_constraint_by_logical_axes( + casted_grad, + comm_overlaps.fprop.get_logical_grad_axes( + input_axes, kernel_axes, (contracting_dims, ((x_bdim, ), ())) + ) + ) # If casted_x has transposed data-layout, we need to untranspose it here, and then transpose # it back after the bulk-AG. This should ideally never be necessary if the data layouts are @@ -210,7 +216,7 @@ def _dense_bwd_rule( *tuple(range(casted_x_lhs.flatten_axis, casted_x_lhs.ndim)), *tuple(range(casted_x_lhs.flatten_axis)), ) - if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: dgrad_aux_in = ( casted_x_lhs.data.transpose(dgrad_aux_transposed_axes) if casted_x_lhs.data_layout == "T" @@ -241,7 +247,7 @@ def _dense_bwd_rule( ) casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) - if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: # LHS was bulk all-gathered during DGRAD and returned as auxiliary input casted_x_lhs.data = ( dgrad[-1].transpose(dgrad_aux_transposed_axes) @@ -250,7 +256,7 @@ def _dense_bwd_rule( ) # DGRAD output will need to be bulk reduce-scattered during WGRAD dgrad = dgrad[0] - elif comm_overlaps.dgrad.is_all_gather() and comm_overlaps.dgrad.output_gathered_lhs: + elif comm_overlaps.dgrad.is_all_gather() and comm_overlaps.dgrad.output_all_gathered_lhs: # GRAD was all-gathered for DGRAD and a copy of the gathered GRAD is in the auxiliary output casted_grad_rhs.data = ( dgrad[-1].transpose(*range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), @@ -258,7 +264,7 @@ def _dense_bwd_rule( if casted_grad_rhs.data_layout == "T" else dgrad[-1] ) - dgrad = dgrad[1] + dgrad = dgrad[0] wgrad = tex.gemm( casted_x_lhs, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 00f7aab1a2..3f485212e7 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -599,7 +599,8 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: kernel_partitioning, enabled=self.enable_comm_overlap, config=self.comm_overlap_config, - ) + ), + batch_first=not self.transpose_batch_sequence ) if self.enable_low_rank_adaptation: @@ -888,8 +889,9 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, + batch_first=not self.transpose_batch_sequence, quantizer_set=quantizer_set, - comm_overlaps=comm_overlaps, + comm_overlaps=comm_overlaps ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -899,6 +901,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: contracting_dims=(axis, contract_ind), input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, + batch_first=not self.transpose_batch_sequence, quantizer_set=quantizer_set, comm_overlaps=comm_overlaps, ) @@ -1336,6 +1339,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=ffn1_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name, activation_type=normalized_acts, + batch_first=not self.transpose_batch_sequence, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ffn1_comm_overlaps=ffn1_comm_overlaps, ffn2_comm_overlaps=ffn2_comm_overlaps, @@ -1356,6 +1360,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, + batch_first=not self.transpose_batch_sequence, quantizer_set=ffn1_quantizer_set, comm_overlaps=ffn1_comm_overlaps, ) @@ -1367,6 +1372,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): contracting_dims=(axis, contract_ind), input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, + batch_first=not self.transpose_batch_sequence, quantizer_set=ffn1_quantizer_set, comm_overlaps=ffn1_comm_overlaps, ) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 57caaba5a8..f50759fa0a 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -322,6 +322,14 @@ def _layernorm_dense_bwd_rule( quantizer=quantizer_set.dgrad, noop_scaled_tensor=True, ) + casted_grad = with_sharding_constraint_by_logical_axes( + casted_grad, + comm_overlaps.fprop.get_logical_grad_axes( + dot_input_axes, + kernel_axes, + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) + ) + ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim g_constracting_dim = tuple( @@ -340,7 +348,7 @@ def _layernorm_dense_bwd_rule( *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), *tuple(range(casted_ln_out.flatten_axis)), ) - if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: dgrad_aux_in = ( casted_ln_out.data.transpose(dgrad_aux_transposed_axes) if casted_ln_out.data_layout == "T" @@ -362,7 +370,7 @@ def _layernorm_dense_bwd_rule( # TN GEMM casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) - if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_gathered_lhs: + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: # LHS was bulk all-gathered during DGRAD and returned as auxiliary input casted_ln_out.data = ( dgrad[-1].transpose(dgrad_aux_transposed_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index a5a087f05b..b5881cb0ff 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -443,12 +443,17 @@ def _layernorm_mlp_bwd_rule( ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets - # Since the sharding of outputs should be the same as dot_1's input - grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - casted_grad, dbias_2 = tex.quantize_dbias( grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True ) + casted_grad = with_sharding_constraint_by_logical_axes( + casted_grad, + ffn1_comm_overlaps.fprop.get_logical_grad_axes( + dot_2_input_axes, + kernel_2_axes, + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) + ) + ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim g_contracting_dims_2 = tuple( @@ -506,6 +511,14 @@ def _layernorm_mlp_bwd_rule( quantizer=ffn2_quantizer_set.dgrad, noop_scaled_tensor=True, ) + casted_dact_out = with_sharding_constraint_by_logical_axes( + casted_dact_out, + ffn1_comm_overlaps.fprop.get_logical_grad_axes( + dot_1_input_axes, + kernel_1_axes, + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) + ) + ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim @@ -527,7 +540,7 @@ def _layernorm_mlp_bwd_rule( *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), *tuple(range(casted_ln_out.flatten_axis)) ) - if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_gathered_lhs: + if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_all_gathered_lhs: dgrad_1_aux_in = ( casted_ln_out.data.transpose(ln_out_transposed_dims) if casted_ln_out.data_layout == "T" @@ -543,7 +556,7 @@ def _layernorm_mlp_bwd_rule( aux_in=dgrad_1_aux_in, ) - if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_gathered_lhs: + if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_all_gathered_lhs: casted_ln_out.data = ( dgrad_1[-1].transpose(ln_out_transposed_dims) if casted_ln_out.data_layout == "T" diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index d454f8f43a..a1ea83152b 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -201,6 +201,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return self # axis_names were given for N layout, so needs to be transpose for T layout + axis_names = logical_axis_names if self.data_layout == "T": assert self.flatten_axis > 0 assert len(logical_axis_names) == self.data.ndim @@ -209,8 +210,6 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st *logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis], ) - else: - axis_names = logical_axis_names data = with_sharding_constraint_by_logical_axes(self.data, axis_names) From 74ab64952d1eaa954badf5de0f37e786f349a7be Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 4 Jul 2025 07:56:55 +0000 Subject: [PATCH 23/27] fixed AG->GEMM overlap auxiliary output for all-gathered LHS copy Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 27 ++++++++++--------- transformer_engine/jax/cpp_extensions/gemm.py | 1 + 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 86bc05df29..76fea54c9f 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -927,6 +927,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -995,12 +1004,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -1048,16 +1051,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 03131e5d92..a20a0c98a0 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -922,6 +922,7 @@ def __post_init__(self): tp_size=self.fprop.tp_size, logical_tp_axis=self.fprop.logical_tp_axis, logical_sp_axis=self.fprop.logical_sp_axis, + output_all_gathered_lhs=True, ) else: From 1c7d5a33d7ffb7cf51319bf5945aba561a74049e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 4 Jul 2025 08:22:56 +0000 Subject: [PATCH 24/27] comm+GEMM overlap working for column-parallel layernorm_dense FWD/BWD Signed-off-by: Alp Dener --- .../jax/comm_overlap/flax_with_overlap.py | 8 ++++---- .../comm_overlap/layer_prim_with_overlap.py | 9 ++++----- transformer_engine/jax/cpp_extensions/gemm.py | 11 +++++----- transformer_engine/jax/layernorm_dense.py | 20 +++++++++---------- transformer_engine/jax/layernorm_mlp.py | 2 +- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/examples/jax/comm_overlap/flax_with_overlap.py b/examples/jax/comm_overlap/flax_with_overlap.py index 8a37757afb..84bca037da 100644 --- a/examples/jax/comm_overlap/flax_with_overlap.py +++ b/examples/jax/comm_overlap/flax_with_overlap.py @@ -44,10 +44,10 @@ # Parse script arguments _supported_layers = (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP) -TE_LAYER_MAP = dict((layer.__name__.lower(), layer) for layer in _supported_layers) +_layer_map = dict((layer.__name__.lower(), layer) for layer in _supported_layers) def _te_flax_layer(layer_name): - assert isinstance(layer_name, str) and layer_name.lower() in TE_LAYER_MAP - return TE_LAYER_MAP[layer_name.lower()] + assert isinstance(layer_name, str) and layer_name.lower() in _layer_map + return _layer_map[layer_name.lower()] parser = argparse.ArgumentParser() parser.add_argument("-dp", "--dp-size", type=int, default=2) @@ -60,7 +60,7 @@ def _te_flax_layer(layer_name): parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") parser.add_argument("--layer-type", type=_te_flax_layer, default=DenseGeneral, - choices=TE_LAYER_MAP.keys()) + choices=_supported_layers) parser.add_argument("--fp8-recipe", type=str.lower, default="none", choices=["none", "current", "delayed", "mxfp8"]) parser.add_argument("--check-result", action="store_true") diff --git a/examples/jax/comm_overlap/layer_prim_with_overlap.py b/examples/jax/comm_overlap/layer_prim_with_overlap.py index 3f5a98d439..0d6554cf77 100644 --- a/examples/jax/comm_overlap/layer_prim_with_overlap.py +++ b/examples/jax/comm_overlap/layer_prim_with_overlap.py @@ -51,10 +51,10 @@ # Parse script arguments _supported_prims = (dense, layernorm_dense, layernorm_mlp) -TE_PRIM_MAP = dict((prim.__name__.lower(), prim) for prim in _supported_prims) +_prim_map = dict((prim.__name__.lower(), prim) for prim in _supported_prims) def _te_layer_prim(prim_name): - assert isinstance(prim_name, str) and prim_name.lower() in TE_PRIM_MAP - return TE_PRIM_MAP[prim_name.lower()] + assert isinstance(prim_name, str) and prim_name.lower() in _prim_map + return _prim_map[prim_name.lower()] parser = argparse.ArgumentParser() parser.add_argument("-dp", "--dp-size", type=int, default=1) @@ -67,8 +67,7 @@ def _te_layer_prim(prim_name): parser.add_argument("--activation-size", type=int, default=53248) parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") -parser.add_argument("--layer-type", type=_te_layer_prim, default=dense, - choices=TE_PRIM_MAP.keys()) +parser.add_argument("--layer-type", type=_te_layer_prim, default=dense, choices=_supported_prims) parser.add_argument("--fp8-recipe", type=str.lower, default="none", choices=["none", "current", "delayed", "mxfp8"]) parser.add_argument("--check-result", action="store_true") diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a20a0c98a0..296c655a2d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -527,7 +527,7 @@ def _get_no_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): assert self.tp_resource in aux_in_specs, ( "CommOverlapHelper: Auxiliary input for bulk all-gather overlap is not sharded " - f"over the tensor-parallel mesh resource {self.tp_resource} in any dimension." + f"over the tensor-parallel mesh resource '{self.tp_resource}' in any dimension." ) aux_out_specs = (None, ) @@ -558,7 +558,7 @@ def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_ ) # GEMM is independent of communication so specs are as if there is no overlap - operand_specs, output_specs, xla_reduce_info = self._get_specs_no_overlap( + operand_specs, output_specs, xla_reduce_info = self._get_no_overlap_rules( lhs_specs, rhs_specs, aux_in_specs, dimension_numbers ) @@ -900,11 +900,10 @@ def __post_init__(self): dgrad_overlap = None if self.fprop.is_all_gather() and not self.fprop.output_all_gathered_lhs: - # FPROP AG->GEMM and DGRAD BULK-AG for LHS if all-gathered LHS is not saved - # from FPROP + # FPROP AG->GEMM and DGRAD GEMM->RS dgrad_overlap = CommOverlapHelper( - method=tex.CommOverlapMethod.BULK, - comm_type=tex.CommOverlapType.AG, + method=tex.CommOverlapMethod.RING_EXCHANGE, + comm_type=tex.CommOverlapType.RS, buffer_shape=self.fprop.buffer_shape, buffer_dtype=self.fprop.buffer_dtype, tp_size=self.fprop.tp_size, diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index f50759fa0a..a1e8a58e79 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -343,17 +343,17 @@ def _layernorm_dense_bwd_rule( # If casted_ln_out has transposed data-layout, we need to untranspose it here, and then # transpose it back after the bulk-AG. This should ideally never be necessary if the data # layouts are handled correctly in the tensor usages. - dgrad_aux_in = None - dgrad_aux_transposed_axes = ( + casted_ln_out_transposed_axes = ( *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), *tuple(range(casted_ln_out.flatten_axis)), ) - if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: - dgrad_aux_in = ( - casted_ln_out.data.transpose(dgrad_aux_transposed_axes) - if casted_ln_out.data_layout == "T" - else casted_ln_out.data - ) + casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) + if ( + comm_overlaps.dgrad.is_bulk() + and not comm_overlaps.fprop.output_all_gathered_lhs + and casted_ln_out.data_layout == "T" + ): + casted_ln_out.data = jnp.transpose(casted_ln_out.data, casted_ln_out_transposed_axes) # NT GEMM dgrad = tex.gemm( @@ -361,7 +361,7 @@ def _layernorm_dense_bwd_rule( casted_kernel, dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), comm_overlap=comm_overlaps.dgrad, - aux_in=dgrad_aux_in, + aux_in=casted_ln_out.data, ) g_constracting_dim = x_constracting_dim = tuple( @@ -373,7 +373,7 @@ def _layernorm_dense_bwd_rule( if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: # LHS was bulk all-gathered during DGRAD and returned as auxiliary input casted_ln_out.data = ( - dgrad[-1].transpose(dgrad_aux_transposed_axes) + dgrad[-1].transpose(casted_ln_out_transposed_axes) if casted_ln_out.data_layout == "T" else dgrad[-1] ) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index b5881cb0ff..a10bbda821 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -491,7 +491,7 @@ def _layernorm_mlp_bwd_rule( if casted_grad_rhs.data_layout == "T" else dgrad_2[-1] ) - dgrad_2 = dgrad_2[1] + dgrad_2 = dgrad_2[0] else: dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) From b4ff96146c8e8f037723291cd1b59fb324865ea3 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 4 Jul 2025 09:02:59 +0000 Subject: [PATCH 25/27] comm+GEMM overlap working with layernorm_mlp FWD/BWD Signed-off-by: Alp Dener --- .../comm_overlap/layer_prim_with_overlap.py | 10 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 20 +++++------ transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/layernorm_dense.py | 17 ++++----- transformer_engine/jax/layernorm_mlp.py | 35 +++++++++---------- 5 files changed, 41 insertions(+), 43 deletions(-) diff --git a/examples/jax/comm_overlap/layer_prim_with_overlap.py b/examples/jax/comm_overlap/layer_prim_with_overlap.py index 0d6554cf77..f7e755c2ef 100644 --- a/examples/jax/comm_overlap/layer_prim_with_overlap.py +++ b/examples/jax/comm_overlap/layer_prim_with_overlap.py @@ -150,7 +150,7 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) elif layer_type_ is layernorm_mlp: - layer_args = (x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_) + layer_args = (x_, gamma_, beta_, (kernel_1_, kernel_2_), (bias_1_, bias_2_)) return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) @@ -183,7 +183,7 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne # Logical axes INPUT_AXES = (SEQLEN_AXES if args.layer_type is dense else SEQLEN_TP_AXES, HIDDEN_TP_AXES if args.layer_type is dense else HIDDEN_AXES) -INTERMEDIATE_AXES = (JOINED_AXES, HIDDEN_TP_AXES) +INTERMEDIATE_AXES = (SEQLEN_AXES, HIDDEN_TP_AXES) if not args.no_batch: INPUT_AXES = (BATCH_AXES, ) + INPUT_AXES INTERMEDIATE_AXES = (BATCH_AXES, ) + INTERMEDIATE_AXES @@ -191,7 +191,7 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES, ) KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) -BIAS_AXES_ROW_PARALLEL = (W_NO_SHARD_AXES, ) +BIAS_AXES_ROW_PARALLEL = (W_FSDP_AXES, ) KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) BIAS_AXES_COL_PARALLEL = (W_TP_AXES, ) if args.layer_type is layernorm_mlp: @@ -237,7 +237,7 @@ def _eval_layer_sharded( } elif layer_type_ is layernorm_mlp: - layer_args = (x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_) + layer_args = (x_, gamma_, beta_, (kernel_1_, kernel_2_), (bias_1_, bias_2_)) layer_kwargs = { "norm_input_axes" : INPUT_AXES, "dot_1_input_axes" : INPUT_AXES, @@ -323,7 +323,7 @@ def _eval_layer_sharded( ) output_sharded, grads_sharded = value_and_grad_sharded( - x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2 + x, gamma, beta, kernel_1, bias_1, kernel_2, bias_2 ) if args.check_result: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 296c655a2d..86c48557bb 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -605,6 +605,7 @@ def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims ) out_specs = (*lhs_non_cspecs_gathered, *rhs_non_cspecs) + self._set_gather_dim(lhs_specs.index(lhs_lspec)) # Bias and Pre-GeLU sharding is based on GEMM output bias_specs = out_specs[len(lhs_non_cspecs_gathered) : ] @@ -614,9 +615,8 @@ def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu aux_out_specs = (None, ) if self.output_all_gathered_lhs: # Auxiliary output is the same as the LHS spec, except the gathered dimension unsharded - self._set_gather_dim(lhs_specs.index(lhs_lspec)) aux_out_specs = list(lhs_specs).copy() - aux_out_specs[lhs_specs.index(lhs_lspec)] = None + aux_out_specs[self.gather_dim] = None return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), @@ -700,10 +700,10 @@ def get_partitioning_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_n } return impl_map[self.comm_type](lhs_specs, rhs_specs, aux_in_specs, dimension_numbers) - def get_logical_grad_axes(self, lhs_axes, rhs_axes, dimension_numbers): + def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): """ - Combine LHS and RHS operand logical axis names in the forward pass into the gradient's - logical axes in the backward pass. + Compute the logical axis names for the GEMM output axes based on LHS and RHS operands' + logical axis names. """ if not lhs_axes or not rhs_axes: assert not lhs_axes and not rhs_axes, ( @@ -731,16 +731,16 @@ def get_logical_grad_axes(self, lhs_axes, rhs_axes, dimension_numbers): rhs_axes[i] for i in range(rhs_ndim) if i not in rhs_cdims ) - grad_axes = (*lhs_batch_axes, *lhs_leading_axes, *rhs_non_contracting_axes) + out_axes = (*lhs_batch_axes, *lhs_leading_axes, *rhs_non_contracting_axes) if self.is_enabled and not self.is_bulk(): if self.is_all_gather(): - grad_axes = ( + out_axes = ( *lhs_batch_axes, *[None for _ in range(len(lhs_leading_axes))], *rhs_non_contracting_axes, ) elif self.is_reduce_scatter(): - grad_axes = ( + out_axes = ( *lhs_batch_axes, self.logical_sp_axis, *[None for _ in range(len(lhs_leading_axes) - 1)], @@ -777,13 +777,13 @@ def get_logical_grad_axes(self, lhs_axes, rhs_axes, dimension_numbers): and lhs_lspec == rhs_lspec ): # Trailing dimension is not scattered (i.e. not doing jax.lax.psum_scatter) - grad_axes = ( + out_axes = ( *lhs_batch_axes, *lhs_leading_axes, *[None for _ in range(len(rhs_non_contracting_axes))] ) - return grad_axes + return out_axes @dataclass(frozen=True) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index a02a38d38f..a3f533fbef 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -203,7 +203,7 @@ def _dense_bwd_rule( ) casted_grad = with_sharding_constraint_by_logical_axes( casted_grad, - comm_overlaps.fprop.get_logical_grad_axes( + comm_overlaps.fprop.get_logical_output_axes( input_axes, kernel_axes, (contracting_dims, ((x_bdim, ), ())) ) ) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index a1e8a58e79..ca7153b346 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -324,7 +324,7 @@ def _layernorm_dense_bwd_rule( ) casted_grad = with_sharding_constraint_by_logical_axes( casted_grad, - comm_overlaps.fprop.get_logical_grad_axes( + comm_overlaps.fprop.get_logical_output_axes( dot_input_axes, kernel_axes, ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) @@ -343,17 +343,18 @@ def _layernorm_dense_bwd_rule( # If casted_ln_out has transposed data-layout, we need to untranspose it here, and then # transpose it back after the bulk-AG. This should ideally never be necessary if the data # layouts are handled correctly in the tensor usages. + dgrad_aux_in = None casted_ln_out_transposed_axes = ( *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), *tuple(range(casted_ln_out.flatten_axis)), ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) - if ( - comm_overlaps.dgrad.is_bulk() - and not comm_overlaps.fprop.output_all_gathered_lhs - and casted_ln_out.data_layout == "T" - ): - casted_ln_out.data = jnp.transpose(casted_ln_out.data, casted_ln_out_transposed_axes) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: + dgrad_aux_in = ( + casted_ln_out.data.transpose(casted_ln_out_transposed_axes) + if casted_ln_out.data_layout == "T" + else casted_ln_out.data + ) # NT GEMM dgrad = tex.gemm( @@ -361,7 +362,7 @@ def _layernorm_dense_bwd_rule( casted_kernel, dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), comm_overlap=comm_overlaps.dgrad, - aux_in=casted_ln_out.data, + aux_in=dgrad_aux_in, ) g_constracting_dim = x_constracting_dim = tuple( diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index a10bbda821..737e72be94 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -37,7 +37,7 @@ def layernorm_mlp( beta: jnp.ndarray, kernels: List[jnp.ndarray], biases: List[jnp.ndarray], - norm_type: str, + norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, norm_input_axes: Tuple[str, ...] = None, @@ -259,8 +259,6 @@ def _layernorm_mlp_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - del kernel_2_axes - ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets # x should be in shape of (batch..., hidden) @@ -299,6 +297,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_1 = tex.quantize( kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True ) + casted_kernel_1 = with_sharding_constraint_by_logical_axes(casted_kernel_1, kernel_1_axes) # NN GEMM # (batch..., sequence, hidden_in) x (hidden_in, hidden_out) @@ -312,17 +311,14 @@ def _layernorm_mlp_fwd_rule( fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, comm_overlap=ffn1_comm_overlaps.fprop, ) - - if ( - not ffn1_comm_overlaps.fprop.is_enabled - and dot_1_input_axes is not None - and kernel_1_axes is not None - ): - dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), + dot_1_output = with_sharding_constraint_by_logical_axes( + dot_1_output, + ffn1_comm_overlaps.fprop.get_logical_output_axes( + dot_1_input_axes, + kernel_1_axes, + ((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())) ) - dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) + ) if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape @@ -336,12 +332,12 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True ) - if not ffn2_comm_overlaps.fprop.is_enabled: - casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) + casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True ) + casted_kernel_2 = with_sharding_constraint_by_logical_axes(casted_kernel_2, kernel_2_axes) # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) @@ -448,7 +444,7 @@ def _layernorm_mlp_bwd_rule( ) casted_grad = with_sharding_constraint_by_logical_axes( casted_grad, - ffn1_comm_overlaps.fprop.get_logical_grad_axes( + ffn2_comm_overlaps.fprop.get_logical_output_axes( dot_2_input_axes, kernel_2_axes, ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) @@ -492,8 +488,6 @@ def _layernorm_mlp_bwd_rule( else dgrad_2[-1] ) dgrad_2 = dgrad_2[0] - else: - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) wgrad_2 = tex.gemm( casted_act_out, @@ -501,6 +495,8 @@ def _layernorm_mlp_bwd_rule( dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), comm_overlap=ffn2_comm_overlaps.wgrad, ) + + dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( @@ -513,7 +509,7 @@ def _layernorm_mlp_bwd_rule( ) casted_dact_out = with_sharding_constraint_by_logical_axes( casted_dact_out, - ffn1_comm_overlaps.fprop.get_logical_grad_axes( + ffn1_comm_overlaps.fprop.get_logical_output_axes( dot_1_input_axes, kernel_1_axes, ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) @@ -540,6 +536,7 @@ def _layernorm_mlp_bwd_rule( *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), *tuple(range(casted_ln_out.flatten_axis)) ) + casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_all_gathered_lhs: dgrad_1_aux_in = ( casted_ln_out.data.transpose(ln_out_transposed_dims) From 95564fc2313b30928bd5cd42a42c91eb7f5f7fef Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 4 Jul 2025 10:20:54 +0000 Subject: [PATCH 26/27] te.flax modules updated for comm+GEMM overlap but untested Signed-off-by: Alp Dener --- transformer_engine/jax/flax/module.py | 147 +++++++++++--------------- 1 file changed, 62 insertions(+), 85 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3f485212e7..469bfc60ea 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -34,7 +34,10 @@ CommOverlapHelperSet, ) from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode -from ..sharding import get_non_contracting_logical_axes, get_padded_spec +from ..sharding import ( + get_non_contracting_logical_axes, + global_mesh_resource, +) import transformer_engine_jax as tex @@ -153,77 +156,48 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, return output -def _generate_comm_overlap_metas( - inputs_shape: Sequence[int], +def _generate_comm_overlap_meta( + input_shape: Sequence[int], + input_axes: Sequence[str], param_shape: Sequence[int], - param_partitioning: nn.LogicallyPartitioned, - enabled: bool = True, - config: dict = {}, + param_axes: Sequence[str], + config: dict, ): - if not enabled: + method = config.pop("method", tex.CommOverlapMethod.RING_EXCHANGE) + if method == tex.CommOverlapMethod.NONE: return CommOverlapHelperSet() - param_sharding = param_partitioning.get_sharding() - param_specs = get_padded_spec(param_sharding.spec, len(param_shape)) - column_parallel = param_specs[-1] is not None - row_parallel = any(spec is not None for spec in param_specs[:-1]) - - comm_type = config.pop("comm_type", None) - if row_parallel and column_parallel: - assert comm_type is not None, ( - "Collective type for communication overlap must be explicitly set via " - "`comm_overlap_config={'comm_type' : ... }` when module parameters are " - "sharded in both contracting and non-contracting dimensions " - "(e.g. FSDP+TP sharding)." - ) - row_parallel = comm_type == tex.CommOverlapType.RS - column_parallel = comm_type == tex.CommOverlapType.AG - - mesh = param_sharding.mesh - buffer_shape = inputs_shape - tp_size = 1 - tp_resource = None - if row_parallel: - contracting_specs = tuple(spec for spec in param_specs[:-1] if spec is not None) - assert len(contracting_specs) == 1, ( - "Module parameter cannot have more than one sharded contracting dimension " - "GEMM->RS overlap is enabled." - ) - tp_resource = contracting_specs[0] - tp_size = mesh.shape[mesh.axis_names.index(tp_resource)] - comm_type = tex.CommOverlapType.RS - buffer_shape = (*inputs_shape[:-1], param_shape[-1]) - - elif column_parallel: - tp_resource = param_specs[-1] - assert tp_resource is not None, ( - "Module parameter must be sharded in the non-contracting dimension when " - "AG->GEMM overlap is enabled." - ) - tp_size = mesh.shape[mesh.axis_names.index(tp_resource)] - comm_type = tex.CommOverlapType.AG + tp_resource = config.pop( + "tp_resource", global_mesh_resource().tp_resource + ) - else: - raise AssertionError("") + input_sp_dim = list(nn.logical_to_mesh_axes(input_axes)).index(tp_resource) + logical_sp_axis = config.pop("logical_sp_axis", input_axes[input_sp_dim]) + + param_tp_dim = list(nn.logical_to_mesh_axes(param_axes)).index(tp_resource) + logical_tp_axis = config.pop("logical_tp_axis", param_axes[param_tp_dim]) + + row_parallel = param_tp_dim == 0 + comm_type = tex.CommOverlapType.RS if row_parallel else tex.CommOverlapType.AG + _ = config.pop("comm_type") + + buffer_shape = config.pop( + "buffer_shape", + (*input_shape[:-1], param_shape[-1]) if row_parallel else input_shape + ) - method = config.pop("method", tex.CommOverlapMethod.RING_EXCHANGE) - buffer_shape = config.pop("buffer_shape", buffer_shape) - buffer_dtype = config.pop("buffer_dtype", jnp.bfloat16) - tp_size = config.pop("tp_size", tp_size) - tp_resource = config.pop("tp_resource", tp_resource) return CommOverlapHelperSet( fprop=CommOverlapHelper( - method=method, comm_type=comm_type, + method=method, buffer_shape=buffer_shape, - buffer_dtype=buffer_dtype, - tp_size=tp_size, tp_resource=tp_resource, + logical_tp_axis=logical_tp_axis, + logical_sp_axis=logical_sp_axis, **config, ) ) - class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. @@ -586,6 +560,9 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: quantizer_set = self.generate_quantizer_set() contract_ind = tuple(range(0, len(axis))) + + if not self.enable_comm_overlap: + self.comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) y = dense( inputs, kernel, @@ -593,12 +570,12 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, - comm_overlaps=_generate_comm_overlap_metas( + comm_overlaps=_generate_comm_overlap_meta( inputs.shape, - kernel_shape, - kernel_partitioning, - enabled=self.enable_comm_overlap, - config=self.comm_overlap_config, + self.input_axes, + kernel.shape, + self.kernel_axes, + self.comm_overlap_method, ), batch_first=not self.transpose_batch_sequence ) @@ -867,15 +844,15 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: contract_ind = tuple(range(0, len(axis))) - if self.enable_comm_overlap: - # All-Gather is the only supported collective to overlap in LayerNormDenseGeneral - self.comm_overlap_config.update({"comm_type" : tex.CommOverlapType.AG}) - comm_overlaps = _generate_comm_overlap_metas( + # All-Gather is the only supported collective to overlap in LayerNormDenseGeneral + if not self.enable_comm_overlap: + self.comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + comm_overlaps = _generate_comm_overlap_meta( inputs.shape, + self.layernorm_input_axes, kernel_shape, - kernel_partitioning, - enabled=self.enable_comm_overlap, - config=self.comm_overlap_config, + self.kernel_axes, + self.comm_overlap_config, ) if fuse_layernorm: z = layernorm_dense( @@ -1301,25 +1278,25 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name = "ffn1" ffn2_ckpt_name = "ffn2" - if self.enable_dot_1_comm_overlap: - # All-Gather is the only supported collective to overlap with the 1st dot - self.dot_1_comm_overlap_config.update({"comm_type" : tex.CommOverlapType.AG}) - ffn1_comm_overlaps = _generate_comm_overlap_metas( + + if not self.enable_dot_1_comm_overlap: + self.dot_1_comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + ffn1_comm_overlaps = _generate_comm_overlap_meta( inputs.shape, - kernel_1_each_shape, - kernel_1_partitioning, - enabled=self.enable_dot_1_comm_overlap, - config=self.enable_dot_1_comm_overlap, + self.layernorm_input_axes, + kernel_1.shape, + self.kernel_axes_1, + self.dot_1_comm_overlap_config, ) - if self.enable_dot_2_comm_overlap: - # Reduce-Scatter is the only supported collective to overlap with the 2nd dot - self.dot_2_comm_overlap_config.update({"comm_type" : tex.CommOverlapType.RS}) - ffn2_comm_overlaps = _generate_comm_overlap_metas( + + if not self.enable_dot_2_comm_overlap: + self.dot_2_comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + ffn2_comm_overlaps = _generate_comm_overlap_meta( inputs.shape, - kernel_2_shape, - kernel_2_partitioning, - enabled=self.enable_dot_2_comm_overlap, - config=self.enable_dot_2_comm_overlap, + self.dot_2_input_axes, + kernel_2.shape, + self.kernel_axes_2, + self.dot_2_comm_overlap_config, ) if use_fused_layernorm_mlp: out = layernorm_mlp( From 3330052b0f33d549764eb446bc51839e8183f544 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Jul 2025 10:23:29 +0000 Subject: [PATCH 27/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/comm_overlap/flax_with_overlap.py | 105 ++++++----- .../jax/comm_overlap/gemm_with_overlap.py | 6 +- .../comm_overlap/layer_prim_with_overlap.py | 147 ++++++++------- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2 - .../transformer_engine/comm_gemm_overlap.h | 18 +- transformer_engine/jax/cpp_extensions/gemm.py | 170 ++++++++---------- transformer_engine/jax/csrc/extensions.h | 2 +- .../jax/csrc/extensions/gemm.cpp | 25 ++- .../jax/csrc/extensions/pybind.cpp | 3 - transformer_engine/jax/dense.py | 56 ++++-- transformer_engine/jax/flax/module.py | 33 ++-- transformer_engine/jax/layernorm_dense.py | 6 +- transformer_engine/jax/layernorm_mlp.py | 20 ++- 13 files changed, 311 insertions(+), 282 deletions(-) diff --git a/examples/jax/comm_overlap/flax_with_overlap.py b/examples/jax/comm_overlap/flax_with_overlap.py index 84bca037da..801fc0bfd7 100644 --- a/examples/jax/comm_overlap/flax_with_overlap.py +++ b/examples/jax/comm_overlap/flax_with_overlap.py @@ -38,17 +38,20 @@ numranks = MPI.COMM_WORLD.Get_size() jax.clear_caches() jax.distributed.initialize(cluster_detection_method="mpi4py") -assert jax.local_device_count() == 1, ( - f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" -) +assert ( + jax.local_device_count() == 1 +), f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" # Parse script arguments _supported_layers = (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP) _layer_map = dict((layer.__name__.lower(), layer) for layer in _supported_layers) + + def _te_flax_layer(layer_name): assert isinstance(layer_name, str) and layer_name.lower() in _layer_map return _layer_map[layer_name.lower()] + parser = argparse.ArgumentParser() parser.add_argument("-dp", "--dp-size", type=int, default=2) parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) @@ -59,10 +62,12 @@ def _te_flax_layer(layer_name): parser.add_argument("--activation-size", type=int, default=53248) parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") -parser.add_argument("--layer-type", type=_te_flax_layer, default=DenseGeneral, - choices=_supported_layers) -parser.add_argument("--fp8-recipe", type=str.lower, default="none", - choices=["none", "current", "delayed", "mxfp8"]) +parser.add_argument( + "--layer-type", type=_te_flax_layer, default=DenseGeneral, choices=_supported_layers +) +parser.add_argument( + "--fp8-recipe", type=str.lower, default="none", choices=["none", "current", "delayed", "mxfp8"] +) parser.add_argument("--check-result", action="store_true") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -80,35 +85,31 @@ def _te_flax_layer(layer_name): fp8_recipe = None # Single GPU evaluation -layer_kwargs = { "use_bias" : True } +layer_kwargs = {"use_bias": True} match args.layer_type: case DenseGeneral: - layer_kwargs.update({"features" : args.hidden_size, "name" : "proj"}) + layer_kwargs.update({"features": args.hidden_size, "name": "proj"}) case LayerNormDenseGeneral: layer_kwargs.update( - { - "features" : 3 * args.hidden_size, - "return_layernorm_output" : False, - "name" : "qkv" - } + {"features": 3 * args.hidden_size, "return_layernorm_output": False, "name": "qkv"} ) case LayerNormMLP: layer_kwargs.update( { - "intermediate_dim" : args.activation_size, - "return_layernorm_output" : False, - "name" : "mlp" + "intermediate_dim": args.activation_size, + "return_layernorm_output": False, + "name": "mlp", } ) rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random_split(rng) -init_rngs = {"params" : params_rng} +init_rngs = {"params": params_rng} dtype = jnp.bfloat16 input_shape = (args.seq_length, args.hidden_size) if not args.no_batch: - input_shape = (args.batch_size, ) + input_shape + input_shape = (args.batch_size,) + input_shape x = jnp.random.normal(rng, input_shape, dtype=jnp.bfloat16) with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): @@ -127,56 +128,58 @@ def _te_flax_layer(layer_name): tp_resource=DEVICE_TP_AXIS, ) -INPUT_AXES = (SEQLEN_TP_AXES if args.layer_type != DenseGeneral else SEQLEN_AXES, - HIDDEN_AXES if args.layer_type != DenseGeneral else HIDDEN_TP_AXES) +INPUT_AXES = ( + SEQLEN_TP_AXES if args.layer_type != DenseGeneral else SEQLEN_AXES, + HIDDEN_AXES if args.layer_type != DenseGeneral else HIDDEN_TP_AXES, +) INTERMEDIATE_AXES = (SEQLEN_AXES, HIDDEN_TP_AXES) if not args.no_batch: - INPUT_AXES = (BATCH_AXES, ) + INPUT_AXES - INTERMEDIATE_AXES = (BATCH_AXES, ) + INTERMEDIATE_AXES + INPUT_AXES = (BATCH_AXES,) + INPUT_AXES + INTERMEDIATE_AXES = (BATCH_AXES,) + INTERMEDIATE_AXES -LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES, ) +LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES,) KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) -BIAS_AXES_ROW_PARALLEL = (W_NO_SHARD_AXES, ) +BIAS_AXES_ROW_PARALLEL = (W_NO_SHARD_AXES,) KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) -BIAS_AXES_COL_PARALLEL = (W_TP_AXES, ) +BIAS_AXES_COL_PARALLEL = (W_TP_AXES,) if args.layer_type == LayerNormMLP: KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) BIAS_AXES_COL_PARALLEL = (W_JOINED_AXES, W_NO_SHARD_AXES) # Multi GPU evaluation -layer_kwargs.update({"enable_comm_overlap" : True}) +layer_kwargs.update({"enable_comm_overlap": True}) if args.layer_type in (DenseGeneral, LayerNormDenseGeneral): layer_kwargs.update( { - "kernel_axes" : KERNEL_AXES_COL_PARALLEL, - "bias_axes" : BIAS_AXES_COL_PARALLEL, - "comm_overlap_config" : {"method" : tex.CommOverlapMethod.RING_EXCHANGE}, + "kernel_axes": KERNEL_AXES_COL_PARALLEL, + "bias_axes": BIAS_AXES_COL_PARALLEL, + "comm_overlap_config": {"method": tex.CommOverlapMethod.RING_EXCHANGE}, } ) if args.layer_type == LayerNormDenseGeneral: layer_kwargs.update( { - "layernorm_input_axes" : INPUT_AXES, - "scale_axes" : LN_SCALE_AXES, - "ln_bias_axes" : LN_BIAS_AXES, - "dot_input_axes" : INPUT_AXES, + "layernorm_input_axes": INPUT_AXES, + "scale_axes": LN_SCALE_AXES, + "ln_bias_axes": LN_BIAS_AXES, + "dot_input_axes": INPUT_AXES, } ) else: layer_kwargs.update( { - "layernorm_input_axes" : INPUT_AXES, - "scale_axes" : LN_SCALE_AXES, - "ln_bias_axes" : LN_BIAS_AXES, - "dot_1_input_axes" : INPUT_AXES, - "kernel_1_axes" : KERNEL_AXES_COL_PARALLEL, - "bias_axes_1" : BIAS_AXES_COL_PARALLEL, - "dot_2_input_axes" : INTERMEDIATE_AXES, - "kernel_2_axes" : KERNEL_AXES_ROW_PARALLEL, - "bias_axes_2" : BIAS_AXES_ROW_PARALLEL, - "dot_1_comm_overlap_config" : {"method" : tex.CommOverlapMethod.RING_EXCHANGE}, - "dot_2_comm_overlap_config" : {"method" : tex.CommOverlapMethod.RING_EXCHANGE}, + "layernorm_input_axes": INPUT_AXES, + "scale_axes": LN_SCALE_AXES, + "ln_bias_axes": LN_BIAS_AXES, + "dot_1_input_axes": INPUT_AXES, + "kernel_1_axes": KERNEL_AXES_COL_PARALLEL, + "bias_axes_1": BIAS_AXES_COL_PARALLEL, + "dot_2_input_axes": INTERMEDIATE_AXES, + "kernel_2_axes": KERNEL_AXES_ROW_PARALLEL, + "bias_axes_2": BIAS_AXES_ROW_PARALLEL, + "dot_1_comm_overlap_config": {"method": tex.CommOverlapMethod.RING_EXCHANGE}, + "dot_2_comm_overlap_config": {"method": tex.CommOverlapMethod.RING_EXCHANGE}, } ) @@ -195,10 +198,14 @@ def _te_flax_layer(layer_name): (W_TP_AXES, DEVICE_TP_AXIS), ) ) -with mesh, axis_rules, te.fp8_autocast( - enabled=fp8_recipe is not None, - fp8_recipe=fp8_recipe, - mesh_resource=mesh_resource, +with ( + mesh, + axis_rules, + te.fp8_autocast( + enabled=fp8_recipe is not None, + fp8_recipe=fp8_recipe, + mesh_resource=mesh_resource, + ), ): model_sharded = partial(args.layer_type, **layer_kwargs) params_sharded = model_sharded.init(init_rngs, x, deterministic=True) diff --git a/examples/jax/comm_overlap/gemm_with_overlap.py b/examples/jax/comm_overlap/gemm_with_overlap.py index 6a03976f3a..615be69e15 100644 --- a/examples/jax/comm_overlap/gemm_with_overlap.py +++ b/examples/jax/comm_overlap/gemm_with_overlap.py @@ -123,7 +123,7 @@ rhs_data = jax.random.normal(key2, rhs_shape, dtype=dtype) lhs = jax.device_put(lhs_data, input_sharding) rhs = jax.device_put(rhs_data, weight_sharding) -dimension_numbers = (((-1, ), (0, )), ((0, ), ())) +dimension_numbers = (((-1,), (0,)), ((0,), ())) # Name of comm+GEMM overlap layer overlap_method = tex.CommOverlapMethod.RING_EXCHANGE @@ -156,14 +156,16 @@ flush=True, ) + @jax.jit def _gemm_wrapper(x, y): return partial( gemm, - dimension_numbers=(((-1, ), (0, )), ((0, ), ())), + dimension_numbers=(((-1,), (0,)), ((0,), ())), comm_overlap=overlap_helper, )(x, y) + with te.sharding.global_shard_guard(mesh_resource): output = _gemm_wrapper(lhs, rhs) diff --git a/examples/jax/comm_overlap/layer_prim_with_overlap.py b/examples/jax/comm_overlap/layer_prim_with_overlap.py index f7e755c2ef..790c1092ad 100644 --- a/examples/jax/comm_overlap/layer_prim_with_overlap.py +++ b/examples/jax/comm_overlap/layer_prim_with_overlap.py @@ -45,17 +45,20 @@ numranks = MPI.COMM_WORLD.Get_size() jax.clear_caches() jax.distributed.initialize(cluster_detection_method="mpi4py") -assert jax.local_device_count() == 1, ( - f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" -) +assert ( + jax.local_device_count() == 1 +), f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" # Parse script arguments _supported_prims = (dense, layernorm_dense, layernorm_mlp) _prim_map = dict((prim.__name__.lower(), prim) for prim in _supported_prims) + + def _te_layer_prim(prim_name): assert isinstance(prim_name, str) and prim_name.lower() in _prim_map return _prim_map[prim_name.lower()] + parser = argparse.ArgumentParser() parser.add_argument("-dp", "--dp-size", type=int, default=1) parser.add_argument("-zp", "--fsdp-size", type=int, default=2) @@ -68,8 +71,9 @@ def _te_layer_prim(prim_name): parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") parser.add_argument("--layer-type", type=_te_layer_prim, default=dense, choices=_supported_prims) -parser.add_argument("--fp8-recipe", type=str.lower, default="none", - choices=["none", "current", "delayed", "mxfp8"]) +parser.add_argument( + "--fp8-recipe", type=str.lower, default="none", choices=["none", "current", "delayed", "mxfp8"] +) parser.add_argument("--check-result", action="store_true") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() @@ -90,7 +94,7 @@ def _te_layer_prim(prim_name): dtype = jnp.bfloat16 input_shape = (args.seq_length, args.hidden_size) if not args.no_batch: - input_shape = (args.batch_size, ) + input_shape + input_shape = (args.batch_size,) + input_shape features = args.hidden_size # post-attention projection if args.layer_type is layernorm_dense: features *= 3 # QKV projection @@ -99,7 +103,7 @@ def _te_layer_prim(prim_name): if args.layer_type is layernorm_mlp else (args.hidden_size, features) ) -bias_shape = (1, args.activation_size) if args.layer_type is layernorm_mlp else (features, ) +bias_shape = (1, args.activation_size) if args.layer_type is layernorm_mlp else (features,) rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -110,9 +114,9 @@ def _te_layer_prim(prim_name): gamma = beta = None if args.layer_type in (layernorm_dense, layernorm_mlp): params_rng, gamma_rng = jax.random.split(params_rng) - gamma = jax.random.normal(gamma_rng, (args.hidden_size, ), dtype=jnp.bfloat16) + gamma = jax.random.normal(gamma_rng, (args.hidden_size,), dtype=jnp.bfloat16) params_rng, beta_rng = jax.random.split(params_rng) - beta = jax.random.normal(beta_rng, (args.hidden_size, ), dtype=jnp.bfloat16) + beta = jax.random.normal(beta_rng, (args.hidden_size,), dtype=jnp.bfloat16) kernel_1 = jax.random.normal(kernel_rng, kernel_shape, dtype=jnp.bfloat16) bias_1 = jax.random.normal(bias_rng, bias_shape, dtype=jnp.bfloat16) @@ -120,20 +124,24 @@ def _te_layer_prim(prim_name): kernel_2 = bias_2 = None if args.layer_type is layernorm_mlp: kernel_rng, kernel_2_rng = jax.random.split(kernel_rng) - kernel_2 = jax.random.normal(kernel_2_rng, (args.activation_size, args.hidden_size), - dtype=jnp.bfloat16) + kernel_2 = jax.random.normal( + kernel_2_rng, (args.activation_size, args.hidden_size), dtype=jnp.bfloat16 + ) bias_rng, bias_2_rng = jax.random.split(bias_rng) - bias_2 = jax.random.normal(bias_2_rng, (args.hidden_size, ), dtype=jnp.bfloat16) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_size,), dtype=jnp.bfloat16) if myrank == 0: - print(f"[{myrank}|{numranks}] {args.layer_type.__name__} inputs:\n" - + f" x: {x.shape}\n" - + f" gamma: {gamma.shape if gamma is not None else None}\n" - + f" beta: {beta.shape if beta is not None else None}\n" - + f" kernel_1: {kernel_1.shape}\n" - + f" bias_1: {bias_1.shape}\n" - + f" kernel_2: {kernel_2.shape if kernel_2 is not None else None}\n" - + f" bias_2: {bias_2.shape if bias_2 is not None else None}\n") + print( + f"[{myrank}|{numranks}] {args.layer_type.__name__} inputs:\n" + + f" x: {x.shape}\n" + + f" gamma: {gamma.shape if gamma is not None else None}\n" + + f" beta: {beta.shape if beta is not None else None}\n" + + f" kernel_1: {kernel_1.shape}\n" + + f" bias_1: {bias_1.shape}\n" + + f" kernel_2: {kernel_2.shape if kernel_2 is not None else None}\n" + + f" bias_2: {bias_2.shape if bias_2 is not None else None}\n" + ) + # Single GPU evaluation def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_): @@ -142,9 +150,7 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne if layer_type_ is dense: layer_args = (x_, kernel_1_, bias_1_) - layer_kwargs = { - "contracting_dims" : ((x.ndim - 1, ), (0, )) - } + layer_kwargs = {"contracting_dims": ((x.ndim - 1,), (0,))} elif layer_type_ is layernorm_dense: layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) @@ -154,18 +160,18 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): - fwd_bwd_serial = jax.jit(jax.value_and_grad(partial(_eval_layer_serial, args.layer_type), - argnums=range(7))) + fwd_bwd_serial = jax.jit( + jax.value_and_grad(partial(_eval_layer_serial, args.layer_type), argnums=range(7)) + ) output_serial, grads_serial = fwd_bwd_serial(x, gamma, beta, kernel_1, bias_1, kernel_2, bias_2) # Device mesh and logical axis resources DEVICE_FSDP_AXIS = "zp" DEVICE_DP_AXIS = "dp" DEVICE_TP_AXIS = "tp" -mesh_shape = { - DEVICE_TP_AXIS: args.tp_size -} +mesh_shape = {DEVICE_TP_AXIS: args.tp_size} if not args.no_batch: mesh_shape[DEVICE_DP_AXIS] = args.dp_size if not args.no_fsdp: @@ -181,19 +187,21 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne print(f"[{myrank}|{numranks}] Device mesh: {mesh}\n") # Logical axes -INPUT_AXES = (SEQLEN_AXES if args.layer_type is dense else SEQLEN_TP_AXES, - HIDDEN_TP_AXES if args.layer_type is dense else HIDDEN_AXES) +INPUT_AXES = ( + SEQLEN_AXES if args.layer_type is dense else SEQLEN_TP_AXES, + HIDDEN_TP_AXES if args.layer_type is dense else HIDDEN_AXES, +) INTERMEDIATE_AXES = (SEQLEN_AXES, HIDDEN_TP_AXES) if not args.no_batch: - INPUT_AXES = (BATCH_AXES, ) + INPUT_AXES - INTERMEDIATE_AXES = (BATCH_AXES, ) + INTERMEDIATE_AXES + INPUT_AXES = (BATCH_AXES,) + INPUT_AXES + INTERMEDIATE_AXES = (BATCH_AXES,) + INTERMEDIATE_AXES -LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES, ) +LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES,) KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) -BIAS_AXES_ROW_PARALLEL = (W_FSDP_AXES, ) +BIAS_AXES_ROW_PARALLEL = (W_FSDP_AXES,) KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) -BIAS_AXES_COL_PARALLEL = (W_TP_AXES, ) +BIAS_AXES_COL_PARALLEL = (W_TP_AXES,) if args.layer_type is layernorm_mlp: KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) BIAS_AXES_COL_PARALLEL = (W_JOINED_AXES, W_TP_AXES) @@ -203,6 +211,7 @@ def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kerne KERNEL_2_AXES = KERNEL_AXES_ROW_PARALLEL if args.layer_type is layernorm_mlp else None BIAS_2_AXES = BIAS_AXES_ROW_PARALLEL if args.layer_type is layernorm_mlp else None + # Multi GPU evaluation def _eval_layer_sharded( layer_type_, @@ -221,39 +230,44 @@ def _eval_layer_sharded( if layer_type_ is dense: layer_args = (x_, kernel_1_, bias_1_) layer_kwargs = { - "input_axes" : INPUT_AXES, - "kernel_axes" : KERNEL_AXES_ROW_PARALLEL, + "input_axes": INPUT_AXES, + "kernel_axes": KERNEL_AXES_ROW_PARALLEL, "comm_overlaps": comm_overlaps_[0], - "contracting_dims" : ((x.ndim - 1, ), (0, )) + "contracting_dims": ((x.ndim - 1,), (0,)), } elif layer_type_ is layernorm_dense: layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) layer_kwargs = { - "layernorm_input_axes" : INPUT_AXES, + "layernorm_input_axes": INPUT_AXES, "dot_input_axes": INPUT_AXES, - "kernel_axes" : KERNEL_AXES_COL_PARALLEL, + "kernel_axes": KERNEL_AXES_COL_PARALLEL, "comm_overlaps": comm_overlaps_[0], } elif layer_type_ is layernorm_mlp: layer_args = (x_, gamma_, beta_, (kernel_1_, kernel_2_), (bias_1_, bias_2_)) layer_kwargs = { - "norm_input_axes" : INPUT_AXES, - "dot_1_input_axes" : INPUT_AXES, - "kernel_1_axes" : KERNEL_AXES_COL_PARALLEL, - "dot_2_input_axes" : INTERMEDIATE_AXES, - "kernel_2_axes" : KERNEL_AXES_ROW_PARALLEL, - "ffn1_comm_overlaps" : comm_overlaps_[0], - "ffn2_comm_overlaps" : comm_overlaps_[1], + "norm_input_axes": INPUT_AXES, + "dot_1_input_axes": INPUT_AXES, + "kernel_1_axes": KERNEL_AXES_COL_PARALLEL, + "dot_2_input_axes": INTERMEDIATE_AXES, + "kernel_2_axes": KERNEL_AXES_ROW_PARALLEL, + "ffn1_comm_overlaps": comm_overlaps_[0], + "ffn2_comm_overlaps": comm_overlaps_[1], } return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) -with mesh, global_shard_guard(mesh_resource), te.fp8_autocast( - enabled=fp8_recipe is not None, - fp8_recipe=fp8_recipe, - mesh_resource=mesh_resource, + +with ( + mesh, + global_shard_guard(mesh_resource), + te.fp8_autocast( + enabled=fp8_recipe is not None, + fp8_recipe=fp8_recipe, + mesh_resource=mesh_resource, + ), ): # Comm+GEMM overlap configs # NOTE: Need to set `tp_resource=` kwarg when *not* initializing under a `global_shard_guard()`. @@ -267,9 +281,7 @@ def _eval_layer_sharded( method=tex.CommOverlapMethod.RING_EXCHANGE, buffer_shape=buffer_shape, ) - comm_overlaps = [ - CommOverlapHelperSet(fprop=fprop_1_overlap) - ] + comm_overlaps = [CommOverlapHelperSet(fprop=fprop_1_overlap)] if args.layer_type is layernorm_mlp: fprop_2_overlap = CommOverlapHelper( comm_type=tex.CommOverlapType.RS, @@ -301,13 +313,13 @@ def _eval_layer_sharded( bias_2 = jax.device_put(bias_2, bias_2_sharding) input_shardings = ( - x_sharding, - gamma_sharding, - beta_sharding, - kernel_1_sharding, - bias_1_sharding, - kernel_2_sharding, - bias_2_sharding, + x_sharding, + gamma_sharding, + beta_sharding, + kernel_1_sharding, + bias_1_sharding, + kernel_2_sharding, + bias_2_sharding, ) output_shardings = ( NamedSharding(mesh, PartitionSpec()), @@ -315,8 +327,7 @@ def _eval_layer_sharded( ) value_and_grad_sharded = jax.jit( jax.value_and_grad( - partial(_eval_layer_sharded, args.layer_type,comm_overlaps), - argnums=range(7) + partial(_eval_layer_sharded, args.layer_type, comm_overlaps), argnums=range(7) ), in_shardings=input_shardings, out_shardings=output_shardings, @@ -329,7 +340,9 @@ def _eval_layer_sharded( if args.check_result: diff = jnp.abs(output_serial - output_sharded) if myrank == 0: - print(f"[{myrank}|{numranks}] Output: serial = {output_serial} | sharded = {output_sharded}") + print( + f"[{myrank}|{numranks}] Output: serial = {output_serial} | sharded = {output_sharded}" + ) rel_err = diff / max(abs(diff), 1e-5) if rel_err > 0.02 and diff > 0.001: if myrank == 0: @@ -342,8 +355,10 @@ def _eval_layer_sharded( for i, (serial, sharded) in enumerate(zip(grads_serial, grads_sharded)): if serial is not None and sharded is not None: if myrank == 0: - print(f"[{myrank}|{numranks}] {labels[i]} : {sharded.shape}\n" - + f" Sharding: {sharded.sharding.spec}\n") + print( + f"[{myrank}|{numranks}] {labels[i]} : {sharded.shape}\n" + + f" Sharding: {sharded.sharding.spec}\n" + ) gathered = jax.lax.with_sharding_constraint( sharded, NamedSharding(mesh, PartitionSpec(None)) ) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 76fea54c9f..3e3f9c6be9 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -145,7 +145,6 @@ void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_stream } } - CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_stop_comm); cudaEventDestroy(_start_comm); @@ -747,7 +746,6 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); } - CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index b823ac0671..77560a9482 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -26,18 +26,9 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); -enum class CommOverlapType : int64_t { - NONE = 0, - RS = 1, - AG = 2 -}; +enum class CommOverlapType : int64_t { NONE = 0, RS = 1, AG = 2 }; -enum class CommOverlapMethod : int64_t { - NONE = 0, - BULK = 1, - PIPELINE = 2, - RING_EXCHANGE = 3 -}; +enum class CommOverlapMethod : int64_t { NONE = 0, BULK = 1, PIPELINE = 2, RING_EXCHANGE = 3 }; enum class CommOverlapAlgo : int64_t { NO_OVERLAP = 0, @@ -98,10 +89,9 @@ class CommOverlapCore { int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); - virtual ~CommOverlapCore(); - void* get_ubuf_dptr() { return _ubuf.dptr(); } + void *get_ubuf_dptr() { return _ubuf.dptr(); } void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; @@ -119,8 +109,6 @@ class CommOverlapCore { TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); - - int get_tp_size() { return _tp_size; } bool is_atomic_gemm() { return _atomic_gemm; } diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 86c48557bb..8fe5fb2d66 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -167,13 +167,14 @@ class CommOverlapHelper: communication buffer, and generates lowering arguments and partitioning rules for the GemmPrimitive. """ + # Core init arguments comm_type: tex.CommOverlapType = field(default=tex.CommOverlapType.NONE) method: tex.CommOverlapMethod = field(default=tex.CommOverlapMethod.NONE) buffer_shape: Sequence[int] = field(default=None) buffer_dtype: jnp.dtype = field(default=jnp.bfloat16) tp_size: int = field( - default_factory=lambda:get_mesh_axis_size(global_mesh_resource().tp_resource) + default_factory=lambda: get_mesh_axis_size(global_mesh_resource().tp_resource) ) # Userbuffers bootstrap kwargs @@ -190,7 +191,7 @@ class CommOverlapHelper: aggregate_ag: bool = field(default=False, kw_only=True) # Other kwargs not passed to Userbuffers - tp_resource: str = field(default_factory=lambda:global_mesh_resource().tp_resource) + tp_resource: str = field(default_factory=lambda: global_mesh_resource().tp_resource) logical_tp_axis: str = field(default=W_TP_AXES, kw_only=True) logical_sp_axis: str = field(default=SEQLEN_TP_AXES, kw_only=True) output_all_gathered_lhs: bool = field(default=False, kw_only=True) @@ -225,9 +226,9 @@ def __post_init__(self): f"CommOverlapHelper: {self.method} is not a valid overlap method for " f"{self.comm_type}." ) - assert self.buffer_shape is not None and len(self.buffer_shape) >= 2, ( - f"CommOverlapHelper: {self.buffer_shape} is not a valid buffer shape." - ) + assert ( + self.buffer_shape is not None and len(self.buffer_shape) >= 2 + ), f"CommOverlapHelper: {self.buffer_shape} is not a valid buffer shape." assert self.tp_resource is not None, ( "CommOverlapHelper: Communication + GEMM overlap requires a valid TP resource. " "This must either be specified via the `tp_resource=` keyword, or " @@ -235,9 +236,9 @@ def __post_init__(self): "`te.sharding.global_shard_guard()` using a `te.sharding.MeshResource()` with a " "valid tensor-parallel mesh axis name." ) - assert self.tp_size % 2 == 0, ( - f"CommOverlapHelper: Tensor-parallel axis of {self.tp_size} is not divisible by 2." - ) + assert ( + self.tp_size % 2 == 0 + ), f"CommOverlapHelper: Tensor-parallel axis of {self.tp_size} is not divisible by 2." if not self.is_bulk() and not self.is_p2p(): # Pipelined overlap is only for reduce-scatter assert not self.is_all_gather(), ( @@ -248,14 +249,16 @@ def __post_init__(self): # Collapse buffer shape to 2D if len(self.buffer_shape) > 2: if self.flatten_axis < 0: - object.__setattr__(self, "flatten_axis", self.flatten_axis + len(self.buffer_shape)) + object.__setattr__( + self, "flatten_axis", self.flatten_axis + len(self.buffer_shape) + ) object.__setattr__( self, "buffer_shape", ( - reduce(operator.mul, self.buffer_shape[ : self.flatten_axis]), - reduce(operator.mul, self.buffer_shape[self.flatten_axis : ]) - ) + reduce(operator.mul, self.buffer_shape[: self.flatten_axis]), + reduce(operator.mul, self.buffer_shape[self.flatten_axis :]), + ), ) # Num splits for P2P overlap is always fixed to TP size @@ -309,9 +312,8 @@ def is_reduce_scatter(self): def has_aux_output(self): """Check if the comm+GEMM overlap has an auxiliary output.""" - return ( - self.is_enabled - and (self.is_bulk() or (self.is_all_gather() and self.output_all_gathered_lhs)) + return self.is_enabled and ( + self.is_bulk() or (self.is_all_gather() and self.output_all_gathered_lhs) ) def get_bootstrap_args_kwargs(self): @@ -324,17 +326,17 @@ def get_bootstrap_args_kwargs(self): self.tp_size, ) kwargs = { - "num_splits" : self.num_splits, - "num_max_streams" : self.num_max_streams, - "comm_cga_size" : self.comm_cga_size, - "gemm_priority" : self.gemm_priority, - "comm_priority" : self.comm_priority, - "num_comm_sm" : self.num_comm_sm, - "set_sm_margin" : self.set_sm_margin, - "use_ce" : self.use_ce, - "atomic_gemm" : self.atomic_gemm, - "rs_overlap_first_gemm" : self.rs_overlap_first_gemm, - "aggregate_ag" : self.aggregate_ag + "num_splits": self.num_splits, + "num_max_streams": self.num_max_streams, + "comm_cga_size": self.comm_cga_size, + "gemm_priority": self.gemm_priority, + "comm_priority": self.comm_priority, + "num_comm_sm": self.num_comm_sm, + "set_sm_margin": self.set_sm_margin, + "use_ce": self.use_ce, + "atomic_gemm": self.atomic_gemm, + "rs_overlap_first_gemm": self.rs_overlap_first_gemm, + "aggregate_ag": self.aggregate_ag, } return args, kwargs @@ -356,10 +358,10 @@ def get_lowering_kwargs(self): aux_axis_boundary = self.scatter_dim + 1 return { - "comm_overlap_id" : self.unique_id, - "comm_overlap_method" : int(self.method.value), - "comm_type" : int(self.comm_type.value), - "aux_axis_boundary" : aux_axis_boundary, + "comm_overlap_id": self.unique_id, + "comm_overlap_method": int(self.method.value), + "comm_type": int(self.comm_type.value), + "aux_axis_boundary": aux_axis_boundary, } @staticmethod @@ -520,7 +522,7 @@ def _get_no_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), - (out_specs, bias_specs, gelu_specs, (None, )), + (out_specs, bias_specs, gelu_specs, (None,)), (all_reduce_spec, reduce_scatter_spec, scatter_dim), ) @@ -530,10 +532,10 @@ def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_ f"over the tensor-parallel mesh resource '{self.tp_resource}' in any dimension." ) - aux_out_specs = (None, ) + aux_out_specs = (None,) bulk_comm_dim = aux_in_specs.index(self.tp_resource) - aux_in_specs_batch = aux_in_specs[ : bulk_comm_dim] - aux_in_specs_tail = aux_in_specs[bulk_comm_dim + 1: ] + aux_in_specs_batch = aux_in_specs[:bulk_comm_dim] + aux_in_specs_tail = aux_in_specs[bulk_comm_dim + 1 :] if self.is_all_gather(): assert all(spec is None for spec in aux_in_specs_tail), ( "CommOverlapHelper: Trailing dimensions of the auxiliary input for bulk all-gather " @@ -543,10 +545,10 @@ def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_ aux_out_specs = ( *aux_in_specs_batch, None, # all-gathered dimension - *[None for _ in range(len(aux_in_specs_tail))] + *[None for _ in range(len(aux_in_specs_tail))], ) else: - assert all(spec is None for spec in aux_in_specs[bulk_comm_dim : ]), ( + assert all(spec is None for spec in aux_in_specs[bulk_comm_dim:]), ( "CommOverlapHelper: Non-batch dimensions of the auxiliary input for bulk " "reduce-scatter overlap cannot be sharded." ) @@ -568,7 +570,6 @@ def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_ xla_reduce_info, ) - def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): contracting_dims, batch_dims = dimension_numbers lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) @@ -601,18 +602,16 @@ def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu lhs_non_cspecs_gathered = list( lhs_specs[i] if i in lhs_bdims else None for i in range(lhs_ndim) if i not in lhs_cdims ) - rhs_non_cspecs = tuple( - rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims - ) + rhs_non_cspecs = tuple(rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims) out_specs = (*lhs_non_cspecs_gathered, *rhs_non_cspecs) self._set_gather_dim(lhs_specs.index(lhs_lspec)) # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_non_cspecs_gathered) : ] + bias_specs = out_specs[len(lhs_non_cspecs_gathered) :] gelu_specs = out_specs # Auxiliary input/output specs depend on bulk vs. non-bulk overlap - aux_out_specs = (None, ) + aux_out_specs = (None,) if self.output_all_gathered_lhs: # Auxiliary output is the same as the LHS spec, except the gathered dimension unsharded aux_out_specs = list(lhs_specs).copy() @@ -627,12 +626,8 @@ def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_nu def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): contracting_dims, batch_dims = dimension_numbers lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) - lhs_cdims, rhs_cdims = map( - sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims - ) - lhs_bdims, rhs_bdims = map( - sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims - ) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims) (_, lhs_cspec), (_, rhs_cspec) = self._check_operand_specs( lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) @@ -659,27 +654,23 @@ def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimensio lhs_bspecs = tuple( lhs_specs[i] for i in range(lhs_ndim) if i in lhs_bdims and i not in lhs_cdims ) - lhs_lspecs = tuple( - lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims - ) - rhs_non_cspecs = tuple( - rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims - ) + lhs_lspecs = tuple(lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims) + rhs_non_cspecs = tuple(rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims) out_specs = ( *lhs_bspecs, self.tp_resource, *[None for _ in range(len(lhs_lspecs) - 1)], - *rhs_non_cspecs + *rhs_non_cspecs, ) self._set_scatter_dim(out_specs.index(self.tp_resource)) # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_bspecs) + len(lhs_lspecs) : ] + bias_specs = out_specs[len(lhs_bspecs) + len(lhs_lspecs) :] gelu_specs = out_specs return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), - (out_specs, bias_specs, gelu_specs, (None, )), + (out_specs, bias_specs, gelu_specs, (None,)), (None, None, None), ) @@ -694,9 +685,9 @@ def get_partitioning_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_n ) impl_map = { - tex.CommOverlapType.NONE : self._get_no_overlap_rules, - tex.CommOverlapType.AG : self._get_all_gather_rules, - tex.CommOverlapType.RS : self._get_reduce_scatter_rules, + tex.CommOverlapType.NONE: self._get_no_overlap_rules, + tex.CommOverlapType.AG: self._get_all_gather_rules, + tex.CommOverlapType.RS: self._get_reduce_scatter_rules, } return impl_map[self.comm_type](lhs_specs, rhs_specs, aux_in_specs, dimension_numbers) @@ -714,12 +705,8 @@ def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): contracting_dims, batch_dims = dimension_numbers lhs_ndim, rhs_ndim = map(len, (lhs_axes, rhs_axes)) - lhs_cdims, rhs_cdims = map( - sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims - ) - lhs_bdims, rhs_bdims = map( - sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims - ) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims) lhs_batch_axes = tuple( lhs_axes[i] for i in range(lhs_ndim) if i in lhs_bdims and i not in lhs_cdims @@ -727,9 +714,7 @@ def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): lhs_leading_axes = tuple( lhs_axes[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims ) - rhs_non_contracting_axes = tuple( - rhs_axes[i] for i in range(rhs_ndim) if i not in rhs_cdims - ) + rhs_non_contracting_axes = tuple(rhs_axes[i] for i in range(rhs_ndim) if i not in rhs_cdims) out_axes = (*lhs_batch_axes, *lhs_leading_axes, *rhs_non_contracting_axes) if self.is_enabled and not self.is_bulk(): @@ -750,7 +735,8 @@ def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): # Generate grad axes without any communication overlap lhs_specs = generate_pspec(lhs_axes) lhs_lspec = tuple( - lhs_specs[i] for i in range(lhs_ndim) + lhs_specs[i] + for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims and lhs_specs[i] is not None ) lhs_lspec = None if len(lhs_lspec) == 0 else lhs_lspec[0] @@ -761,7 +747,8 @@ def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): rhs_specs = generate_pspec(rhs_axes) rhs_lspec = tuple( - rhs_specs[i] for i in range(rhs_ndim) + rhs_specs[i] + for i in range(rhs_ndim) if i not in rhs_bdims + rhs_cdims and rhs_specs[i] is not None ) rhs_lspec = None if len(rhs_lspec) == 0 else rhs_lspec[0] @@ -780,7 +767,7 @@ def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): out_axes = ( *lhs_batch_axes, *lhs_leading_axes, - *[None for _ in range(len(rhs_non_contracting_axes))] + *[None for _ in range(len(rhs_non_contracting_axes))], ) return out_axes @@ -792,6 +779,7 @@ class CommOverlapHelperSet: A set of CommOverlapHelper objects that provide complementary comm+GEMM overlap configurations for FPROP, DGRAD and WGRAD GEMMs in FWD/BWD passes through Dense-layers. """ + fprop: CommOverlapHelper = field(default=None) dgrad: CommOverlapHelper = field(default=None) wgrad: CommOverlapHelper = field(default=None) @@ -808,18 +796,17 @@ def _sanity_check(self): # If FPROP overlap is not defined or not enabled, require DGRAD and WGRAD to also not be # be defined or not enabled if self.fprop is None or not self.fprop.is_enabled: - assert ( - (self.dgrad is None or not self.dgrad.is_enabled) - and (self.wgrad is None or not self.wgrad.is_enabled) + assert (self.dgrad is None or not self.dgrad.is_enabled) and ( + self.wgrad is None or not self.wgrad.is_enabled ), ( "CommOverlapHelperSet: Cannot do communication overlap for DGRAD and/or WGRAD when " "there is no communication overlap for FPROP." ) return - assert not self.fprop.is_bulk(), ( - "CommOverlapHelperSet: Cannot overlap bulk collectives with FPROP." - ) + assert ( + not self.fprop.is_bulk() + ), "CommOverlapHelperSet: Cannot overlap bulk collectives with FPROP." if self.fprop.is_all_gather(): if self.dgrad is not None and self.dgrad.is_enabled: @@ -1037,15 +1024,15 @@ def abstract( output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) # Auxiliary output for comm+GEMM overlap - aux_out_shape = (0, ) + aux_out_shape = (0,) aux_out_dtype = jnp.bfloat16 if comm_overlap.is_enabled: if comm_overlap.is_bulk(): # Bulk overlap will all-gather or reduce-scatter the tensor in the auxiliary input # and return the result of the collective in the auxiliary output - assert aux_in.size > 0, ( - "cuBLAS GEMM w/ bulk collective overlap requires an auxiliary input." - ) + assert ( + aux_in.size > 0 + ), "cuBLAS GEMM w/ bulk collective overlap requires an auxiliary input." assert aux_in.ndim > 1, ( "cuBLAS GEMM w/ bulk collective overlap only supports multidimensional " "auxiliary inputs." @@ -1232,7 +1219,6 @@ def impl( grad, use_split_accumulator, comm_overlap, - ): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0]) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -1319,7 +1305,7 @@ def batcher( if fuse_gelu and not grad: pre_gelu_bdims = out_bdims - aux_out_bdims = (None, ) + aux_out_bdims = (None,) if comm_overlap.is_enabled: if comm_overlap.is_bulk(): # Bulk overlap auxiliary output must have the same batch dims as the auxiliary input @@ -1372,10 +1358,10 @@ def infer_sharding_from_operands( del use_split_accumulator, result_infos lhs_specs, _, rhs_specs, *_, aux_in_specs = map(get_padded_spec, arg_infos) - ( - _, (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), *_ - ) = comm_overlap.get_partitioning_rules( - lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + (_, (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), *_) = ( + comm_overlap.get_partitioning_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) ) # Discard bias gradient and pre-GeLU output specs based on fusion choices @@ -1388,7 +1374,7 @@ def infer_sharding_from_operands( out_shardings = list( map( lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), - (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs) + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), ) ) @@ -1422,8 +1408,8 @@ def partition( ) # Block scale inverses match their operands, but tensor scale inverses are unsharded. - lhs_scale_specs = (None, ) - rhs_scale_specs = (None, ) + lhs_scale_specs = (None,) + rhs_scale_specs = (None,) if scaling_mode.is_1d_block_scaling() and not comm_overlap.is_enabled: lhs_scale_specs = lhs_specs rhs_scale_specs = rhs_specs @@ -1447,7 +1433,7 @@ def partition( rhs_scale_specs, bias_specs, gelu_input_specs, - aux_in_specs + aux_in_specs, ), ) ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1339d3c5dd..6432eb5f77 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -13,9 +13,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f0c86910e6..35f84af543 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -174,9 +174,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - void* out_ptr = + void *out_ptr = (comm_type == CommOverlapType::RS && comm_overlap_method != CommOverlapMethod::BULK) - ? comm_overlaps[comm_overlap_id]->get_ubuf_dptr() : output->untyped_data(); + ? comm_overlaps[comm_overlap_id]->get_ubuf_dptr() + : output->untyped_data(); auto out_ = TensorWrapper(out_ptr, out_shape, out_dtype); // Bias input to forward pass or bias gradient output from backward pass @@ -220,13 +221,13 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); if (comm_type == CommOverlapType::NONE) { NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, expected ", - out_.numel(), " elements ", to_string_like(out_shape), " but got ", - output->element_count(), " elements ", to_string_like(output->dimensions())); + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", + to_string_like(out_shape), " but got ", output->element_count(), " elements ", + to_string_like(output->dimensions())); nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); } else { auto executor = comm_overlaps[comm_overlap_id]; auto tp_size = executor->get_tp_size(); @@ -235,8 +236,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto aux_out_dims = aux_out->dimensions(); std::vector aux_out_shape = {0}; auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); - if ((comm_type == CommOverlapType::AG && aux_out->element_count() > 0) - || comm_type == CommOverlapType::RS) { + if ((comm_type == CommOverlapType::AG && aux_out->element_count() > 0) || + comm_type == CommOverlapType::RS) { std::vector aux_out_shape = { product(aux_out_dims, 0, aux_axis_boundary), product(aux_out_dims, aux_axis_boundary, aux_out_dims.size())}; @@ -280,8 +281,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch GEMM+RS executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, - workspace_, grad, false, use_split_accumulator, rs_out_, - stream); + workspace_, grad, false, use_split_accumulator, rs_out_, stream); } else if (comm_type == CommOverlapType::AG) { // Prepare the auxiliary buffer for all-gathered LHS std::vector aux_out_shape = {0}; @@ -302,8 +302,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch AG+GEMM executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, - workspace_, grad, false, use_split_accumulator, aux_out_, - stream); + workspace_, grad, false, use_split_accumulator, aux_out_, stream); } else { NVTE_ERROR("cuBLAS GEMM w/ comm. overlap invoked with invalid collective type (", static_cast(comm_type), ")"); diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index fc899a92de..29516a5e54 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -5,7 +5,6 @@ ************************************************************************/ #include "../extensions.h" - #include "common/util/pybind_helper.h" namespace transformer_engine { namespace jax { @@ -121,5 +120,3 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE); } - - diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index a3f533fbef..a1e149abe8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -61,15 +61,31 @@ def dense( output += jnp.reshape(bias, bias_new_shape) else: output = _dense( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, - quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, ) return output @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, - quantizer_set): +def _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, +): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -90,15 +106,29 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, - quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, ) return output def _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, - quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -204,8 +234,8 @@ def _dense_bwd_rule( casted_grad = with_sharding_constraint_by_logical_axes( casted_grad, comm_overlaps.fprop.get_logical_output_axes( - input_axes, kernel_axes, (contracting_dims, ((x_bdim, ), ())) - ) + input_axes, kernel_axes, (contracting_dims, ((x_bdim,), ())) + ), ) # If casted_x has transposed data-layout, we need to untranspose it here, and then transpose @@ -259,8 +289,10 @@ def _dense_bwd_rule( elif comm_overlaps.dgrad.is_all_gather() and comm_overlaps.dgrad.output_all_gathered_lhs: # GRAD was all-gathered for DGRAD and a copy of the gathered GRAD is in the auxiliary output casted_grad_rhs.data = ( - dgrad[-1].transpose(*range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), - *range(casted_grad_rhs.flatten_axis)) + dgrad[-1].transpose( + *range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), + *range(casted_grad_rhs.flatten_axis), + ) if casted_grad_rhs.data_layout == "T" else dgrad[-1] ) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 469bfc60ea..4aa0c75c25 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -167,9 +167,7 @@ def _generate_comm_overlap_meta( if method == tex.CommOverlapMethod.NONE: return CommOverlapHelperSet() - tp_resource = config.pop( - "tp_resource", global_mesh_resource().tp_resource - ) + tp_resource = config.pop("tp_resource", global_mesh_resource().tp_resource) input_sp_dim = list(nn.logical_to_mesh_axes(input_axes)).index(tp_resource) logical_sp_axis = config.pop("logical_sp_axis", input_axes[input_sp_dim]) @@ -182,8 +180,7 @@ def _generate_comm_overlap_meta( _ = config.pop("comm_type") buffer_shape = config.pop( - "buffer_shape", - (*input_shape[:-1], param_shape[-1]) if row_parallel else input_shape + "buffer_shape", (*input_shape[:-1], param_shape[-1]) if row_parallel else input_shape ) return CommOverlapHelperSet( @@ -198,6 +195,7 @@ def _generate_comm_overlap_meta( ) ) + class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. @@ -221,7 +219,9 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods softmax_type: SoftmaxType = SoftmaxType.SCALED @nn.compact - def __call__(self, inputs: jnp.ndarray, mask: jnp.ndarray = None, bias: jnp.ndarray = None) -> jnp.ndarray: + def __call__( + self, inputs: jnp.ndarray, mask: jnp.ndarray = None, bias: jnp.ndarray = None + ) -> jnp.ndarray: batch = inputs.shape[0] heads = inputs.shape[1] q_seqlen = inputs.shape[2] @@ -562,7 +562,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: contract_ind = tuple(range(0, len(axis))) if not self.enable_comm_overlap: - self.comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + self.comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) y = dense( inputs, kernel, @@ -577,7 +577,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: self.kernel_axes, self.comm_overlap_method, ), - batch_first=not self.transpose_batch_sequence + batch_first=not self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -846,7 +846,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: # All-Gather is the only supported collective to overlap in LayerNormDenseGeneral if not self.enable_comm_overlap: - self.comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + self.comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) comm_overlaps = _generate_comm_overlap_meta( inputs.shape, self.layernorm_input_axes, @@ -868,7 +868,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: kernel_axes=self.kernel_axes, batch_first=not self.transpose_batch_sequence, quantizer_set=quantizer_set, - comm_overlaps=comm_overlaps + comm_overlaps=comm_overlaps, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -1086,8 +1086,12 @@ class LayerNormMLP(TransformerEngineBase): enable_comm_overlap: bool = False enable_dot_1_comm_overlap: bool = False enable_dot_2_comm_overlap: bool = False - dot_1_comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call - dot_2_comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call + dot_1_comm_overlap_config: dict = field( + default_factory=dict + ) # pylint: disable=invalid-field-call + dot_2_comm_overlap_config: dict = field( + default_factory=dict + ) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -1278,9 +1282,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name = "ffn1" ffn2_ckpt_name = "ffn2" - if not self.enable_dot_1_comm_overlap: - self.dot_1_comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + self.dot_1_comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) ffn1_comm_overlaps = _generate_comm_overlap_meta( inputs.shape, self.layernorm_input_axes, @@ -1290,7 +1293,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ) if not self.enable_dot_2_comm_overlap: - self.dot_2_comm_overlap_config.update({"method" : tex.CommOverlapMethod.NONE}) + self.dot_2_comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) ffn2_comm_overlaps = _generate_comm_overlap_meta( inputs.shape, self.dot_2_input_axes, diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index ca7153b346..09bb0cfb9a 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -239,7 +239,7 @@ def _layernorm_dense_fwd_rule( casted_ln_out_for_bwd = casted_ln_out.get_tensor(TensorUsage.LHS_TRANS) ln_out_transposed_dims = ( *tuple(range(casted_ln_out_for_bwd.flatten_axis, casted_ln_out_for_bwd.ndim)), - *tuple(range(casted_ln_out_for_bwd.flatten_axis)) + *tuple(range(casted_ln_out_for_bwd.flatten_axis)), ) if comm_overlaps.fprop.output_all_gathered_lhs: casted_ln_out_for_bwd.data = ( @@ -327,8 +327,8 @@ def _layernorm_dense_bwd_rule( comm_overlaps.fprop.get_logical_output_axes( dot_input_axes, kernel_axes, - ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) - ) + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim,), ())), + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 737e72be94..84df3e29f1 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -316,8 +316,8 @@ def _layernorm_mlp_fwd_rule( ffn1_comm_overlaps.fprop.get_logical_output_axes( dot_1_input_axes, kernel_1_axes, - ((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())) - ) + ((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + ), ) if use_bias_1 and tex.gemm_uses_jax_dot(): @@ -447,8 +447,8 @@ def _layernorm_mlp_bwd_rule( ffn2_comm_overlaps.fprop.get_logical_output_axes( dot_2_input_axes, kernel_2_axes, - ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) - ) + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim,), ())), + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -482,8 +482,10 @@ def _layernorm_mlp_bwd_rule( casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) if ffn2_comm_overlaps.dgrad.is_enabled: casted_grad_rhs.data = ( - dgrad_2[-1].transpose(*range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), - *range(casted_grad_rhs.flatten_axis)) + dgrad_2[-1].transpose( + *range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), + *range(casted_grad_rhs.flatten_axis) + ) if casted_grad_rhs.data_layout == "T" else dgrad_2[-1] ) @@ -512,8 +514,8 @@ def _layernorm_mlp_bwd_rule( ffn1_comm_overlaps.fprop.get_logical_output_axes( dot_1_input_axes, kernel_1_axes, - ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim, ), ())) - ) + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim,), ())), + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -534,7 +536,7 @@ def _layernorm_mlp_bwd_rule( dgrad_1_aux_in = None ln_out_transposed_dims = ( *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), - *tuple(range(casted_ln_out.flatten_axis)) + *tuple(range(casted_ln_out.flatten_axis)), ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_all_gathered_lhs: