diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index a21d5ecb57..2a1ac0f8fa 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do LOG_FILE="${TEST_CASE}_gpu_${i}.log" # Run pytest and redirect stdout and stderr to the log file - pytest -c "$TE_PATH/tests/jax/pytest.ini" \ + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ --num-process=$NUM_GPUS \ --process-id=$i > "$LOG_FILE" 2>&1 & diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 1f45d10faf..7d5fefddaa 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -25,6 +25,7 @@ assert_params_sufficiently_sharded, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -465,14 +466,14 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -480,7 +481,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -488,14 +489,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -504,7 +505,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -513,14 +514,14 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -529,7 +530,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp_shardy(self): @@ -539,9 +540,32 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.785 + assert actual[0] < 0.43 and actual[1] > 0.80 - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.43 and actual[1] > 0.80 + + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_with_sp_shardy(self): + """Test Transformer Engine with MXFP8 + SP""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.43 and actual[1] > 0.80 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 12148b0e29..40431b8cdd 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -21,6 +21,7 @@ from common import is_bf16_supported, get_fp8_recipe_from_name_string import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "6"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -445,7 +446,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -453,7 +454,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -461,14 +462,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -477,9 +478,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 - - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + assert actual[0] < 0.50 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8_shardy(self): @@ -488,7 +487,19 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.50 and actual[1] > 0.75 + + @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "MXFP8BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.50 and actual[1] > 0.75 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 580824cefa..c1da4db4a9 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -28,8 +28,8 @@ get_fp8_recipe_from_name_string, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase): """Encoder unittests""" def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): - """Run 3 epochs for testing""" - args = encoder_parser([]) + """Run 5 epochs for testing""" + args = encoder_parser(["--epochs", "5"]) num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 @@ -607,7 +607,7 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def test_te_bf16(self): """Test Transformer Engine with BF16""" result = self.exec(False, None) - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -615,7 +615,7 @@ def test_te_bf16(self): def test_te_delayed_scaling_fp8(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling") - assert result[0] < 0.506 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -623,7 +623,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.507 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -631,13 +631,13 @@ def test_te_current_scaling_fp8(self): def test_te_mxfp8(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling") - assert result[0] < 0.505 and result[1] > 0.754 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.505 and result[1] > 0.755 + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" @@ -645,9 +645,7 @@ def test_te_bf16_shardy(self): def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.506 and result[1] > 0.753 - - # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. + assert result[0] < 0.43 and result[1] > 0.80 @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" @@ -655,7 +653,18 @@ def test_te_delayed_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.507 and result[1] > 0.753 + assert result[0] < 0.43 and result[1] > 0.80 + + @unittest.skipIf( + not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" + ) + @unittest.skipIf( + tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." + ) + def test_te_mxfp8_shardy(self): + """Test Transformer Engine with MXFP8""" + result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) + assert result[0] < 0.43 and result[1] > 0.80 if __name__ == "__main__": diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 22bd4bfc25..d6a17ac372 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -13,6 +13,7 @@ from utils import ( assert_allclose, pytest_parametrize_wrapper, + use_jax_gemm, ) from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( - DelayedScaleQuantizer, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -851,6 +851,22 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +valid_fp8_gemm_operand_types = [ + (jnp.float8_e4m3fn, jnp.float8_e4m3fn), + (jnp.float8_e5m2, jnp.float8_e4m3fn), + (jnp.float8_e4m3fn, jnp.float8_e5m2), +] + + +def _use_jax_fp8_gemm(enabled=False): + import os + + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": @@ -883,27 +899,47 @@ def _generate_gemm_input(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) - primitive_out = tex.gemm(x, w, contracting_dims) + primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) + @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + if ( + not with_jax_gemm + and scaling_mode.is_1d_block_scaling() + and jnp.float8_e5m2 in (x_qtype, w_qtype) + ): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False - ) - primitive_out = tex.gemm( - x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=False, ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=( + quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), + rhs_quantizer=( + quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), + ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): @@ -932,9 +968,9 @@ def ref_func(x, w, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -956,23 +992,27 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( - value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) - ) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( + value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) + ) ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func( x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2) @pytest.fixture(name="random_inputs") @@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") - # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1025,8 +1058,8 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1064,41 +1097,35 @@ def ref_func(x, w, gamma, beta): ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - prim_out, ( - prim_x_grad, - prim_w_grad, - prim_gamma_grad, - prim_beta_grad, - ) = value_n_grad_prim_func(x, w, gamma, beta) - - assert_allclose(prim_out, ref_out, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + prim_out, ( + prim_x_grad, + prim_w_grad, + prim_gamma_grad, + prim_beta_grad, + ) = value_n_grad_prim_func(x, w, gamma, beta) + + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias + self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") - # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1123,8 +1150,8 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1153,14 +1180,13 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - # TODO: replace gemm with jnp.dot - linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,))) + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ()))) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) x = _jax_act_lu(linear_1_out, activation_type) - linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) + linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) @@ -1174,15 +1200,16 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_ref_func = value_and_grad(ref_func, range(6)) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 - for _ in range(n_iterations): - prim_out, ( - prim_x_grad, - prim_gamma_grad, - prim_kernel_1_grad, - prim_kernel_2_grad, - prim_bias_1_grad, - prim_bias_2_grad, - ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) + with use_jax_gemm(enabled=with_jax_gemm): + for _ in range(n_iterations): + prim_out, ( + prim_x_grad, + prim_gamma_grad, + prim_kernel_1_grad, + prim_kernel_2_grad, + prim_bias_1_grad, + prim_bias_2_grad, + ) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ref_out, ( ref_x_grad, @@ -1193,18 +1220,18 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) # E5M2 * E5M2 is not supported diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 19037409f6..bfe558f25e 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -13,6 +13,7 @@ assert_tree_like_allclose, is_devices_enough, pytest_parametrize_wrapper, + use_jax_gemm, ) from transformer_engine.common import recipe @@ -147,7 +148,15 @@ def layernorm_fp8_mlp_prim_func( ) def _test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy, + with_jax_gemm, ): jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config @@ -157,72 +166,83 @@ def _test_layernorm_mlp_grad( input_shape, activation_type, use_bias, dtype ) static_inputs = [layernorm_type, activation_type] - value_and_grad_func = jax.value_and_grad( - self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) - ) - # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - single_jitter = jax.jit( - value_and_grad_func, - static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), - ) - single_fwd, single_grads = single_jitter(*inputs, *static_inputs) - - # Multi GPUs - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) - k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) - k1_ = jax.device_put(k1, k1_sharding) - k2_ = jax.device_put(k2, k2_sharding) - if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) - b1_ = jax.device_put(b1, b1_sharding) - else: - b1_sharding = b1_ = None - multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]] - - # Position ref for sharding pspec lists - # x, gamma, k1, k2, b1, - # b2 - in_shardings = ( - None, - None, - k1_sharding, - k2_sharding, - b1_sharding, - None, - ) - out_shardings = ( - None, - (None, None, k1_sharding, k2_sharding, b1_sharding, None), + with use_jax_gemm(enabled=with_jax_gemm): + value_and_grad_func = jax.value_and_grad( + self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) ) - multi_jitter = jax.jit( - value_and_grad_func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), - ) # +1 for multi_gpus - - multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - - assert_allclose(multi_fwd, single_fwd, dtype=dtype) + # Single GPU + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + single_jitter = jax.jit( + value_and_grad_func, + static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), + ) + single_fwd, single_grads = single_jitter(*inputs, *static_inputs) + + # Multi GPUs + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ): + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) + k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) + k1_ = jax.device_put(k1, k1_sharding) + k2_ = jax.device_put(k2, k2_sharding) + if use_bias: + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) + b1_ = jax.device_put(b1, b1_sharding) + else: + b1_sharding = b1_ = None + multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]] + + # Position ref for sharding pspec lists + # x, gamma, k1, k2, b1, + # b2 + in_shardings = ( + None, + None, + k1_sharding, + k2_sharding, + b1_sharding, + None, + ) + out_shardings = ( + None, + (None, None, k1_sharding, k2_sharding, b1_sharding, None), + ) + + multi_jitter = jax.jit( + value_and_grad_func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=range( + len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1 + ), + ) # +1 for multi_gpus + + multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) + + fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn + bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): assert isinstance(single_grads[i], list) for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): assert_allclose( - m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" + m_grad, + s_grad, + dtype=bwd_test_type, + err_msg=f"multi_grads[{i}] is not close", ) else: assert_allclose( multi_grads[i], single_grads[i], - dtype=dtype, + dtype=bwd_test_type, err_msg=f"multi_grads[{i}] is not close", ) @@ -233,8 +253,16 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): self._test_layernorm_mlp_grad( mesh_config, @@ -244,6 +272,7 @@ def test_layernorm_mlp_grad( dtype, fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -252,19 +281,29 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): - # We don't test block scaling with Shardy because at the time of writing, - # it is not supported in JAX's scaled_matmul_stablehlo. + if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=recipe.DelayedScaling(), + fp8_recipe=fp8_recipe, use_shardy=True, + with_jax_gemm=with_jax_gemm, ) def _test_layernorm_mlp( @@ -277,6 +316,7 @@ def _test_layernorm_mlp( use_fp8, fp8_recipe, use_shardy, + with_jax_gemm, ): jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape @@ -288,48 +328,49 @@ def _test_layernorm_mlp( x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) init_rngs = {"params": subkeys[1]} - # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): - ln_mlp_single = LayerNormMLP( - layernorm_type=layernorm_type, - transpose_batch_sequence=False, # input: [batch, seqlen, hidden] - intermediate_dim=INTERMEDIATE, - activations=activation_type, - use_bias=use_bias, - ) - params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) - mlp_out_single, ln_out_single = ln_mlp_single.apply( - params_single, x, deterministic=True - ) - - # Multi GPUs - device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource - ): - ln_mlp_sharded = LayerNormMLP( - layernorm_type=layernorm_type, - transpose_batch_sequence=False, - intermediate_dim=INTERMEDIATE, - activations=activation_type, - scale_axes=LN_SCALE_AXES, - ln_bias_axes=LN_BIAS_AXES, - kernel_axes_1=KERNEL_1_AXES, - kernel_axes_2=KERNEL_2_AXES, - use_bias=use_bias, - bias_axes_1=BIAS_1_AXES, - bias_axes_2=BIAS_2_AXES, - layernorm_input_axes=LAYERNORM_INPUT_AXES, - dot_1_input_axes=DOT_1_INPUT_AXES, - dot_2_input_axes=DOT_2_INPUT_AXES, - name="mlp", - ) - params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) - mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( - params_sharded, x, deterministic=True - ) + with use_jax_gemm(enabled=with_jax_gemm): + # Single GPUs + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + ln_mlp_single = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, # input: [batch, seqlen, hidden] + intermediate_dim=INTERMEDIATE, + activations=activation_type, + use_bias=use_bias, + ) + params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) + mlp_out_single, ln_out_single = ln_mlp_single.apply( + params_single, x, deterministic=True + ) + + # Multi GPUs + device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + with mesh, fp8_autocast( + enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + ): + ln_mlp_sharded = LayerNormMLP( + layernorm_type=layernorm_type, + transpose_batch_sequence=False, + intermediate_dim=INTERMEDIATE, + activations=activation_type, + scale_axes=LN_SCALE_AXES, + ln_bias_axes=LN_BIAS_AXES, + kernel_axes_1=KERNEL_1_AXES, + kernel_axes_2=KERNEL_2_AXES, + use_bias=use_bias, + bias_axes_1=BIAS_1_AXES, + bias_axes_2=BIAS_2_AXES, + layernorm_input_axes=LAYERNORM_INPUT_AXES, + dot_1_input_axes=DOT_1_INPUT_AXES, + dot_2_input_axes=DOT_2_INPUT_AXES, + name="mlp", + ) + params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) + mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( + params_sharded, x, deterministic=True + ) # Make sure params values are the same assert_tree_like_allclose(params_sharded["params"], params_single["params"]) @@ -355,9 +396,9 @@ def _test_layernorm_mlp( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("use_shardy", [False, True]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -367,7 +408,8 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, fp8_recipe=None, - use_shardy=use_shardy, + use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -377,8 +419,9 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -389,4 +432,51 @@ def test_layernorm_mlp_layer_fp8( use_fp8=True, fp8_recipe=fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, + ) + + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_mlp_layer_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm + ): + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=True, + with_jax_gemm=with_jax_gemm, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_mlp_layer_fp8_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + ): + if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=True, + fp8_recipe=fp8_recipe, + use_shardy=True, + with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 39eb368375..13b2b9148f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -3,11 +3,12 @@ # See LICENSE for license information. """Utility for the TE layer tests""" +import os import functools import math import operator -from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional -import os +from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType +from contextlib import contextmanager import jax import jax.numpy as jnp @@ -20,7 +21,6 @@ import pytest from transformer_engine.jax.attention import ( - AttnMaskType, canonicalize_attn_mask_type, make_swa_mask, ) @@ -28,8 +28,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = Any +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] @@ -1519,7 +1519,7 @@ def dtype_tols( TEDType.kFloat8E5M2: jnp.float8_e5m2, }[dtype] elif isinstance(dtype, np.dtype): - dtype = jnp.dtype(dtype) + dtype = DType(dtype) # Expect bit-wise accuracy for integer dtypes if not jnp.issubdtype(dtype, jnp.floating): @@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): fmt = fmt + "\n {}\n {}" jax.debug.print(fmt, *args) + + +@contextmanager +def use_jax_gemm(enabled=False): + orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + + try: + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + yield + + finally: + if enabled: + if orig_custom_calls_filter is None: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + else: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ce66bba3cf..57133f48aa 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -415,37 +415,35 @@ def shardy_sharding_rule( result_types, ): del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - + prefix = "ActLuPrimitive_" x_rank = len(value_types[0].shape) scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2 + x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 ) - x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) + x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) out = (*x_axes[:-2], x_axes[-1]) scale_inv = scale_rules.rowwise_rule - colwise_scale_inv = scale_rules.colwise_rule + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: + colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple( multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) ) else: colwise_out = out - else: - colwise_out = ("j",) - colwise_scale_inv = ("k",) # amax is always a unit tensor. - amax = ("l",) + amax = (prefix + "amax",) return SdyShardingRule( ( x_axes, - "…1", + ("…1",), ), (out, colwise_out, scale_inv, colwise_scale_inv, amax), - **scale_rules.factor_sizes, ) @@ -890,28 +888,26 @@ def shardy_sharding_rule( result_types, ): del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - - x_rank = len(value_types[1].shape) + prefix = "BaseDActLuDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2 + len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec + dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: - colwise_out = tuple(x_axes) - else: - colwise_out = ("j",) + colwise_out = out - dbias = x_axes[-2:] if is_dbias else ("k",) - amax = ("…4",) + dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) + amax = (prefix + "amax",) return SdyShardingRule( - (("…0",), tuple(x_axes), ("…2",)), + (dz_axes, x_axes, ("…2",)), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), - **scale_rules.factor_sizes, ) @@ -985,6 +981,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -993,6 +990,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1037,6 +1035,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1090,6 +1092,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1100,6 +1103,7 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1113,13 +1117,49 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) + scale = jnp.empty((), jnp.float32) + act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive - if not PrimitiveClass.enabled(): + if not PrimitiveClass.enabled() or ( + quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE + ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) + if quantizer is None: + output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + dz, + x, + scale, + # outputs float32 for dbias accumulation + out_dtype=(jnp.float32 if is_dbias else x.dtype), + # default value for no scaling, TE/common ignore this value when scale is unset + scaling_mode=ScalingMode.NO_SCALING.value, + is_2x=False, # unused + scale_dtype=jnp.float32, # unused + is_dbias=False, + act_enum=act_type_id, + act_len=act_len, + is_outer=True, + ) + output = output.astype(x.dtype) + dbias = None + if is_dbias: + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) + + if noop_scaled_tensor: + return ( + ScaledTensorFactory.create_2x( + output, + None, + output, + None, + ScalingMode.NO_SCALING, + dq_dtype=output.dtype, + ), + dbias, + ) + + return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): @@ -1145,31 +1185,6 @@ def quantize_dact_dbias( if war_output is not None: return war_output - scale = jnp.empty((), jnp.float32) - - act_type_id = ActivationEnum[activation_type] - - if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( - dz, - x, - scale, - # outputs float32 for dbias accumulation - out_dtype=(jnp.float32 if is_dbias else x.dtype), - # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused - scale_dtype=jnp.float32, # unused - is_dbias=False, - act_enum=act_type_id, - act_len=act_len, - is_outer=True, - ) - dbias = None - if is_dbias: - dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - return output.astype(x.dtype), dbias - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( @@ -1183,7 +1198,7 @@ def quantize_dact_dbias( ) return out, dbias - if isinstance(quantizer, DelayedScaleQuantizer): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # TE/common dact_dbias_quantize does not support gated act yet @@ -1243,6 +1258,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1252,6 +1268,7 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1262,5 +1279,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, + noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4188d496f1..b306deac14 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,19 +3,26 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict -from functools import partial, reduce -import operator import math +import operator +from collections.abc import Iterable +from typing import Tuple, Sequence, Union +from functools import partial, reduce + import jax import jax.numpy as jnp -from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams +from jax import dtypes +from jax.sharding import NamedSharding, PartitionSpec +from jax.experimental.custom_partitioning import SdyShardingRule + +import transformer_engine_jax as tex +from transformer_engine_jax import get_num_compute_streams from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize - from ..quantize import ( ScaledTensor, + ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, @@ -25,10 +32,20 @@ QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, + apply_padding_to_scale_inv, + remove_padding_from_scale_inv, ) +from .misc import get_padded_spec -__all__ = ["gemm", "grouped_gemm"] +__all__ = [ + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_dims", +] num_cublas_streams = get_num_compute_streams() @@ -36,11 +53,936 @@ def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if get_device_compute_capability(0) >= 90: + if tex.get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 +def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: + """Convert relative (negative) indexes to absolute dimension numbers.""" + dims_ = dims if isinstance(dims, Iterable) else (dims,) + if len(dims_) == 0: + return dims_ + return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None) + + +def get_non_contracting_dims(ndim, contracting_dims): + """Return a tuple of dimensions not included in the contracting dimensions.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(dim for dim in range(ndim) if dim not in contracting_dims) + + +def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1): + """Compute the new dimension numbers after transpose.""" + if len(dims_to_transpose) == 0: + return dims_to_transpose + flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis + transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis)) + return tuple(transposed_dims.index(dim) for dim in dims_to_transpose) + + +def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: + lhs, rhs, e4m3, e5m2 = map( + dtypes.canonicalize_dtype, + ( + lhs_dtype, + rhs_dtype, + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ), + ) + + # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) + if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): + return True + + # Any other combination of data types is not supported + return False + + +def _get_gemm_layout( + operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]] +) -> Tuple[bool, bool]: + lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting + rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting + return lhs_is_transposed, rhs_is_transposed + + +def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): + lhs_q = lhs + rhs_q = rhs + + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims + need_lhs_colwise = lhs_is_transposed and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() + ) + flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) + lhs_q = lhs_quantizer.quantize( + lhs, + is_rowwise=not need_lhs_colwise, + is_colwise=need_lhs_colwise, + flatten_axis=flatten_axis, + ) + + if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_is_transposed = rhs.ndim - 1 in rhs_cdims + need_rhs_colwise = not rhs_is_transposed and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() + ) + flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 + rhs_q = rhs_quantizer.quantize( + rhs, + is_rowwise=not need_rhs_colwise, + is_colwise=need_rhs_colwise, + flatten_axis=flatten_axis, + ) + + assert not isinstance(lhs_q, ScaledTensor2x) + assert not isinstance(rhs_q, ScaledTensor2x) + + return lhs_q, rhs_q + + +class GemmPrimitive(BasePrimitive): + """ + Primitive for cuBLAS GEMM + """ + + name = "te_gemm_ffi" + multiple_results = True + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + + def _dims_are_consecutive(dims): + if len(dims) <= 1: + return True + return sorted(dims) == list(range(min(dims), max(dims) + 1)) + + # Sanity-check operand layouts and types + operand_ndims = (lhs.ndim, rhs.ndim) + + ( + lhs_contracting_dims, + rhs_contracting_dims, + ) = map(sanitize_dims, operand_ndims, contracting_dims) + assert _dims_are_consecutive(lhs_contracting_dims), ( + "cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got " + f"{lhs_contracting_dims}." + ) + assert _dims_are_consecutive(rhs_contracting_dims), ( + "cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got " + f"{rhs_contracting_dims}." + ) + + ( + lhs_batch_dims, + rhs_batch_dims, + ) = map(sanitize_dims, operand_ndims, batched_dims) + assert _dims_are_consecutive(lhs_batch_dims), ( + "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got " + f"{lhs_batch_dims}." + ) + assert _dims_are_consecutive(rhs_batch_dims), ( + "cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got " + f"{rhs_batch_dims}." + ) + if len(lhs_batch_dims) == 0: + assert ( + len(rhs_batch_dims) == 0 + ), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." + elif len(rhs_batch_dims) != 0: + assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all( + bdim in rhs_contracting_dims for bdim in rhs_batch_dims + ), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." + + lhs_contracting_size, rhs_contracting_size = map( + lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims), + ) + assert lhs_contracting_size == rhs_contracting_size, ( + "cuBLAS GEMM operands have incompatible contracting dimensions: " + f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) + + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) + if scaling_mode != ScalingMode.NO_SCALING: + assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + assert ( + lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 + ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + if ( + scaling_mode != ScalingMode.MXFP8_1D_SCALING + and not tex.is_non_nt_fp8_gemm_supported() + ): + assert not lhs_is_transposed and rhs_is_transposed, ( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) + + # Determine output shape and dtype + assert ( + dtypes.canonicalize_dtype(out_dtype).itemsize > 1 + ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." + lhs_non_contracting_shape, rhs_non_contracting_shape = map( + lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims), + ) + out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # Validate bias + bias_shape = (0,) + bias_dtype = out_dtype + if fuse_bias: + expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) + if not grad: + assert bias.size == expected_bias_size, ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({expected_bias_size}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {bias_dtype} but found {bias.dtype}." + ) + bias_shape = bias.shape + else: + bias_shape = rhs_non_contracting_shape + bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + + # Validate pre-GeLU + pre_gelu_shape = (0,) + pre_gelu_dtype = out_dtype + if fuse_gelu: + pre_gelu_shape = out_shape + if grad: + pre_gelu_ndim = len(pre_gelu_shape) + assert gelu_input.ndim == pre_gelu_shape and all( + gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim) + ), ( + "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " + f"expected {pre_gelu_shape} but found {gelu_input.shape}." + ) + assert gelu_input.dtype == out_dtype, ( + "cuBLAS GEMM pre-GeLU tensor has incorrect data type, " + f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." + ) + pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + + # Need extra workspace for swizzled scale factors + lhs_swizzle_size = 0 + rhs_swizzle_size = 0 + swizzle_dtype = jnp.uint8 + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_swizzle_size = lhs_scale_inv.size + rhs_swizzle_size = rhs_scale_inv.size + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) + + # Declare cuBLAS workspace + # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not + # necessarily 256 bytes aligned, we add some padding to ensure alignment. + workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + + return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + + @staticmethod + def outer_abstract(*args, **kwargs): + outputs = GemmPrimitive.abstract(*args, **kwargs) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype + lhs_aval, _, rhs_aval, *_ = ctx.avals_in + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) + ) + + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + kwargs = { + "scaling_mode": int(scaling_mode.value), + "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed": lhs_transposed, + "rhs_transposed": rhs_transposed, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + } + + operand_output_aliases = {} + if fuse_bias and not grad: + operand_output_aliases.update({4: 1}) # bias <-> bias_grad + if fuse_gelu and grad: + operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out + + return jax.ffi.ffi_lowering( + GemmPrimitive.name, + operand_output_aliases=operand_output_aliases, + )(ctx, *args, **kwargs) + + @staticmethod + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + is_colwise=lhs_quantized_colwise, + flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, + scaling_mode, + rhs.shape, + is_colwise=rhs_quantized_colwise, + flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + ) + + outputs = GemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def batcher( + batched_args, + jax_batch_dims, + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + assert GemmPrimitive.outer_primitive is not None + lhs, _, rhs, *_ = batched_args + lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims + arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) + arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + + # Output is batched like the non-contracting batch dimensions of the LHS operand + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) + lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) + out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + + # Bias gradient is never batched + bias_bdims = (None,) + + # Pre-GeLU output, if exists, is batched like GEMM output + pre_gelu_bdims = (None,) + if fuse_gelu and not grad: + pre_gelu_bdims = out_bdims + + return ( + GemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ), + (out_bdims, bias_bdims, pre_gelu_bdims), + ) + + @staticmethod + def _decompose_operand_specs(specs, contracting_dims, batch_dims): + ndims = len(specs) + cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) + + # Batch specs + bspecs = tuple(specs[i] for i in bdims) + + # Non-batch leading dimension specs + lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) + + # Non-batch contracting dimension specs + cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + + return bspecs, lspecs, cspecs + + @staticmethod + def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): + lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims + ) + ( + (lhs_bspecs, lhs_lspecs, lhs_cspecs), + (rhs_bspecs, rhs_lspecs, rhs_cspecs), + ) = map( + GemmPrimitive._decompose_operand_specs, + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + + # Batched dimensions must have the same sharding + if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: + assert all( + lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) + ), ( + "cuBLAS GEMM operand batch dimensions must have the same sharding: " + f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + ) + + # Only one each of the non-batched leading dimensions and non-batched contracting + # dimensions can be sharded + lhs_ldims, rhs_ldims = map( + lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + (lhs_ndim, rhs_ndim), + (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), + ) + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( + lambda specs: tuple(spec for spec in specs if spec is not None), + (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), + ) + assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " + f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." + ) + assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " + f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + ) + + # Extract single leading and contracting dimension specs + (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + lambda specs: None if len(specs) == 0 else specs[0], + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + ) + + # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts + # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. + # 1. K1 == K2 != None and N == None + # LHS: (B, M, K) + # RHS: (B, None, K) + # OUT: (B, M, None) --(AR)-> (B, M, None) + # 2. K1 == K2 != None and M == N != None + # LHS: (B, M, K) + # RHS: (B, N, K)--(AG)->(B, None, K) + # OUT: (B, M, None) --(RS)--> (B, M, N) + # 3. M == N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, M, K)--(AG)->(B, None, None) + # OUT: (B, M, None) + # 4. M != N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, N, K)--(AG)->(B, N, None) + # OUT: (B, M, N) + reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec + all_reduce_output = reduce_flag and rhs_lspec is None + reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec + all_reduce_spec = reduce_scatter_spec = scatter_dim = None + + lhs_non_contracting_specs, rhs_non_contracting_specs = map( + lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + ) + out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) + if reduce_scatter_output: + # All-gather (if necessary) the non-batch non-contracting dimension of RHS + # (B, N, K) --(AG)-> (B, None, K) + # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) + rhs_spec = tuple( + rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) + ) + reduce_scatter_spec = lhs_cspec + scatter_dim = out_specs.index(rhs_lspec) + + elif all_reduce_output: + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + all_reduce_spec = lhs_cspec + else: + # All-gather (if necessary) the non-batch contracting dimensions + # (B, M, K) --(AG)-> (B, M, None) + # (B, N, K) --(AG)-> (B, N, None) + # (B, M, None) x (B, N, None)^T = (B, M, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] + for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Check if RHS non-contracting spec also appears in the LHS non-contracting specs + if rhs_lspec is not None and rhs_lspec in tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims + ): + # All-gather (if necessary) the non-batch non-contracting dimensions of RHS + # (B, N, None) --(AG)-> (B, None, None) + # (B, M, None) x (B, None, None)^T = (B, M, None) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_contracting_specs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs), + (out_specs, bias_specs, gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del ( + out_dtype, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + grad, + ) + del use_split_accumulator, result_infos + + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + ) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) + + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) + + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)) + + return [out_sharding, dbias_sharding, pre_gelu_sharding] + + @staticmethod + def partition( + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del result_infos + + ( + (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), + (out_specs, dbias_specs, pre_gelu_specs), + all_reduce_spec, + reduce_scatter_spec, + scatter_dim, + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + + # Assemble argument shardings + # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + none_sharding = NamedSharding(mesh, PartitionSpec(None)) + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) + arg_shardings = ( + lhs_sharding, + lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + rhs_sharding, + rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding, + ) + + # Discard bias input spec if there is no bias fusion + if not fuse_bias: + bias_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),) + + # Discard pre-GeLU input spec if there is no GeLU fusion + if not fuse_gelu: + gelu_input_specs = (None,) + arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + + # Assemble output shardings + out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] + + # Discard bias gradient spec if there is no bias fusion + if not fuse_bias: + dbias_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) + + # Discard pre-GeLU output spec if there is no GeLU fusion + if not fuse_gelu: + pre_gelu_specs = (None,) + out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) + + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + outputs = GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + + # All-Reduce/Reduce-Scatter GEMM output + if all_reduce_spec is not None: + outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) + elif reduce_scatter_spec is not None: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum_scatter( + outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + + return outputs + + return mesh, _sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + out_dtype, + contracting_dims, + batched_dims, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + mesh, + operand_types, + result_types, + ): + del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator + del mesh, result_types + + prefix = "GemmPrimitive_" + + def _generate_operand_rules(name, ndim, cdims, bdims): + specs = [] + ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) + for i in range(ndim): + dim_name = None + if i in bdims: + dim_idx = bdims.index(i) if len(bdims) > 1 else "" + dim_name = f"b{dim_idx}" + elif i in cdims: + dim_idx = cdims.index(i) if len(cdims) > 1 else "" + dim_name = f"k{dim_idx}" + else: + dim_idx = ldims.index(i) if len(ldims) > 1 else "" + dim_name = f"{name}_l{dim_idx}" + specs.append(prefix + dim_name) + return specs + + lhs, _, rhs, *_ = operand_types + operand_ndims = (len(lhs.shape), len(rhs.shape)) + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( + lambda dims: map(sanitize_dims, operand_ndims, dims), + (contracting_dims, batched_dims), + ) + lhs_specs, rhs_specs = map( + _generate_operand_rules, + ("lhs", "rhs"), + operand_ndims, + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + lhs_scale_specs = ("…1",) + rhs_scale_specs = ("…2",) + if scaling_mode.is_1d_block_scaling(): + # Shardy rules for MXFP8 scales cannot be related to the operands because of the + # global-unpadding and local-padding workflow. This can potentially insert expensive + # re-shards in the partition call later if the scales are not already sharded correctly. + lhs_scale_specs, rhs_scale_specs = map( + lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), + (lhs_specs, rhs_specs), + ) + + lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) + rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) + out_spec = (*lhs_non_cspec, *rhs_non_cspec) + bias_spec = rhs_non_cspec if fuse_bias else ("…4",) + gelu_spec = out_spec if fuse_gelu else ("…5",) + + return SdyShardingRule( + operand_mappings=( + lhs_specs, + lhs_scale_specs, + rhs_specs, + rhs_scale_specs, + bias_spec, + gelu_spec, + ), + result_mappings=( + out_spec, + bias_spec, + gelu_spec, + ), + ) + + +register_primitive(GemmPrimitive) + + +def gemm_uses_jax_dot() -> bool: + """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" + return not GemmPrimitive.enabled() + + +def _get_scale_inv_without_padding(scaled_tensor): + return remove_padding_from_scale_inv( + scaled_tensor.scale_inv, + scaled_tensor.scaling_mode, + scaled_tensor.data.shape, + is_colwise=scaled_tensor.is_colwise, + flatten_axis=scaled_tensor.flatten_axis, + ) + + +def _te_gemm( + lhs: Union[jax.Array, ScaledTensor], + rhs: Union[jax.Array, ScaledTensor], + bias: jax.Array = None, + gelu_input: jax.Array = None, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), + batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), + fuse_bias: bool = False, + fuse_gelu: bool = False, + grad: bool = False, + use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, +) -> Tuple[jax.Array, ...]: + # Prepare non-quantized GEMM operands + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) + + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + + # Extract GEMM custom op inputs from quantized operands + if isinstance(lhs_q, ScaledTensor): + assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) + if isinstance(lhs_q, ScaledTensor2x): + # Choose the quantization of the contracting dimension(s) + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + scaling_mode = lhs_q.scaling_mode + lhs_data = lhs_q.data + lhs_scale_inv = _get_scale_inv_without_padding(lhs_q) + if lhs_q.data_layout == "T": + lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) + + if isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) + if isinstance(rhs_q, ScaledTensor2x): + # Choose the quantization of the contracting dimension(s) + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) + rhs_data = rhs_q.data + rhs_scale_inv = _get_scale_inv_without_padding(rhs_q) + if rhs_q.data_layout == "T": + rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) + + # Dummy empties for bias and gelu + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype + if bias is None or not (fuse_bias and not grad): + bias = jnp.empty(0, dtype=out_dtype) + if gelu_input is None or not (fuse_gelu and grad): + gelu_input = jnp.empty(0, dtype=out_dtype) + + return GemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + contracting_dims=(lhs_cdims, rhs_cdims), + batched_dims=(lhs_bdims, rhs_bdims), + lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, + rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -230,11 +1172,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) def _calculate_remaining_shape(shape, contracting_dims): - return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) - - -def _transpose_contract_dims(ndim, contracting_dims): - return tuple(ndim - i - 1 for i in contracting_dims)[::-1] + contracting_dims_ = sanitize_dims(len(shape), contracting_dims) + return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_) # Apply jit to guarantee correctness of FP8 GEMM. @@ -242,9 +1181,11 @@ def _transpose_contract_dims(ndim, contracting_dims): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract) + lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) + lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": - rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract) + rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) + rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -315,12 +1256,12 @@ def _jax_gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, ) -> jnp.ndarray: """ FP8 GEMM via JAX """ - dim_nums = (contracting_dims, ((), ())) def _jax_gemm_fp8_impl(lhs, rhs): @@ -340,37 +1281,16 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - return _jax_gemm_fp8_impl(lhs, rhs) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) - if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): - if quantizer_set != noop_quantizer_set: - assert type(quantizer_set.x) is type(quantizer_set.kernel) - if ( - quantizer_set.x.scaling_mode.is_tensor_scaling() - and is_fp8_gemm_with_all_layouts_supported() - ): - lhs_is_rowwise = rhs_is_rowwise = True - else: - (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - lhs_q = quantizer_set.x.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = quantizer_set.kernel.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): + return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( isinstance(lhs, jnp.ndarray) and isinstance(rhs, jnp.ndarray) - and quantizer_set == noop_quantizer_set + and lhs_quantizer is None + and rhs_quantizer is None ): return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) @@ -380,30 +1300,109 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: QuantizerSet = noop_quantizer_set, -) -> jnp.ndarray: - """General matrix multiplication with optional quantization. - - Args: - lhs: First input matrix. - rhs: Second input matrix. - contracting_dims: Tuple of two sequences representing the contracting dimensions. - The first sequence represents the contracting dimensions of the first matrix, - and the second sequence represents the contracting dimensions of the second matrix. - quantizer_set: Set of quantizers for FP8 quantization of the output. - If None, no quantization is applied and the output has the same dtype as the inputs. - - Returns: - If quantizer_set is None: - The matrix multiplication result. - Shape: (M, N) - Dtype: Same as input dtype - If quantizer_set is provided: - A ScaledTensor containing the quantized matrix multiplication result. + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), + batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + **kwargs, +) -> Tuple[jnp.ndarray, ...]: + r"""General matrix multiplication with optional quantization. + + Parameters + ---------- + lhs: Union[jax.Array, ScaledTensor] + Left-hand side operand in the matrix multiplication. + rhs: Union[jax.Array, ScaledTensor] + Right-hand side operand in the matrix multiplication. + lhs_quantizer: Quantizer, default = None + Object for down-casting the LHS operand for quantized GEMM. + rhs_quantizer: Quantizer, default = None + Object for down-casting the RHS operand for quantized GEMM. + contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) + Tuple of sequences representing the contracting dimensions of the operands. + batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), + Tuple of sequences representing the batched dimensions of the operands. This is *not* used + to perform a batched matrix multiplication, but it is required to avoid a potentially + undesirable reduction in any batched contracting dimensions when invoked with sharded + operands (e.g. when computing weight gradients in a Flax module). + bias: jax.Array, default = None + Optional additive bias term, required for forward GEMM with bias fusion. Only supported + with TE's custom call to cuBLAS GEMM. + gelu_input: jax.Array, default = None + Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only + supported with TE's custom call to cuBLAS GEMM. + fuse_bias: bool, default = False + Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with + TE's custom call to cuBLAS GEMM. + fuse_gelu: bool, default = False + Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported + with TE's custom call to cuBLAS GEMM. + grad: bool, default = False + Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with + TE's custom call to cuBLAS GEMM. + use_split_accumulator: bool, default = True + Enable promoting some intermediate sums to higher precision when accumulating the result in + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + + Returns + ------- + jax.Array: + Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the + GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution + when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and + `grad=False`. + Optional[jax.Array]: + Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call + to cuBLAS GEMM. + Optional[jax.Array]: + Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input + to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to + compute the GeLU contribution to the gradient. Only supported with TE's custom call to + cuBLAS GEMM. """ + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility + if lhs_quantizer is None or rhs_quantizer is None: + quantizer_set = kwargs.get("quantizer_set", None) + if quantizer_set is not None: + lhs_quantizer = quantizer_set.x + rhs_quantizer = quantizer_set.kernel + + # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) + if not GemmPrimitive.enabled(): + assert kwargs.get("bias", None) is None and not fuse_gelu, ( + "TE GEMM was invoked with bias fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( + "TE GEMM was invoked with GeLU fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) + + outputs = _te_gemm( + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + contracting_dims=contracting_dims, + batched_dims=batched_dims, + **kwargs, + ) - return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) + # Discard empty outputs + grad = kwargs.get("grad", False) + clean_outputs = outputs[0] # first output is the final result and is never empty + if (fuse_bias and grad) or (fuse_gelu and not grad): + clean_outputs = (outputs[0],) + if fuse_bias and grad: # only return bias gradient if it exists + clean_outputs += (outputs[1],) + if fuse_gelu and not grad: # only return pre-GeLU output if it exists + clean_outputs += (outputs[2],) + return clean_outputs def grouped_gemm( diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 87f1c1913a..94dfaa45a4 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to calculate dbias separately. This function checks if the workaround should be applied. """ + if quantizer is None: + return False + arch_l_100 = False for local_gpu_id in range(len(jax.local_devices())): if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: arch_l_100 = True break + # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, + # but this fails when bias fusion is turned on with arch < 100. + force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() return ( - quantizer is not None - and quantizer.q_layout == QuantizeLayout.ROWWISE + (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 and is_dbias ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 07ebb33114..3b563efbd0 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -587,16 +587,17 @@ def shardy_sharding_rule( result_types, ) + prefix = "NormFwdPrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1 + len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec - out = x_axes[:-1] + ("k",) - colwise_out = out if is_2x else ("…4",) + out = x_axes + colwise_out = out if is_2x else (prefix + "out_colwise",) rsigma = x_axes[:-1] - mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma - amax = ("…6",) + mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = (prefix + "amax",) return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), @@ -609,7 +610,6 @@ def shardy_sharding_rule( mu, rsigma, ), - **scale_rules.factor_sizes, ) @@ -1276,6 +1276,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1292,6 +1293,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1319,6 +1321,15 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") + if quantizer is None and noop_scaled_tensor: + return ( + ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), + mu, + rsigma, + ) + return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 11b3cdc2a3..23e821b1a0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -36,7 +36,6 @@ Quantizer, GroupedQuantizer, QuantizeLayout, - DelayedScaleQuantizer, ScalingMode, compute_scale_from_amax, ) @@ -489,9 +488,10 @@ def shardy_sharding_rule( ): del out_dtype, scale_dtype, is_outer, mesh, result_types + prefix = "BaseDBiasQuantizePrimitive_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( len(value_types[0].shape), - unique_var="BaseDBiasQuantizePrimitive_i", + unique_var=prefix + "x", flatten_axis=flatten_axis, ) @@ -499,22 +499,19 @@ def shardy_sharding_rule( colwise_scale_inv = scale_rules.colwise_rule out = x_axes + colwise_out = (prefix + "out_colwise",) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) else: colwise_out = x_axes - else: - colwise_out = ("j",) - colwise_scale_inv = ("k",) - dbias = x_axes[flatten_axis:] if is_dbias else ("l",) - amax = ("m",) + dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) + amax = (prefix + "amax",) return SdyShardingRule( (x_axes, ("…1",)), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), - **scale_rules.factor_sizes, ) @@ -538,11 +535,12 @@ def _jax_quantize( def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): - assert flatten_axis < 0 + sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype dbias = jnp.sum( dx.astype(jnp.float32), - axis=tuple(range(dx.ndim + flatten_axis)), + axis=tuple(range(sum_axis)), keepdims=False, ) return dbias.astype(dtype) @@ -568,6 +566,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -577,24 +576,34 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + # Early-exit for non-quantized call dq_dtype = dq_dtype or x.dtype - - PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if not PrimitiveClass.enabled(): + if quantizer is None: + dbias = None if is_dbias: - return _jax_quantize_dbias( - x, - quantizer=quantizer, - dq_dtype=dq_dtype, - flatten_axis=flatten_axis, + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + if noop_scaled_tensor: + # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() + # always works. + return ( + ScaledTensorFactory.create_2x( + x, + None, + x, + None, + ScalingMode.NO_SCALING, + dq_dtype=x.dtype, + data_layout="NN", + flatten_axis=flatten_axis, + ), + dbias, ) - return ( - _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), - None, - ) + return x, dbias - # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, + # fall back on the native-JAX quantize implementation + PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive + if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -606,9 +615,8 @@ def _quantize_dbias_impl( _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), None, ) - scale = jnp.empty((), jnp.float32) - # TE/common dbias_quantize does not support 1x on arch < 100 + # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out, _ = _quantize_dbias_impl( x=x, @@ -620,29 +628,23 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - if quantizer is None: - if is_dbias: - return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - return x, None - + scale = jnp.empty((), jnp.float32) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) scale = compute_scale_from_amax(amax, quantizer.q_dtype) - - if isinstance(quantizer, DelayedScaleQuantizer): + elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale - is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) # It is faster to use 1x quantization for tensor scaling + is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() and is_1x_kernel_supported ) - q_layout = quantizer.q_layout if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE @@ -698,6 +700,7 @@ def quantize( x: jnp.ndarray, quantizer: Quantizer, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -707,6 +710,8 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer + is None. Returns: A ScaledTensor containing the quantized input tensor. @@ -715,6 +720,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) return out @@ -724,6 +730,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -734,6 +741,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when + quantizer is None. Returns: A tuple containing: @@ -743,7 +752,11 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis + dz, + quantizer=quantizer, + is_dbias=is_dbias, + flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0789478348..59079fe3f0 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +// GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index a760df4a79..e77c38e990 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E4M3FN: return DType::kFloat8E4M3; break; - // case xla::ffi::DataType::F8E8M0FNU: - // return DType::kFloat8E8M0; - // break; + case xla::ffi::DataType::F8E8M0FNU: + return DType::kFloat8E8M0; + break; default: auto type_num = static_cast(type); - if (type_num == 33) return DType::kFloat8E8M0; NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", static_cast(type_num)); break; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index fa8c0b988f..ba2d65e3eb 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -6,11 +6,13 @@ #include "transformer_engine/gemm.h" #include +#include +#include #include "../extensions.h" #include "common/util/cuda_runtime.h" +#include "common/util/string.h" #include "common/util/system.h" -#include "transformer_engine/multi_stream.h" #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -25,6 +27,181 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ~static_cast(255)); } +std::tuple> xla_buffer_to_nvte_gemm_operand( + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + // Set tensor data with collapsed 2D shape + auto buffer_dims = buffer.dimensions(); + std::vector input_shape = {product(buffer_dims, 0, axis_boundary), + product(buffer_dims, axis_boundary, buffer_dims.size())}; + auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type()); + TensorWrapper input(get_nvte_scaling_mode(scaling_mode)); + + if (rowwise) { + input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + } else { + input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + } + + // Set scaling factor for quantized tensors + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); + + std::vector scale_shape = {1}; + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Block scaling also needs to be collapsed to match 2D data + scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), + product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + } + + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + + // Swizzle scaling factors for MXFP8 + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Get the swizzle buffer + NVTE_CHECK(swizzled_scale_inv->element_count() > 0, + "Missing swizzled inverse scale buffer in the JAX primitive."); + auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto swizzled_scale_inv_dtype = + convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); + NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); + + // Create tensor to hold swizzled scale factor + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, + scale_shape); + } + + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + + // Set swizzled scales into the input tensor + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, + scale_shape); + } + } + } + + return std::make_tuple(std::move(input), input_shape); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, + Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + // Operands (this includes swizzling MXFP8 scaling factors) + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when + // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) + bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; + bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + + // Output tensor + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, " + "expected ", + out_.numel(), " elements ", to_string_like(out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); + + // Bias input to forward pass or bias gradient output from backward pass + void *bias_ptr = nullptr; + std::vector bias_shape = {0}; + DType bias_dtype = out_dtype; + if (fuse_bias) { + if (!grad) { + NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); + } + bias_ptr = bias_grad->untyped_data(); + bias_shape.at(0) = bias_grad->dimensions().front(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + } + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + // Pre-GeLU output from forward pass or input to backward pass + void *pre_gelu_ptr = nullptr; + std::vector pre_gelu_shape = {0}; + DType pre_gelu_dtype = out_dtype; + if (gelu_input.element_count() > 0) { + if (grad) { + NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"); + } + pre_gelu_ptr = pre_gelu_out->untyped_data(); + pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1), + static_cast(pre_gelu_out->dimensions().back())}; + pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type()); + } + auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); + + // cuBLAS workspace + 256 alignment enforcement + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; + auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); + + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 03194e9d72..af7f54feb6 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { CURRENT_TENSOR_SCALING = 3, }; +inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING || + mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); +} + +inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2d7801cc20..afbeb644c1 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -55,6 +55,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + // GEMM + dict["te_gemm_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); + // Grouped GEMM dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), @@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); + m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 57170e85be..a0fc7b7af8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -8,7 +8,7 @@ It implements matrix multiplication with optional bias addition and supports customizable contracting dimensions for flexible tensor operations. """ - +import warnings from typing import Tuple, Sequence from functools import partial import jax @@ -23,6 +23,16 @@ ) +DENSE_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global DENSE_BATCH_FIRST_WARNING_ISSUED + if not DENSE_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + DENSE_BATCH_FIRST_WARNING_ISSUED = True + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -30,6 +40,7 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -43,25 +54,28 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + batch_first: Assume that X is batched in the first dimension. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set: + if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims) + output = tex.gemm(x, kernel, contracting_dims=contracting_dims) if bias is not None: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) + output = _dense( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -75,44 +89,81 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +def _dense_fwd_rule( + x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set +): """Forward pass rule for dense layer transformation. Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims + x_contracting_dims, k_contracting_dims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + + # Check supported input layout + x_is_transposed = x.ndim - 1 not in x_contracting_dims + k_is_transposed = kernel.ndim - 1 in k_contracting_dims + assert ( + not x_is_transposed and not k_is_transposed + ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + + # Determine X batch dimension + # - If `batch_first=True` -> (batch, leading..., contracting...) + # - Otherwise -> (leading..., batch, contracting...) + # NOTE: Always assume a single batch dimension + x_bdim = None + num_cdims = len(x_contracting_dims) + if x.ndim >= num_cdims + 2: + # Assume X is batched if it has at least +2 dimensions more than the number of contracting + # dimensions. + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `dense()` layer implementation does not officially support sequence-first " + "inputs and may produce incorrect results when `batch_first=False`. Use " + "sequence-first inputs at your own discretion.", + ) + x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = tex.quantize( + x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + kernel, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.kernel, + noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # GEMM NN + use_bias = bias is not None output = tex.gemm( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) @@ -124,20 +175,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( casted_x_lhs, casted_kernel_rhs, @@ -146,10 +196,19 @@ def _dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) = ctx + fwd_x_contracting_dims, fwd_k_contracting_dims = map( + tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims + ) + casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) # GEMM NT @@ -164,7 +223,8 @@ def _dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, - (g_contracting_dim, k_contracting_dim), + contracting_dims=(g_contracting_dim, k_contracting_dim), + batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -177,7 +237,8 @@ def _dense_bwd_rule( wgrad = tex.gemm( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), - (x_contracting_dim, g_contracting_dim), + contracting_dims=(x_contracting_dim, g_contracting_dim), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index bd311472f0..5992d36079 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -6,7 +6,7 @@ """ from functools import reduce import operator -from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType import numpy as np import jax.numpy as jnp @@ -15,12 +15,12 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, _issue_batch_first_warning as _dense_warning from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm -from ..layernorm_dense import layernorm_dense -from ..layernorm_mlp import layernorm_mlp +from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning +from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -35,8 +35,8 @@ PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = jnp.ndarray +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] @@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase): input_axes: Tuple[str, ...] = () def __post_init__(self): + if self.transpose_batch_sequence: + _dense_warning( + "TE/JAX DenseGeneral() module does not officially support sequence-first inputs " + "and may produce incorrect results when `transpose_batch_sequence=True`. Use " + "sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype @@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float = None def __post_init__(self): + if self.transpose_batch_sequence: + _ln_dense_warning( + "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first " + "inputs and may produce incorrect results when `transpose_batch_sequence=True`. " + "Use sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, @@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase): dot_2_input_axes: Tuple[str, ...] = None def __post_init__(self): + if self.transpose_batch_sequence: + _ln_mlp_warning( + "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs " + "and may produce incorrect results when `transpose_batch_sequence=True`. Use " + "sequence-first inputs at your own discretion." + ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index ea66e78302..5ccfc71c24 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -9,6 +9,7 @@ distributed training through sharding constraints. """ +import warnings from functools import partial from typing import Tuple @@ -25,6 +26,16 @@ ) +LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED + if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True + + def layernorm_dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -37,6 +48,7 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -57,6 +69,7 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_set: Set of quantizers for different tensor types Returns: @@ -80,6 +93,7 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -94,6 +108,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -108,6 +123,7 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], + batch_first: bool, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -127,6 +143,7 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding + batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers Returns: @@ -144,6 +161,7 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ) return output @@ -161,6 +179,7 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -178,6 +197,17 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] + x_bdim = None + if x.ndim > 2: + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `layernorm_dense()` fused-layer implementation does not officially " + "support sequence-first inputs and may produce incorrect results when " + "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " + "inputs at your own discretion." + ) + x_bdim = 0 if batch_first else x.ndim - 2 + x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -187,25 +217,31 @@ def _layernorm_dense_fwd_rule( zero_centered_gamma, epsilon, norm_type, - quantizer_set.x, + quantizer=quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out...) + use_bias = bias is not None output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) @@ -224,6 +260,7 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) return output, ctx @@ -236,6 +273,7 @@ def _layernorm_dense_bwd_rule( layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument kernel_axes, + batch_first, # pylint: disable=unused-argument ctx, grad, ): @@ -265,10 +303,15 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -284,7 +327,8 @@ def _layernorm_dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - (g_constracting_dim, k_constracting_dim), + contracting_dims=(g_constracting_dim, k_constracting_dim), + batched_dims=((x_bdim,), ()), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -297,7 +341,8 @@ def _layernorm_dense_bwd_rule( wgrad = tex.gemm( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), - (x_constracting_dim, g_constracting_dim), + contracting_dims=(x_constracting_dim, g_constracting_dim), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 18563fd255..507c49c7e9 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -13,6 +13,7 @@ quantization, and distributed training through sharding constraints. """ +import warnings from typing import List, Tuple, Sequence, Union, Callable from functools import partial @@ -31,6 +32,16 @@ from .sharding import get_non_contracting_logical_axes +LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False + + +def _issue_batch_first_warning(msg): + global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED + if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED: + warnings.warn(msg, UserWarning) + LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True + + def layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -48,6 +59,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + batch_first: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -79,6 +91,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -124,12 +137,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -149,6 +163,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + batch_first: bool, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -174,6 +189,7 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of quantizer sets Returns: @@ -198,6 +214,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ) return output @@ -222,6 +239,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -254,6 +272,17 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + x_bdim = None + if x.ndim > 2: + if not batch_first: + _issue_batch_first_warning( + "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially " + "support sequence-first inputs and may produce incorrect results when " + "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " + "inputs at your own discretion." + ) + x_bdim = 0 if batch_first else x.ndim - 2 + use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -267,17 +296,23 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_in, hidden_out) dot_1_output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias_1 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) if dot_1_input_axes is not None and kernel_1_axes is not None: @@ -287,7 +322,7 @@ def _layernorm_mlp_fwd_rule( ) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - if use_bias_1: + if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) @@ -295,21 +330,28 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + contracting_dims=(x_contracting_dims, k_contracting_dims), + batched_dims=((x_bdim,), ()), + bias=bias_2 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, ) - if use_bias_2: + if use_bias_2 and tex.gemm_uses_jax_dot(): bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) @@ -334,6 +376,7 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) return dot_2_output, ctx @@ -351,6 +394,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, ctx, grad, ): @@ -367,7 +411,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first ( x, mu, @@ -386,6 +430,7 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -394,7 +439,7 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -411,7 +456,8 @@ def _layernorm_mlp_bwd_rule( dgrad_2 = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, - (g_contracting_dims_2, k_contracting_dims_2), + contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + batched_dims=((x_bdim,), ()), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -425,7 +471,8 @@ def _layernorm_mlp_bwd_rule( wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - (x_contracting_dims, g_contracting_dims), + contracting_dims=(x_contracting_dims, g_contracting_dims), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -435,6 +482,7 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -451,7 +499,8 @@ def _layernorm_mlp_bwd_rule( dgrad_1 = tex.gemm( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, - (g_contracting_dims_1, k_contracting_dims_1), + contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + batched_dims=((x_bdim,), ()), ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -461,7 +510,8 @@ def _layernorm_mlp_bwd_rule( wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - (x_contracting_dims, g_contracting_dims), + contracting_dims=(x_contracting_dims, g_contracting_dims), + batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 06a2562fb1..2459190f1a 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -36,6 +36,22 @@ def dequantize(scaled_tensor): """Dequantizing given tensor to higher precision.""" +@dataclass +class NoopDequantizer(Dequantizer): + """No-op Dequantizer Class""" + + @staticmethod + def _dequantize_func(data, *args, **kwargs): + """A no-op dequantize function that returns the data without any changes.""" + del args, kwargs + return data + + @staticmethod + def dequantize(scaled_tensor): + """A no-op dequantize function that simply returns the data array in the ScaledTensor.""" + return scaled_tensor.data + + class TensorScaleDequantizer(Dequantizer): """ TensorScaling Dequantizer Class @@ -152,6 +168,7 @@ def dequantize(scaled_tensor): ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NO_SCALING: NoopDequantizer, } diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c0617eafbb..122265ea27 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -9,7 +9,9 @@ """ from contextlib import contextmanager from enum import Enum -from typing import Optional, Tuple, Dict, Union +from typing import Optional, Tuple, Dict, Union, Sequence +from functools import reduce +import operator import jax import jax.numpy as jnp @@ -29,6 +31,8 @@ "is_fp8_available", "update_collections", "get_delayed_scaling", + "apply_padding_to_scale_inv", + "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", ] @@ -471,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection: return new_coll +def remove_padding_from_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Slice padding out of padded inverse scale factors. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Inverse scale factor without padding. + """ + # Get expected unpadded scale shape and check if inverse scale already matches + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: + return scale_inv + + # Get the padded scale shape and make sure inverse scale matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=True, + flatten_axis=flatten_axis, + ) + assert scale_inv.shape == padded_scale_shape, ( + f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got " + f"{scale_inv.shape} instead." + ) + + # Reshape scale inverse to 2D in two stages to preserve the flatten axis + padded_scale_shape_2d = ( + reduce(operator.mul, padded_scale_shape[:flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis:]), + ) + scale_inv_2d = jnp.reshape( + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])), + padded_scale_shape_2d, + ) + + # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape + unpadded_scale_shape_2d = ( + reduce(operator.mul, unpadded_scale_shape[:flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis:]), + ) + scale_inv_2d_unpadded = jnp.asarray( + scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + ) + + # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis + scale_inv_unpadded = jnp.reshape( + jnp.reshape( + scale_inv_2d_unpadded, + (*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]), + ), + unpadded_scale_shape, + ) + return scale_inv_unpadded + + +def apply_padding_to_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Pad the scale inverse with zeros to match the necessary padded shape for this scaling + mode. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Padded inverse scale factor. + """ + # Get the expected padded scale shape and check if inverse scale already matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape: + return scale_inv + + # Get the expected unpadded scale shape and make sure inverse scales match + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) + + # Pad the scales with the lowest representable value (2^-127) and return + pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) + return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index f45a05a399..fc4fd13531 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -17,7 +17,7 @@ import operator import numpy as np -from jax.experimental.custom_partitioning import CompoundFactor +from jax.experimental.custom_partitioning import BATCHING from jax.tree_util import register_pytree_node_class import jax.numpy as jnp @@ -252,8 +252,9 @@ def get_shardy_sharding_rules( The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"x{i}" for i in range(input_rank)) - return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -488,31 +489,41 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - input_spec = [f"x{i}" for i in range(input_rank)] - - # We have to use two different factors in the two CompoundFactors because of Shardy - # verifier requirements, even though they are the same. - rowwise_var = unique_var - colwise_var = f"{unique_var}_" - input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") - - # The rowwise and colwise scale tensors should be sharded the same way as the input. - # However, we need to adjust the dimensions where the block scaling factor applies. - rowwise = input_spec.copy() - rowwise[-1] = rowwise_var - - colwise = input_spec.copy() - colwise[flatten_axis - 1] = colwise_var - - # This implementation needs to be updated for different block dims. - assert self._block_dims == (1, 32) + del flatten_axis + input_spec = [f"{unique_var}{i}" for i in range(input_rank)] + rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] + colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] + + # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. + # Unfortunately, because Shardy rules are applied to the inner primitive, the + # only way to preserve the relationship is to lower unpadded scales to the + # underlying custom call and pad them in C++. Until that's implemented, the + # Shardy rules for block scales have to be completely disconnected from the + # Shardy rules for the tensor they belong to. + + # # We have to use two different factors in the two CompoundFactors because of Shardy + # # verifier requirements, even though they are the same. + # rowwise_var = unique_var + # colwise_var = f"{unique_var}_" + # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") + # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") + + # # The rowwise and colwise scale tensors should be sharded the same way as the input. + # # However, we need to adjust the dimensions where the block scaling factor applies. + # rowwise = input_spec.copy() + # rowwise[-1] = rowwise_var + + # colwise = input_spec.copy() + # colwise[flatten_axis - 1] = colwise_var + + # # This implementation needs to be updated for different block dims. + # assert self._block_dims == (1, 32) return QuantizeShardyRules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {"block_size_rowwise": 32, "block_size_colwise": 32}, + {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, ) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index b2156df95f..d454f8f43a 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -17,6 +17,7 @@ from transformer_engine_jax import QuantizeLayout +from .helper import apply_padding_to_scale_inv from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap from ..sharding import ( @@ -56,6 +57,11 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, *aux_data) + @property + @abstractmethod + def ndim(self): + """Number of dimensions of the underlying quantized array.""" + @abstractmethod def dequantize(self): """Dequantizes the tensor back to its original precision. @@ -127,24 +133,16 @@ def __post_init__(self): 0 < self.flatten_axis < len(self.data.shape) ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis - ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) - ) - # padding with the smallest number it can present - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127 + if self.scaling_mode == ScalingMode.NO_SCALING: + self.scale_inv = jnp.empty((0,), dtype=jnp.float32) + + else: + self.scale_inv = apply_padding_to_scale_inv( + self.scale_inv, + self.scaling_mode, + self.data.shape, + is_colwise=self.is_colwise, + flatten_axis=self.flatten_axis, ) def tree_flatten(self): @@ -164,6 +162,10 @@ def tree_flatten(self): ) return (children, aux_data) + @property + def ndim(self): + return self.data.ndim + def dequantize(self): """Dequantizes the tensor using the stored dequantization function. @@ -347,6 +349,11 @@ def tree_flatten(self): aux_data = () return (children, aux_data) + @property + def ndim(self): + """Number of dimensions of the underlying row-wise tensor.""" + return self.rowwise_tensor.ndim + def dequantize(self): """Dequantizes the tensor using the row-wise component's dequantization.