From 1971658bcc21214db734da5c2f3043160108a2ff Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Thu, 10 Apr 2025 01:48:06 +0000 Subject: [PATCH 1/3] JAX collective GEMM without compute/communication overlap. Signed-off-by: Philipp Hack --- tests/jax/test_distributed_gemm.py | 307 +++++++ .../common/util/pybind_helper.h | 30 +- transformer_engine/jax/__init__.py | 7 +- .../jax/cpp_extensions/__init__.py | 3 +- transformer_engine/jax/cpp_extensions/gemm.py | 754 ++++++++++++++- transformer_engine/jax/cpp_extensions/misc.py | 7 + .../jax/cpp_extensions/quantization.py | 8 + .../jax/cpp_extensions/transpose.py | 855 ++++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 12 + .../jax/csrc/extensions/gemm.cpp | 108 +++ .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/fp8.py | 262 ++++++ transformer_engine/jax/gemm.py | 484 ++++++++++ 13 files changed, 2816 insertions(+), 22 deletions(-) create mode 100644 tests/jax/test_distributed_gemm.py create mode 100644 transformer_engine/jax/cpp_extensions/transpose.py create mode 100644 transformer_engine/jax/fp8.py create mode 100644 transformer_engine/jax/gemm.py diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py new file mode 100644 index 0000000000..60f0d7b6f6 --- /dev/null +++ b/tests/jax/test_distributed_gemm.py @@ -0,0 +1,307 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest +from functools import partial +from collections.abc import Iterable + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.jax.gemm import gemm +from transformer_engine.jax.quantize import helper + +from utils import assert_allclose + + +jax.config.update("jax_enable_compilation_cache", False) + + +# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) +# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128) +# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P) + +# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128) +# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P) +# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128) + +BATCH = 4 +BASE_SIZE = 16 +SEQ_LEN = BASE_SIZE * 8 +HIDDEN_SIZE = BASE_SIZE * 6 +FFN_HIDDEN_SIZE = BASE_SIZE * 16 + +COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"] +MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"] +NUM_DEVICES = 4 + +is_fp8_supported, no_fp8_reason = helper.is_fp8_available() + + +def _get_mesh(parallel_dist): + jax.clear_caches() + + batched = False + fsdp = False + mesh_shape = dict(tp=NUM_DEVICES) + resources = dict(cp_resource="tp", tp_resource="tp") + if parallel_dist in ["DP_TP", "FSDP_TP"]: + batched = True + tp = NUM_DEVICES // 2 + dp = NUM_DEVICES // tp + mesh_shape.update(dict(tp=tp, dp=dp)) + resources.update(dict(dp_resource="dp")) + if parallel_dist == "FSDP_TP": + fsdp = True + dp = 1 + zp = NUM_DEVICES // tp + mesh_shape.update(dict(tp=tp, dp=1, zp=zp)) + resources.update(dict(fsdp_resource="zp")) + mesh_resource = te.MeshResource(**resources) + + devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) + + mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) + + return mesh, mesh_resource, batched, fsdp + + +def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False): + fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + # Operand and output shapes + lhs_shape = ( + [SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] + ) + rhs_shape = ( + [HIDDEN_SIZE, FFN_HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [FFN_HIDDEN_SIZE, HIDDEN_SIZE] + ) + out_shape = [lhs_shape[0], rhs_shape[1]] + + if batched: + lhs_shape = [BATCH] + lhs_shape + out_shape = [BATCH] + out_shape + + # Operand and output partition specs + lhs_spec = ( + [mesh_resource.tp_resource, None] + if fwd_comm_type == "ALL_GATHER" + else [None, mesh_resource.tp_resource] + ) + rhs_spec = ( + [None, mesh_resource.tp_resource] + if fwd_comm_type == "ALL_GATHER" + else [mesh_resource.tp_resource, None] + ) + out_spec = [None, rhs_spec[-1]] + + # Modify RHS operand for FP8 + fsdp_gathered_rhs_spec = rhs_spec.copy() + if fp8_gemm: + rhs_shape = list(reversed(rhs_shape)) + rhs_spec = list(reversed(rhs_spec)) + fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec)) + + # Add batch dimensions and specs + if batched: + if fsdp: + lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec + rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec] + out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec + else: + lhs_spec = [mesh_resource.dp_resource] + lhs_spec + out_spec = [mesh_resource.dp_resource] + out_spec + + # Allocate global operands on device + key = jax.random.PRNGKey(42) + split_keys = jax.random.split(key, 3 if fwd_bwd else 2) + mu = 0.0 + sigma = 0.023 + shapes = (lhs_shape, rhs_shape) + if fwd_bwd: + shapes += (out_shape,) + global_operands = list( + map( + lambda key, shape: jax.device_put( + mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), + NamedSharding(mesh, PartitionSpec(None)), + ), + split_keys, + shapes, + ) + ) + + # Allocate sharded operands on device + partition_axes = (lhs_spec, rhs_spec) + if fwd_bwd: + partition_axes += (out_spec,) + local_operands = list( + map( + lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), + global_operands, + partition_axes, + ) + ) + + # Tranpose global RHS back to non-transpoosed orientation if it was originally allocated + # for FP8 GEMM + if fp8_gemm: + rhs_global = jnp.matrix_transpose(global_operands[1]) + global_operands = (global_operands[0], rhs_global, *global_operands[2:]) + + return ( + local_operands, + global_operands, + (out_shape, out_spec), + fsdp_gathered_rhs_spec, + ) + + +def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False): + num_operands = 3 if fwd_bwd else 2 + ref_operands = tensors[:num_operands] + test_outputs = tensors[num_operands:] + + # Check number of dimensions + assert test_outputs[0].ndim == len(expected_out_shape), ( + f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected " + + f"({len(expected_out_shape)})" + ) + + # Pad test output spec for unsharded dimensions + test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim) + + for i in range(test_outputs[0].ndim): + # Check shape + assert test_outputs[0].shape[i] == expected_out_shape[i], ( + f"Output with shape {test_outputs[0].shape} does not match expected shape " + + f"{expected_out_shape} in dimension index {i}." + ) + + # Check shardings (with padded output spec) + spec_mismatch = False + if isinstance(expected_out_specs[i], str): + if test_spec[i] != expected_out_specs[i]: + spec_mismatch = True + elif isinstance(expected_out_specs[i], Iterable): + if not isinstance(test_spec[i], type(expected_out_specs[i])): + if test_spec[i] not in expected_out_specs[i]: + spec_mismatch = True + elif len(test_spec[i]) != len(expected_out_specs[i]): + spec_mismatch = True + else: + for j in range(len(expected_out_specs[i])): + if test_spec[i][j] != expected_out_specs[i][j]: + spec_mismatch = True + break + elif expected_out_specs[i] == None: + if test_spec[i] != None: + spec_mismatch = True + else: + raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.") + if spec_mismatch: + raise AssertionError( + f"Output sharding {test_spec} does not match expected sharding " + + f"{expected_out_specs} in dimension index {i}." + ) + + def _native_gemm_fwd_bwd(lhs, rhs, grad): + fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs) + lhs_grad, rhs_grad = vjp_fn(grad) + return fwd_out, lhs_grad, rhs_grad + + ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot) + + out_names = ["output"] + ref_outputs = ref_fn(*ref_operands) + if not fwd_bwd: + ref_outputs = [ref_outputs] + else: + out_names += ["dgrad", "wgrad"] + + for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)): + test_out_global = jax.lax.with_sharding_constraint( + test_out, NamedSharding(mesh, PartitionSpec(None)) + ) + try: + assert_allclose(ref_out, test_out_global) + except AssertionError as err: + raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err)) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_impl(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) + + @jax.jit + def _test_fn(lhs, rhs): + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + return te.cpp_extensions.collective_gemm_impl(lhs, rhs_no_fsdp, batched_output=batched) + + with te.sharding.global_shard_guard(mesh_resource): + output, *_ = _test_fn(*local_operands) + + _check_output(mesh, *output_info, *global_operands, output) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_fwd_bwd(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) + + @jax.jit + def _test_fn(lhs, rhs, grad): + # Gather weights in FSDP axis + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + + # FWD pass + fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp) + + # BWD pass + lhs_grad, rhs_grad = vjp_fn(grad) + + return fwd_out, lhs_grad, rhs_grad + + print( + f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n" + + f" LHS sharding: {local_operands[0].sharding.spec}\n" + + f" RHS sharding: {local_operands[1].sharding.spec}\n" + ) + + with te.sharding.global_shard_guard(mesh_resource): + output, dgrad, wgrad = _test_fn(*local_operands) + + print( + f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: " + + f"{output.shape} | {output.sharding.spec}\n" + + f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n" + + f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n" + ) + + _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b8c8df37ee..f481ff39ab 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -91,30 +91,30 @@ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ - py::call_guard()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + pybind11::call_guard()) \ .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ - py::call_guard()); \ - py::class_()); \ + pybind11::class_, \ transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ - py::call_guard()); \ - py::class_()); \ + pybind11::class_, \ transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ - py::call_guard()); \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + pybind11::call_guard()); \ m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def( \ "get_stream_priority_range", \ [](int device_id = -1) { \ @@ -122,8 +122,8 @@ transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ return std::make_pair(low_pri, high_pri); \ }, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index ab56d60f59..cad5697daa 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -31,10 +31,10 @@ from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +module_name = "transformer_engine_jax" def _load_library(): """Load shared library with Transformer Engine C extensions""" - module_name = "transformer_engine_jax" if is_package_installed(module_name): assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." @@ -79,7 +79,9 @@ def _load_library(): spec.loader.exec_module(solib) -_load_library() +if module_name not in sys.modules: + _load_library() + from . import flax from . import quantize @@ -101,7 +103,6 @@ def _load_library(): ) __all__ = [ - "fp8_autocast", "MeshResource", "MajorShardingType", "ShardingResource", diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index ef8d76cd05..1afc172c9a 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -4,7 +4,8 @@ """Python interface for c++ extensions""" from .activation import * from .attention import * +from .gemm import * from .normalization import * from .quantization import * from .softmax import * -from .gemm import * +from .transpose import * diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0fad75817f..ab35ad1f54 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,15 +3,31 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict, List -from functools import partial, reduce +import warnings import operator +from functools import partial, reduce +from typing import Optional, Tuple, Sequence, Union, Dict, List +from packaging import version + from transformer_engine_jax import get_device_compute_capability import jax import jax.numpy as jnp +from jax import dtypes +from jax.sharding import PartitionSpec, NamedSharding +from jax.typing import ArrayLike from .base import BasePrimitive, register_primitive +from .misc import ( + jax_dtype_is_fp8, + get_padded_spec, + check_valid_batch_dims, +) +from ..sharding import ( + global_mesh_resource, + all_reduce_max_along_all_axes_except_PP, +) + from ..quantize import ( ScaledTensor, ScalingMode, @@ -20,12 +36,27 @@ noop_quantizer_set, ) +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + -__all__ = ["gemm", "grouped_gemm"] +__all__ = ["gemm", + "grouped_gemm", + "collective_fp8_gemm_impl", + "collective_gemm_impl"] num_cublas_streams = 4 +def sanitize_dims(dim, ndims): + return (ndims + dim) if dim < 0 else dim + + +def mirror_dim(dim, ndims): + return ndims - 2 if dim == ndims - 1 else ndims - 1 + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" @@ -514,3 +545,720 @@ def grouped_gemm( out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) return out_tensors +class CollectiveGemmPrimitive(BasePrimitive): + """ + cuBlasLt GEMM Primitive w/ support for distributed inputs + """ + + name = "te_gemm" + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) + multiple_results = True + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs_aval, + lhs_scale_inv_aval, + rhs_aval, + rhs_scale_inv_aval, + bias_aval, + gelu_input_aval, + out_amax_aval, + out_scale_aval, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + """ + cuBlasLt GEMM abstract + """ + del grad, accumulate, use_split_accumulator + + # Validate operand dtypes + lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) + rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) + assert lhs_dtype == rhs_dtype, "Mismatched matrix dtypes for GEMM." + is_fp8 = False + if jax_dtype_is_fp8(lhs_dtype): + assert ( + lhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(lhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing LHS operand scale inverse in FP8 GEMM." + is_fp8 = True + if jax_dtype_is_fp8(rhs_dtype): + assert ( + rhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing RHS operand scale inverse in FP8 GEMM." + + # Validate operand layouts + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) + assert ( + lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] + ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." + + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + assert not ( + lhs_trans and rhs_trans + ), "GEMM does not support transposed LHS and transposed RHS at the same time." + if is_fp8: + assert not lhs_trans, "FP8 GEMM does not support transposed LHS." + assert rhs_trans, "FP8 GEMM requires transposed RHS." + + # Validate output dtype + if jax_dtype_is_fp8(out_dtype): + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." + else: + out_dtype = lhs_dtype + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Make sure leading dimensions of RHS is broadcast-compatible with LHS + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim), + ) + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + + # Infer output shape + if batched_output: + assert ( + lhs_aval.ndim > 2 and rhs_aval.ndim == 2 + ), "Batched output requires batched LHS and non-batched RHS operands." + out_shape = ( + *lhs_batch_shape, + lhs_aval.shape[lhs_outer_dim], + rhs_aval.shape[rhs_outer_dim], + ) + else: + assert ( + lhs_aval.ndim == rhs_aval.ndim + ), "Non-batched output requires LHS and RHS operands with same number of dimensions." + if lhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + assert lhs_batch_size == rhs_batch_size, ( + f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_aval.shape=})." + ) + out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + + # Validate bias/bias_grad shape against inferred output + bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype + if fuse_bias: + assert ( + bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] + ), "Incorrect bias shape." + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + else: + assert bias_aval.size == 0, "Internal TE error." + + # Validate GELU input/output + gelu_shape = (0,) + if fuse_gelu: + gelu_shape = ( + (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) + if len(out_shape) > 2 + else out_shape + ) + assert gelu_input_aval.ndim == 2 and all( + gelu_input_aval.shape[i] == gelu_shape[i] for i in range(len(gelu_shape)) + ), "Invalid GELU input shape." + assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." + else: + assert gelu_input_aval.size == 0, "Internal TE error." + + # Create abstract arrays for all outputs + out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) + out_amax_updated_aval = out_amax_aval.update( + shape=out_amax_aval.shape, dtype=out_amax_updated_dtype + ) + out_scale_updated_aval = out_scale_aval.update( + shape=out_scale_aval.shape, dtype=out_scale_updated_dtype + ) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) + bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + workspace_aval = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) + + return ( + out_aval, + out_amax_updated_aval, + out_scale_updated_aval, + pre_gelu_out_aval, + bias_grad_aval, + workspace_aval, + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + cuBlasLt GEMM outer abstract + """ + (out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval, _) = ( + CollectiveGemmPrimitive.abstract(*args, **kwargs) + ) + return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval + @staticmethod + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + *, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + """ + Fused attention fwd lowering rules + """ + del batched_output + lhs_aval, _, rhs_aval, _, _, *_ = ctx.avals_in + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + + operand_output_aliases = { + 4: 4, # bias <--> bias_grad + 5: 3, # gelu_input <--> pre_gelu_out + 6: 1, # out_amax <--> out_amax_updated + 7: 2, # out_scale <--> out_scale_updated + } + + name = "te_gemm_ffi" + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + lhs_trans=lhs_trans, + rhs_trans=rhs_trans, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + + @staticmethod + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + assert CollectiveGemmPrimitive.inner_primitive is not None + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim) + ) + + # Infer output shape and collapse batch dimensions + lhs_2d_shape = rhs_2d_shape = None + lhs_layout = rhs_layout = None + lhs_batch_dims = [ + dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim] + ] + lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + contracting_dims_2d = list(contracting_dims).copy() + if lhs.ndim > 2 and rhs.ndim > 2: + # If both LHS and RHS are batched, the batch dimensions collapse into the + # contracting dimensions for both operands + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) + lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) + contracting_dims_2d[0] = 0 + + rhs_batch_dims = [ + dim for dim in range(rhs.ndim) if dim not in [rhs_inner_dim, rhs_outer_dim] + ] + rhs_batch_shape = [rhs.shape[dim] for dim in rhs_batch_dims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) + rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) + contracting_dims_2d[1] = 0 + elif lhs.ndim > 2: + # If only the LHS is batched,the batch dimension collapses into the outer dimension + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 + + # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM + if lhs_2d_shape is not None and lhs.ndim > 2: + lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout) + if jax_dtype_is_fp8(lhs.dtype): + lhs = jax.lax.transpose(lhs, (1, 0)) + contracting_dims_2d[0] = 1 + else: + contracting_dims_2d[0] = contracting_dims[0] + + if rhs_2d_shape is not None and rhs.ndim > 2: + rhs = jax.lax.reshape(rhs, rhs_2d_shape, dimensions=rhs_layout) + if jax_dtype_is_fp8(rhs.dtype): + rhs = jax.lax.transpose(rhs, (1, 0)) + contracting_dims_2d[1] = 1 + else: + contracting_dims_2d[1] = contracting_dims[1] + + # Invoke GEMM with guaranteed 2D inputs, so batched_output=False + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + _, + ) = CollectiveGemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + batched_output=False, + contracting_dims=contracting_dims_2d, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Recover batched dimensions in the output + if batched_output: + out_shape = (*lhs_batch_shape, out.shape[-2] // lhs_batch_size, out.shape[-1]) + out = jax.lax.reshape(out, out_shape) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + assert CollectiveGemmPrimitive.outer_primitive is not None + check_valid_batch_dims(batch_dims) + lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims + + return ( + CollectiveGemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ), + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims), + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del out_dtype, accumulate, use_split_accumulator, result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs: + # - If contracting dimensions of both operands are sharded, force them to match. + # - If contracting dimensions of both operands are sharded, all-gather outer dimensions. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." + if lhs_spec_new[lhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the LHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + + if rhs_spec_new[rhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the RHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + rhs_spec_new[rhs_outer_dim] = None + else: + if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: + warnings.warn( + "Contracting dimension of the RHS operand must be all-gathered when the " + + "contracting dimension of the LHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: + if not grad: + # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + warnings.warn( + "Contracting dimension of the LHS operand must be all-gathered when the " + + "contracting dimension of the RHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + out_spec = [None, out_col_spec] + if batched_output: + out_spec = batch_spec + out_spec + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Bias gradient spec matches outer dimension of output if bias fusion is turned on + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) + + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) + + @staticmethod + def partition( + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs: + # - Always all-gather the outer dimension of LHS. + # - If contracting dimensions of both operands are sharded, all-gather RHS outer dimension. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + reduce_output = False + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) + + # Bias is sharded to match outer dimension spec of the RHS operand (also the output) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + out_spec = [None, out_col_spec] + if batched_output: + out_spec = batch_spec + out_spec + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + arg_shardings = ( + lhs_sharding, + fp8_meta_sharding, + rhs_sharding, + fp8_meta_sharding, + bias_sharding, + gelu_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + ) + out_shardings = ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + ) + + def sharded_impl( + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale + ): + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + ) = CollectiveGemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # FP8 amax reduction + if jax_dtype_is_fp8(lhs.dtype): + out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) + + # All-reduce sum GEMM output when contracting dimensions are sharded + if reduce_output: + out = jax.lax.psum(out, global_mesh_resource().tp_resource) + if fuse_gelu: + pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CollectiveGemmPrimitive) + + +def collective_fp8_gemm_impl( + lhs: ArrayLike, + lhs_scale_inv: ArrayLike, + rhs_t: ArrayLike, + rhs_scale_inv: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + batched_output: bool = False, + fuse_gelu: bool = False, + fuse_bias: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + if out_dtype is not None and jax_dtype_is_fp8(out_dtype): + assert out_amax is not None and out_scale is not None, "Missing output amax and scale." + else: + out_amax = jnp.zeros(0, dtype=jnp.float32) + out_scale = jnp.zeros(0, dtype=jnp.float32) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=jnp.bfloat16) + else: + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=bias.dtype) + elif gelu_input is None: + gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) + gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) + + out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + lhs_scale_inv, + rhs_t, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=(-1, -1), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + return out, out_amax, out_scale, pre_gelu_out + + +def collective_gemm_impl( + lhs: ArrayLike, + rhs: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + batched_output: bool = False, + contracting_dims: Tuple[int, int] = (-1, -2), + fuse_gelu: bool = False, + fuse_bias: bool = False, + grad: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + bias = jnp.zeros(rhs.shape[rhs_outer_dim], dtype=lhs.dtype) + else: + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + elif gelu_input is None: + bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_size = reduce(operator.mul, [lhs.shape[dim] for dim in bdims], 1) + gelu_shape = (batch_size * lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(gelu_shape, dtype=lhs.dtypes) + + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) + out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + dummy_fp8_meta, + rhs, + dummy_fp8_meta, + bias, + gelu_input, + dummy_fp8_meta, + dummy_fp8_meta, + out_dtype=lhs.dtype, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if grad: + return out, pre_gelu_out, bias_grad + return out, pre_gelu_out, None diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 980ea556bb..d383e75143 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -161,6 +161,13 @@ def jax_version_meet_requirement(version: str): return jax_version >= jax_version_required +def jax_dtype_is_fp8(dtype): + """ + Check if the given jax.numpy.dtype is an FP8 dtype. + """ + return dtypes.canonicalize_dtype(dtype) in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + def get_xla_flag(flag: str, default=None, cast=str): """ Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 551b4b4bdb..9030387eb9 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -34,6 +34,14 @@ __all__ = ["quantize", "quantize_dbias"] +def _jax_cast_fp8(inputs, scale, amax, out_dtype): + """ + JAX native fp8 casting implementation + """ + casted_output = _jax_quantize(inputs, scale, dq_dtype=out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) + return casted_output, updated_amax + class DBiasQuantizePrimitive(BasePrimitive): """ diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py new file mode 100644 index 0000000000..b7e6ac9ada --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -0,0 +1,855 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for transpose""" +from functools import partial, reduce +from typing import Tuple, Sequence, Union, Callable +import operator +from packaging import version + +import jax +import jax.numpy as jnp +from jax import dtypes +from jax.sharding import PartitionSpec, NamedSharding + +import transformer_engine_jax +from transformer_engine_jax import DType as TEDType + +from .base import BasePrimitive, register_primitive +from .misc import ( + check_valid_batch_dims, + jax_dtype_to_te_dtype, + te_dtype_to_jax_dtype, + get_padded_spec, + multidim_transpose, + normalize_axis_boundary, +) +from .activation import ActivationEnum +from .activation import _jax_act_lu +from .quantization import _jax_cast_fp8 +from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp + +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + + +__all__ = [ + "transpose", + "cast_transpose", + "dbias_cast_transpose", + "dact_lu_dbias_cast_transpose", +] + + +def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary): + """ + JAX native transpose implementation + """ + axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary) + return jnp.transpose(inputs, axes=axes) + + +def _jax_cast_transpose( + inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary +): + """ + JAX native cast_transpose implementation + """ + casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype) + casted_transposed_output = _jax_transpose( + casted_output, static_axis_boundary, transpose_axis_boundary + ) + return casted_output, casted_transposed_output, updated_amax + + +def _jax_dbias_cast_transpose( + dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary +): + """ + JAX native dbias_cast_transpose implementation + """ + casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( + dz, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + dbias = jnp.sum( + dz, + axis=tuple( + range( + transpose_axis_boundary + if transpose_axis_boundary > 0 + else transpose_axis_boundary + dz.ndim + ) + ), + keepdims=False, + ) + dbias = dbias.ravel() # C++ function returns an 1D array for dbias + return casted_dz, cast_transposed_dz, dbias, updated_amax + + +class TransposePrimitive(BasePrimitive): + """ + Transpose Primitive + """ + + name = "te_transpose" + multiple_results = False + impl_static_args = (1, 2) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary): + """ + _transpose abstract + """ + transposed_x_shape = multidim_transpose( + x_aval.shape, static_axis_boundary, transpose_axis_boundary + ) + xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype) + + return xt_aval + + @staticmethod + def lowering(ctx, x, *, transpose_axis_boundary): + """ + _transpose cuda lowering + """ + + x_aval = ctx.avals_in[0] + assert x_aval.dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ] + + name = "te_transpose_ffi" + return ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary) + + @staticmethod + def impl(x, static_axis_boundary, transpose_axis_boundary): + """ + tcast_transpose implementation + """ + assert TransposePrimitive.inner_primitive is not None + transposed_x = TransposePrimitive.inner_primitive.bind( + x, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + return transposed_x + + @staticmethod + def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary): + check_valid_batch_dims(batch_dims) + assert TransposePrimitive.outer_primitive is not None + assert static_axis_boundary < 0 + + (x,) = batched_args + (x_bdim,) = batch_dims + + # Minus batch dim. + transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + + out_bdims = x_bdim + return ( + TransposePrimitive.outer_primitive.bind( + x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary + ), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands( + static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + return transposed_x_sharding + + @staticmethod + def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = transposed_x_sharding + + impl = partial( + TransposePrimitive.impl, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(TransposePrimitive) + + +def transpose( + x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int +) -> jnp.ndarray: + """ + transpose wrapper + """ + if not TransposePrimitive.enabled(): + return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary) + return TransposePrimitive.outer_primitive.bind( + x, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + + +class CastTransposePrimitive(BasePrimitive): + """ + Cast Transpose Primitive + """ + + name = "te_cast_transpose" + multiple_results = True + impl_static_args = (4, 5, 6) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary + ): + """ + te_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + + transposed_x_shape = multidim_transpose( + x_aval.shape, static_axis_boundary, transpose_axis_boundary + ) + + casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + return casted_x_aval, casted_xt_aval, updated_amax_aval + + @staticmethod + def lowering( + ctx, x, amax, scale, scale_inv, *, transpose_axis_boundary + ): + """ + te_cast_transpose_p lowering rules + """ + x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + name = "te_cast_transpose_ffi" + return ffi.ffi_lowering(name, operand_output_aliases={1: 2})( + ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary + ) + + @staticmethod + def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary): + """ + te_cast_transpose implementation + """ + assert CastTransposePrimitive.inner_primitive is not None + casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + return casted_x, casted_transposed_x, updated_amax + + @staticmethod + def batcher( + batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary + ): + check_valid_batch_dims(batch_dims) + assert CastTransposePrimitive.outer_primitive is not None + assert static_axis_boundary < 0 + + x, amax, scale, scale_inv = batched_args + x_bdim, amax_bdim, *_ = batch_dims + + # Minus batch dim. + transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + + out_bdims = x_bdim, x_bdim, amax_bdim + return ( + CastTransposePrimitive.outer_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + transpose_axis_boundary=transpose_axis_boundary, + ), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[0]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) + + @staticmethod + def partition( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) + + def sharded_impl(x, amax, scale, scale_inv): + local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) + + return local_cx, local_cxt, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CastTransposePrimitive) + + +def cast_transpose( + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, + static_axis_boundary: int, + transpose_axis_boundary: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose wrapper + Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` + """ + if not CastTransposePrimitive.enabled(): + return _jax_cast_transpose( + x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary + ) + return CastTransposePrimitive.outer_primitive.bind( + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + + +class DBiasCastTransposePrimitive(BasePrimitive): + """ + DBias Cast Transpose Primitive + """ + + name = "te_dbias_cast_transpose" + multiple_results = True + # out_dtype, static_axis_boundary, transpose_axis_boundary + impl_static_args = (4, 5, 6) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + dz_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + transpose_axis_boundary + ): + """ + te_dbias_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:]) + t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary) + out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + + dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size) + dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) + + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes( + dz_aval.size // gi_hidden_size, + gi_hidden_size, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + ) + wkspace_aval = dz_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return out, t_out, dbias, updated_amax_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + te_dbias_cast_transpose_p outer abstract + """ + + out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract( + *args, **kwargs + ) + return out, t_out, dbias, updated_amax_aval + + @staticmethod + def lowering( + ctx, dz, amax, scale, scale_inv, *, transpose_axis_boundary + ): + """ + te_dbias_cast_transpose_p lowering rules + """ + dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + name = "te_dbias_cast_transpose_ffi" + return ffi.ffi_lowering(name, operand_output_aliases={1: 3})( + ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary + ) + + @staticmethod + def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary): + """ + to describe implementation + """ + assert DBiasCastTransposePrimitive.inner_primitive is not None + out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + return out, t_out, dbias, updated_amax + + @staticmethod + def batcher( + batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary + ): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + check_valid_batch_dims(batch_dims) + assert DBiasCastTransposePrimitive.outer_primitive is not None + dz, amax, scale, scale_inv = batched_args + dz_bdim, amax_bdim, _, _ = batch_dims + + # Minus batch dim. + transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + + out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim + return ( + DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=dz_bdim, + transpose_axis_boundary=transpose_axis_boundary, + ), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) + + @staticmethod + def partition( + out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos + ): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) + + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = ( + casted_x_sharding, + casted_transposed_x_sharding, + dbias_shaprding, + amax_sharding, + ) + + def sharded_impl(dz, amax, scale, scale_inv): + local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + return local_out, local_t_out, global_dbias, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(DBiasCastTransposePrimitive) + + +def dbias_cast_transpose( + dz: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, + static_axis_boundary: int, + transpose_axis_boundary: int = -1, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose dbias partial fusion wrapper + Return FP8(inputs), dbias + """ + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + + if not DBiasCastTransposePrimitive.enabled(): + return _jax_dbias_cast_transpose( + dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary + ) + + return DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + + +class DActLuDBiasCastTransposePrimitive(BasePrimitive): + """ + DActLu DBias Cast Transpose Primitive + """ + + name = "te_dact_lu_dbias_cast_transpose" + multiple_results = True + # out_dtype, static_axis_boundary, act_enum + impl_static_args = (5, 6, 7) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + dz_aval, + x_aval, + amax_aval, + scale_aval, + scale_inv_aval, + *, + out_dtype, + static_axis_boundary, + ): # pylint: disable=unused-argument + """ + te_dact_lu_dbais_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_hidden_szie = dz_aval.shape[-1] + gi_hidden_size = x_aval.shape[-1] + assert ir_hidden_szie == gi_hidden_size + t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2) + out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + + dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size) + dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) + + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( + x_aval.size // gi_hidden_size, + gi_hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + ) + wkspace_aval = x_aval.update( + shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) + ) + + return out, t_out, dbias, updated_amax_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + te_dact_lu_dbais_cast_transpose_p outer abstract + """ + + out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract( + *args, **kwargs + ) + return out, t_out, dbias, updated_amax_aval + + @staticmethod + def lowering(ctx, dz, x, amax, scale, scale_inv, *, act_enum): + """ + te_dgated_act_lu_cast_transpose_p lowering rules + """ + dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_aval.dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + name = "te_dact_lu_dbias_cast_transpose_ffi" + return ffi.ffi_lowering(name, operand_output_aliases={2: 3})( + ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum) + ) + + @staticmethod + def impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype, + static_axis_boundary, + act_enum, + ): + """ + to describe implementation + """ + assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None + out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + act_enum=act_enum, + ) + return out, t_out, dbias, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + check_valid_batch_dims(batch_dims) + assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None + dz, x, amax, scale, scale_inv = batched_args + x_bdim, _, amax_bdim, _, _ = batch_dims + + out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim + return ( + DActLuDBiasCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + act_enum=act_enum, + ), + out_bdims, + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, + static_axis_boundary, + act_enum, + mesh, + arg_infos, + result_infos, + ): + del out_dtype, result_infos, act_enum + x_spec = get_padded_spec(arg_infos[1]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) + + @staticmethod + def partition( + out_dtype, + static_axis_boundary, + act_enum, + mesh, + arg_infos, + result_infos, + ): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) + ) + + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = ( + casted_x_sharding, + casted_transposed_x_sharding, + dbias_shaprding, + amax_sharding, + ) + + def sharded_impl(dz, x, amax, scale, scale_inv): + local_out, local_t_out, local_dbias, local_amax = ( + DActLuDBiasCastTransposePrimitive.impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + act_enum=act_enum, + ) + ) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + return local_out, local_t_out, global_dbias, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(DActLuDBiasCastTransposePrimitive) + + +def dact_lu_dbias_cast_transpose( + dz: jnp.ndarray, + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, + static_axis_boundary: int, + activation_type: Sequence[Union[str, Callable]] = ("gelu",), +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose dact_lu and dbias fusion wrapper + Return FP8(dact_lu(inputs)), dbias + ONLY support non-gated activation type + """ + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + + if not DActLuDBiasCastTransposePrimitive.enabled(): + _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) + (dx,) = vjp_func(dz) + transpose_axis_boundary = -2 + return _jax_dbias_cast_transpose( + dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary + ) + + act_type_id = ActivationEnum[activation_type] + return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + act_enum=act_type_id, + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..e8d26710d8 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,18 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); // CuBLAS helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); +// GEMM + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 74909319cc..50f7023c1c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -210,5 +210,113 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode"), FFI_CudaGraph_Traits); +void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_shape, + float *lhs_scale_inv, bool lhs_trans, void *rhs, const std::vector &rhs_shape, + float *rhs_scale_inv, bool rhs_trans, DType operand_dtype, void *bias, + DType bias_dtype, void *out, float *out_amax, float *out_scale, DType out_dtype, + void *pre_gelu_out, void *workspace, size_t workspace_size, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, operand_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, operand_dtype, nullptr, nullptr, rhs_scale_inv); + + std::vector out_shape(2, 0); + out_shape[0] = (lhs_trans) ? lhs_shape[1] : lhs_shape[0]; + out_shape[1] = (rhs_trans) ? rhs_shape[0] : rhs_shape[1]; + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + void *bias_ptr = (fuse_bias) ? bias : nullptr; + std::vector bias_shape = + (fuse_bias) ? std::vector{out_shape[1]} : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + std::vector pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + // cuBLAS is column-major, so we swap LHS and RHS in the arguments + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_out_.data(), + (rhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, (lhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, workspace_.data(), accumulate, use_split_accumulator, num_math_sm, stream); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto operand_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + + // Outputs + auto out_ptr = out->untyped_data(); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->dimensions().back(); + + // Operand aliasing + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + std::vector lhs_shape(lhs.dimensions().begin(), lhs.dimensions().end()); + std::vector rhs_shape(rhs.dimensions().begin(), rhs.dimensions().end()); + + // Swap A and B argument locations to match what the TE/common kernel expects + GemmImpl(stream, lhs_ptr, lhs_shape, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, + rhs_scale_inv_ptr, rhs_trans, operand_dtype, bias_ptr, bias_dtype, out_ptr, out_amax_ptr, + out_scale_ptr, out_dtype, pre_gelu_out_ptr, workspace_ptr, workspace_size, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out_amax + .Arg() // out_scale + .Ret() // out + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index c777a02c99..35c4ae6c03 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -59,6 +59,7 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); return dict; } diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py new file mode 100644 index 0000000000..53f5fc3d96 --- /dev/null +++ b/transformer_engine/jax/fp8.py @@ -0,0 +1,262 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +Helper module for fp8 meta management +""" +from enum import Enum +from functools import partial +from typing import Dict, List, Tuple, Union + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.linen import fp8_ops + +from transformer_engine_jax import DType +from transformer_engine.common.recipe import Format + +Collection = Union[Dict, FrozenDict] + +def _format2dtypes(format_: Format): + if format_ == Format.E4M3: + return jnp.float8_e4m3fn, jnp.float8_e4m3fn + if format_ == Format.E5M2: + return jnp.float8_e5m2, jnp.float8_e5m2 + if format_ == Format.HYBRID: + return jnp.float8_e4m3fn, jnp.float8_e5m2 + return jnp.bfloat16, jnp.bfloat16 + + +# fm32 is a custom dtype to specify the "add" rules as max operation. +# This is typically used in Pipeline Parallelism + "MiconBatching > 1", +# which is implemented via nn.scan. Without this custom dtype, nn.scan +# would sum gradients from all micro-batches, and this is not the expected +# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should +# be "MAX". +FlaxFloatMeta32 = fp8_ops.fm32 + + +class FP8MetaPackage: + """ + A container that contains all required meta data for FP8 + """ + + NUM_OF_META: int = 4 + INPUT_IDX: int = 0 + WEIGHT_IDX: int = 1 + GRAD_IDX: int = 2 + OUTPUT_IDX: int = 3 + + def __init__( + self, + input_amax: jnp.ndarray, + input_scale: jnp.ndarray, + weight_amax: jnp.ndarray, + weight_scale: jnp.ndarray, + grad_amax: jnp.ndarray, + grad_scale: jnp.ndarray, + output_amax: jnp.ndarray, + output_scale: jnp.ndarray, + ) -> None: + + self._amax_list = [None] * FP8MetaPackage.NUM_OF_META + self._scale_list = [None] * FP8MetaPackage.NUM_OF_META + + self._amax_list[FP8MetaPackage.INPUT_IDX] = input_amax + self._scale_list[FP8MetaPackage.INPUT_IDX] = input_scale + self._amax_list[FP8MetaPackage.WEIGHT_IDX] = weight_amax + self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale + self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax + self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale + self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax + self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale + + @property + def amax_list(self) -> List[jnp.ndarray]: + """ + Get the amax list of this package. + """ + return self._amax_list + + @property + def scale_list(self) -> List[jnp.ndarray]: + """ + Get the scale list of this package. + """ + return self._scale_list + + @staticmethod + def update_amax_list(amax_list: List[jnp.ndarray]) -> jnp.ndarray: + """ + Update the amax history list + """ + updated_amax_list = [FP8Helper.update_amax_history(amax) for amax in amax_list] + return updated_amax_list + + @staticmethod + def update_fp8_scale( + amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], fp8_dtype_list: List[DType] + ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: + """ + Get update scale and scale_inv list + """ + update_scale_list = [] + update_scale_inv_list = [] + for amax, scale, fp8_dtype in zip(amax_list, scale_list, fp8_dtype_list): + upadted_scale, updated_scale_inv = FP8Helper.update_fp8_scale(amax, scale, fp8_dtype) + update_scale_list.append(upadted_scale) + update_scale_inv_list.append(updated_scale_inv) + return update_scale_list, update_scale_inv_list + + +class AmaxComputeAlgo(Enum): + """AmaxComputeAlgo.""" + + MAX = "max" + MOST_RECENT = "most_recent" + + +NVTE_FP8_COLLECTION_NAME = "fp8_metas" + + +class FP8Helper: + """ + FP8 helper to manage the FP8 meta + """ + + INITIALIZED = False + MARGIN: float = 0.0 + FP8_FORMAT: Format = Format.HYBRID + FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0] + BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1] + AMAX_HISTORY_LEN: int = 1024 + AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX + FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME + FP8_AMAX_NAME: str = "amax" + FP8_SCALE_NAME: str = "scale" + FP8_2X_ACC_FPROP: bool = False + FP8_2X_ACC_DGRAD: bool = True + FP8_2X_ACC_WGRAD: bool = True + + @staticmethod + def is_fp8_enabled(): + """ + Indicate if fp8 training is enable or not. + """ + return FP8Helper.INITIALIZED + + @staticmethod + def initialize( + margin: float = 0.0, + fp8_format: Format = Format.HYBRID, + amax_history_len: int = 1, + amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX, + ) -> None: + """ + Initialize the FP8 meta + """ + FP8Helper.INITIALIZED = True + FP8Helper.MARGIN = margin + FP8Helper.FP8_FORMAT = fp8_format + FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT) + FP8Helper.AMAX_HISTORY_LEN = amax_history_len + FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo + FP8Helper.FP8_2X_ACC_FPROP = False + FP8Helper.FP8_2X_ACC_DGRAD = True + FP8Helper.FP8_2X_ACC_WGRAD = True + + @staticmethod + def finalize() -> None: + """ + FP8 helper finalize + """ + FP8Helper.INITIALIZED = False + FP8Helper.MARGIN = 0.0 + FP8Helper.FP8_FORMAT = Format.HYBRID + FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT) + FP8Helper.AMAX_HISTORY_LEN = 1024 + FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX + + @staticmethod + def update_collections(new: Collection, original: Collection) -> Collection: + """ + Update the collections + """ + assert isinstance(original, (dict, FrozenDict)) + assert isinstance(new, (dict, FrozenDict)) + frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original + for key in new: + if key in frozen_original: + frozen_original, _ = frozen_original.pop(key) + new_coll = FrozenDict({**new, **frozen_original}) + if not isinstance(original, FrozenDict): + new_coll = new_coll.unfreeze() + return new_coll + + @staticmethod + def generate_fp8_meta_dtype_converter_pair(*args): + """ + Generate a pair of conversion fun in-between fm32 and fp32. + """ + + def identical_fun(*metas): + return list(metas) + + def fm32_to_fp32_fun(*metas): + for meta in metas: + assert meta.dtype == FlaxFloatMeta32 + return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas] + + def fp32_to_fm32_fun(*metas): + for meta in metas: + assert meta.dtype == jnp.float32 + return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas] + + # Make functions to be a vaild JAX type + partial_identical_fun = jax.tree_util.Partial(identical_fun) + partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun) + partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun) + + if len(args) < 1: + return partial_identical_fun, partial_identical_fun + + original_dtype = args[0].dtype + for arg in args: + assert arg.dtype == original_dtype + + if original_dtype == FlaxFloatMeta32: + return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun + + return partial_identical_fun, partial_identical_fun + + @staticmethod + @jax.jit + def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray: + """ + Update the amax history + """ + updated_amax = jnp.roll(amax, -1, -1) + updated_amax = updated_amax.at[0].set(0) + return updated_amax + + @staticmethod + @partial(jax.jit, static_argnums=(2,)) + def update_fp8_scale(amax: jnp.ndarray, scale: jnp.ndarray, fp8_dtype: DType) -> jnp.ndarray: + """ + Calculate fp8 scale and scale_inv based on given amax. + """ + fp8_max = jnp.astype(jnp.finfo(fp8_dtype).max, jnp.float32) + + if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + amax = jnp.max(amax, axis=-1, keepdims=True) + else: + amax = amax[0:1] + + sf = (fp8_max / amax) / (2**FP8Helper.MARGIN) + sf = jnp.where(amax > 0.0, sf, scale) + sf = jnp.where(jnp.isfinite(amax), sf, scale) + scale = sf + scale_inv = 1 / sf + + return scale, scale_inv diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py new file mode 100644 index 0000000000..66c5ef158e --- /dev/null +++ b/transformer_engine/jax/gemm.py @@ -0,0 +1,484 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +from functools import partial +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from .fp8 import FP8Helper, FP8MetaPackage +from .cpp_extensions import ( + collective_gemm_impl, + collective_fp8_gemm_impl, + cast_transpose, + dact_lu, + dbias_cast_transpose, + dact_lu_dbias_cast_transpose, +) +from .cpp_extensions.gemm import sanitize_dims, mirror_dim + + +__all__ = [ + "gemm", + "fp8_gemm", + "type_safe_gemm", +] + + +def gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (-1, -2), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" + return _gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Union[ArrayLike, None], + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + out, _ = _gemm_fwd_rule( + x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator + ) + return out + + +def _gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + assert ( + kernel.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + + fuse_bias = bias is not None + + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + out, pre_gelu_out, _ = collective_gemm_impl( + x, + kernel, + bias=bias, + batched_output=(x.ndim > 2), + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + ctx = ( + x, + kernel, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + ) + + return out, ctx + + +def _gemm_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + x, kernel, pre_gelu_out, fuse_bias = ctx + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + x_outer_dim, kernel_outer_dim = map( + mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) + ) + + # FWD MODE: + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + + # DGRAD: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) + # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) + # (DP, None, None) x (TP, None)^T --> (DP, None, TP) + dgrad, dgelu, _ = collective_gemm_impl( + grad, + kernel, + gelu_input=pre_gelu_out, + batched_output=(x.ndim > 2), + contracting_dims=(-1, kernel_outer_dim), + fuse_gelu=fuse_gelu, + fuse_bias=False, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # WGRAD: + # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') + # + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) + # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) + # Make XLA scatter output in first dim. + wgrad_rhs = dgelu if fuse_gelu else grad + wgrad, _, bgrad = collective_gemm_impl( + x, + wgrad_rhs, + gelu_input=pre_gelu_out, + batched_output=False, + contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), + fuse_gelu=False, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_gemm.defvjp(_gemm_fwd_rule, _gemm_bwd_rule) + + +def fp8_gemm( + x: ArrayLike, + kernel_t: ArrayLike, + fp8_meta: FP8MetaPackage, + bias: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" + return _fp8_gemm( + x, + kernel_t, + bias, + fp8_meta.amax_list, + fp8_meta.scale_list, + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +def _fp8_gemm( + x: ArrayLike, + kernel_t: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + out, _ = _fp8_gemm_fwd_rule( + x, + kernel_t, + bias, + amax_list, + scale_list, + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + ) + return out + + +def _fp8_gemm_fwd_rule( + x: ArrayLike, + kernel_t: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + assert ( + kernel_t.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + + fuse_bias = bias is not None + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, + *scale_list, + ) + amax_list = maybe_fm32_to_fp32(*amax_list) + scale_list = maybe_fm32_to_fp32(*scale_list) + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype, fwd_dtype] + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) + amax_list = FP8MetaPackage.update_amax_list(amax_list) + + x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] + x_scale = scale_list[FP8MetaPackage.INPUT_IDX] + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_x, casted_x_t, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_x = x + casted_x_t = jnp.matrix_transpose(x) + updated_x_amax = x_amax + + kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] + kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + if kernel_t.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_kernel_t, casted_kernel, updated_kernel_amax = cast_transpose( + kernel_t, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_kernel = jnp.matrix_transpose(kernel_t) + casted_kernel_t = kernel_t + updated_kernel_amax = kernel_amax + + out_amax = ( + amax_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out_scale = ( + scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out, updated_out_amax, updated_out_scale, pre_gelu_out = collective_fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_kernel_t, + kernel_scale_inv, + bias=bias, + out_amax=out_amax, + out_scale=out_scale, + out_dtype=out_dtype, + batched_output=(x.ndim > 2), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + updated_out_amax = None + updated_out_scale = None + + ctx = ( + casted_x_t, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + updated_out_amax, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + maybe_fp32_to_fm32, + (x.ndim > 2), + ) + + return (out, updated_out_scale), ctx + + +def _fp8_gemm_bwd_rule( + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + ( + casted_x_t, + casted_kernel, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + updated_out_amax, + pre_gelu_out, + fuse_bias, + maybe_fp32_to_fm32, + batched_input, + ) = ctx + + bwd_dtype = FP8Helper.BWD_DTYPE + + grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] + grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] + grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] + if fuse_gelu: + if fuse_bias: + # Fuse dbias into this dGELU. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + activation_type=("gelu",), + ) + else: + # No bias to fuse so we just do dGELU. + casted_grad, casted_grad_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) + bgrad = None + else: + if fuse_bias: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + casted_grad, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None + + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = collective_fp8_gemm_impl( + casted_grad, + grad_scale_inv, + casted_kernel, + kernel_scale_inv, + batched_output=batched_input, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = collective_fp8_gemm_impl( + casted_x_t, + x_scale_inv, + casted_grad_t, + grad_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + amax_list[FP8MetaPackage.INPUT_IDX] = ( + amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( + amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) + ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( + amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + amax_list[FP8MetaPackage.OUTPUT_IDX] = ( + amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) + ) + + amax_list = maybe_fp32_to_fm32(*amax_list) + scale_list = maybe_fp32_to_fm32(*scale_list) + + return dgrad, wgrad, bgrad, amax_list, scale_list + + +_fp8_gemm.defvjp(_fp8_gemm_fwd_rule, _fp8_gemm_bwd_rule) + + +def type_safe_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + fp8_meta: Optional[FP8MetaPackage] = None, + out_dtype: Optional[jnp.dtype] = None, + contracting_dims: Tuple[int, int] = (-1, -2), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ]: + assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." + + if fp8_meta is not None: + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 1, ( + "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + + "i.e. contracting_dims=(-1, -1)." + ) + return fp8_gemm( + x, + kernel, + bias, + fp8_meta, + out_dtype, + fuse_gelu, + accumulate, + use_split_accumulator, + ) + return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) From fb7a993b050f75c71dc24f841d3ab7d6718a826d Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Tue, 15 Apr 2025 19:21:28 +0000 Subject: [PATCH 2/3] JAX collective GEMM without compute/communication overlap. Signed-off-by: Philipp Hack --- transformer_engine/jax/cpp_extensions/gemm.py | 113 ++-- transformer_engine/jax/dense.py | 210 ++++++++ transformer_engine/jax/fp8.py | 262 ---------- transformer_engine/jax/gemm.py | 484 ------------------ 4 files changed, 245 insertions(+), 824 deletions(-) delete mode 100644 transformer_engine/jax/fp8.py delete mode 100644 transformer_engine/jax/gemm.py diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ab35ad1f54..9e9c6e71ce 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -44,8 +44,7 @@ __all__ = ["gemm", "grouped_gemm", - "collective_fp8_gemm_impl", - "collective_gemm_impl"] + "collective_gemm"] num_cublas_streams = 4 @@ -545,6 +544,7 @@ def grouped_gemm( out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) return out_tensors + class CollectiveGemmPrimitive(BasePrimitive): """ cuBlasLt GEMM Primitive w/ support for distributed inputs @@ -795,8 +795,6 @@ def impl( rhs_scale_inv, bias, gelu_input, - out_amax, - out_scale, out_dtype, batched_output, contracting_dims, @@ -875,8 +873,6 @@ def impl( rhs_scale_inv, bias, gelu_input, - out_amax, - out_scale, out_dtype=out_dtype, batched_output=False, contracting_dims=contracting_dims_2d, @@ -1143,76 +1139,33 @@ def sharded_impl( register_primitive(CollectiveGemmPrimitive) -def collective_fp8_gemm_impl( - lhs: ArrayLike, - lhs_scale_inv: ArrayLike, - rhs_t: ArrayLike, - rhs_scale_inv: ArrayLike, - bias: Optional[ArrayLike] = None, +def collective_gemm( + lhs: Union[jnp.ndarray, ScaledTensor], + rhs: Union[jnp.ndarray, ScaledTensor], + bias: jnp.ndarray = None, gelu_input: Optional[ArrayLike] = None, - out_amax: Optional[ArrayLike] = None, - out_scale: Optional[ArrayLike] = None, - out_dtype: jnp.dtype = jnp.bfloat16, batched_output: bool = False, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), fuse_gelu: bool = False, - fuse_bias: bool = False, + grad: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, ) -> Tuple[ArrayLike, ...]: - """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - if out_dtype is not None and jax_dtype_is_fp8(out_dtype): - assert out_amax is not None and out_scale is not None, "Missing output amax and scale." - else: - out_amax = jnp.zeros(0, dtype=jnp.float32) - out_scale = jnp.zeros(0, dtype=jnp.float32) + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - if not fuse_bias: - bias = jnp.zeros(0, dtype=jnp.bfloat16) + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + scaling_mode = lhs.scaling_mode + out_dtype = lhs.dq_dtype + # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout + if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + assert not ( + lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" else: - assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." - - if not fuse_gelu: - gelu_input = jnp.zeros(0, dtype=bias.dtype) - elif gelu_input is None: - gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) - gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) + # For jnp.ndarray, only consider contracting_dims, layout is always NN + scaling_mode = ScalingMode.NVTE_NO_SCALING + out_dtype = lhs.dtype - out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - lhs_scale_inv, - rhs_t, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - out_dtype=out_dtype, - batched_output=batched_output, - contracting_dims=(-1, -1), - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=False, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - return out, out_amax, out_scale, pre_gelu_out - - -def collective_gemm_impl( - lhs: ArrayLike, - rhs: ArrayLike, - bias: Optional[ArrayLike] = None, - gelu_input: Optional[ArrayLike] = None, - batched_output: bool = False, - contracting_dims: Tuple[int, int] = (-1, -2), - fuse_gelu: bool = False, - fuse_bias: bool = False, - grad: bool = False, - accumulate: bool = False, - use_split_accumulator: bool = False, -) -> Tuple[ArrayLike, ...]: - """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_outer_dim, rhs_outer_dim = map( mirror_dim, @@ -1220,12 +1173,10 @@ def collective_gemm_impl( (lhs.ndim, rhs.ndim), ) - if not fuse_bias: + if bias is None: bias = jnp.zeros(0, dtype=lhs.dtype) elif grad: bias = jnp.zeros(rhs.shape[rhs_outer_dim], dtype=lhs.dtype) - else: - assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=lhs.dtype) @@ -1239,21 +1190,27 @@ def collective_gemm_impl( gelu_shape = (batch_size * lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(gelu_shape, dtype=lhs.dtypes) - dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_scale_inv = jnp.ones(1, dtype=jnp.float32) + rhs_scale_inv = jnp.ones(1, dtype=jnp.float32) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_scale_inv = lhs.scale_inv.reshape(-1) + rhs_scale_inv = rhs.scale_inv.reshape(-1) + if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_scale_inv = lhs_scale_inv.reshape(-1) + rhs_scale_inv = rhs_scale_inv.reshape(-1) + out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( lhs, - dummy_fp8_meta, + lhs_scale_inv, rhs, - dummy_fp8_meta, + rhs_scale_inv, bias, gelu_input, - dummy_fp8_meta, - dummy_fp8_meta, - out_dtype=lhs.dtype, + scaling_mode=scaling_mode, + out_dtype=out_dtype, batched_output=batched_output, contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, grad=grad, accumulate=accumulate, use_split_accumulator=use_split_accumulator, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 43336768cb..e0fdc567dc 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -168,6 +168,216 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) +def collective_dense( + x, + kernel, + bias = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + quantizer_set: QuantizerSet = None, + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +): + """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" + return _collective_dense(x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _collective_dense( + x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator, +): + out, _ = _collective_dense_fwd_rule( + x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator, + ) + return out + + +def _collective_dense_fwd_rule( + x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator, +): + assert ( + kernel.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + + fuse_bias = bias is not None + + if quantizer_set is None: + x_rowwise = x + x_colwise = x + kernel_colwise = kernel + kernel_rowwise = kernel + x_shape = x.shape + kernel_shape = kernel.shape + else: + q_x = tex.quantize(x, quantizer_set.x) + q_kernel = tex.quantize(kernel, quantizer_set.kernel) + x_rowwise = q_x.get_rowwise_tensor() + x_colwise = q_x.get_colwise_tensor() + kernel_colwise = q_kernel.get_colwise_tensor() + kernel_rowwise = q_kernel.get_rowwise_tensor() + x_shape = x_rowwise.data.shape + kernel_shape = kernel_rowwise.data.shape + + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + out, pre_gelu_out, _ = tex.collective_gemm( + x_rowwise, + kernel_colwise, + bias=bias, + batched_output=(x.ndim > 2), + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + ctx = ( + x_colwise, + kernel_rowwise, + x_shape, + kernel_shape, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + quantizer_set, + ) + + return out, ctx + + +def _collective_dense_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + ( + x_colwise, + kernel_rowwise, + x_shape, + kernel_shape, + pre_gelu_out, + fuse_bias, + quantizer_set, + ) = ctx + + if quantizer_set is None: + casted_grad = grad + bgrad = tex.quantization._jax_dbias(grad) + grad_rowwise = grad + grad_colwise = grad + else: + casted_grad, bgrad = tex.quantize_dbias( + grad, is_dbias=fuse_bias, quantizer=quantizer_set.dgrad + ) + grad_rowwise = casted_grad.get_rowwise_tensor() + grad_colwise = casted_grad.get_colwise_tensor() + + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + g_contracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + ) + k_contracting_dim = tuple( + dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + ) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + + # GEMM TN + g_contracting_dim = x_contracting_dim = tuple( + range(0, len(x_shape) - len(fwd_x_contracting_dims)) + ) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + + # FWD MODE: + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + + # DGRAD: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) + # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) + # (DP, None, None) x (TP, None)^T --> (DP, None, TP) + dgrad, dgelu, _ = tex.collective_gemm( + grad_rowwise, + kernel_rowwise, + gelu_input=pre_gelu_out, + batched_output=(x_colwise.ndim > 2), + contracting_dims=dgrad_contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=False, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # WGRAD: + # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') + # + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) + # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) + # Make XLA scatter output in first dim. + wgrad_rhs = dgelu if fuse_gelu else grad_colwise + wgrad, _, bgrad = tex.collective_gemm( + x_colwise, + wgrad_rhs, + gelu_input=pre_gelu_out, + batched_output=False, + contracting_dims=wgrad_contracting_dims, + fuse_gelu=False, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_collective_dense.defvjp(_collective_dense_fwd_rule, _collective_dense_bwd_rule) + + def grouped_dense( x_list, kernel_list, diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py deleted file mode 100644 index 53f5fc3d96..0000000000 --- a/transformer_engine/jax/fp8.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -""" -Helper module for fp8 meta management -""" -from enum import Enum -from functools import partial -from typing import Dict, List, Tuple, Union - -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict -from flax.linen import fp8_ops - -from transformer_engine_jax import DType -from transformer_engine.common.recipe import Format - -Collection = Union[Dict, FrozenDict] - -def _format2dtypes(format_: Format): - if format_ == Format.E4M3: - return jnp.float8_e4m3fn, jnp.float8_e4m3fn - if format_ == Format.E5M2: - return jnp.float8_e5m2, jnp.float8_e5m2 - if format_ == Format.HYBRID: - return jnp.float8_e4m3fn, jnp.float8_e5m2 - return jnp.bfloat16, jnp.bfloat16 - - -# fm32 is a custom dtype to specify the "add" rules as max operation. -# This is typically used in Pipeline Parallelism + "MiconBatching > 1", -# which is implemented via nn.scan. Without this custom dtype, nn.scan -# would sum gradients from all micro-batches, and this is not the expected -# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should -# be "MAX". -FlaxFloatMeta32 = fp8_ops.fm32 - - -class FP8MetaPackage: - """ - A container that contains all required meta data for FP8 - """ - - NUM_OF_META: int = 4 - INPUT_IDX: int = 0 - WEIGHT_IDX: int = 1 - GRAD_IDX: int = 2 - OUTPUT_IDX: int = 3 - - def __init__( - self, - input_amax: jnp.ndarray, - input_scale: jnp.ndarray, - weight_amax: jnp.ndarray, - weight_scale: jnp.ndarray, - grad_amax: jnp.ndarray, - grad_scale: jnp.ndarray, - output_amax: jnp.ndarray, - output_scale: jnp.ndarray, - ) -> None: - - self._amax_list = [None] * FP8MetaPackage.NUM_OF_META - self._scale_list = [None] * FP8MetaPackage.NUM_OF_META - - self._amax_list[FP8MetaPackage.INPUT_IDX] = input_amax - self._scale_list[FP8MetaPackage.INPUT_IDX] = input_scale - self._amax_list[FP8MetaPackage.WEIGHT_IDX] = weight_amax - self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale - self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax - self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale - self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax - self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale - - @property - def amax_list(self) -> List[jnp.ndarray]: - """ - Get the amax list of this package. - """ - return self._amax_list - - @property - def scale_list(self) -> List[jnp.ndarray]: - """ - Get the scale list of this package. - """ - return self._scale_list - - @staticmethod - def update_amax_list(amax_list: List[jnp.ndarray]) -> jnp.ndarray: - """ - Update the amax history list - """ - updated_amax_list = [FP8Helper.update_amax_history(amax) for amax in amax_list] - return updated_amax_list - - @staticmethod - def update_fp8_scale( - amax_list: List[jnp.ndarray], scale_list: List[jnp.ndarray], fp8_dtype_list: List[DType] - ) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]: - """ - Get update scale and scale_inv list - """ - update_scale_list = [] - update_scale_inv_list = [] - for amax, scale, fp8_dtype in zip(amax_list, scale_list, fp8_dtype_list): - upadted_scale, updated_scale_inv = FP8Helper.update_fp8_scale(amax, scale, fp8_dtype) - update_scale_list.append(upadted_scale) - update_scale_inv_list.append(updated_scale_inv) - return update_scale_list, update_scale_inv_list - - -class AmaxComputeAlgo(Enum): - """AmaxComputeAlgo.""" - - MAX = "max" - MOST_RECENT = "most_recent" - - -NVTE_FP8_COLLECTION_NAME = "fp8_metas" - - -class FP8Helper: - """ - FP8 helper to manage the FP8 meta - """ - - INITIALIZED = False - MARGIN: float = 0.0 - FP8_FORMAT: Format = Format.HYBRID - FWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[0] - BWD_DTYPE: DType = _format2dtypes(Format.HYBRID)[1] - AMAX_HISTORY_LEN: int = 1024 - AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX - FP8_COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME - FP8_AMAX_NAME: str = "amax" - FP8_SCALE_NAME: str = "scale" - FP8_2X_ACC_FPROP: bool = False - FP8_2X_ACC_DGRAD: bool = True - FP8_2X_ACC_WGRAD: bool = True - - @staticmethod - def is_fp8_enabled(): - """ - Indicate if fp8 training is enable or not. - """ - return FP8Helper.INITIALIZED - - @staticmethod - def initialize( - margin: float = 0.0, - fp8_format: Format = Format.HYBRID, - amax_history_len: int = 1, - amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX, - ) -> None: - """ - Initialize the FP8 meta - """ - FP8Helper.INITIALIZED = True - FP8Helper.MARGIN = margin - FP8Helper.FP8_FORMAT = fp8_format - FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT) - FP8Helper.AMAX_HISTORY_LEN = amax_history_len - FP8Helper.AMAX_COMPUTE_ALGO = amax_compute_algo - FP8Helper.FP8_2X_ACC_FPROP = False - FP8Helper.FP8_2X_ACC_DGRAD = True - FP8Helper.FP8_2X_ACC_WGRAD = True - - @staticmethod - def finalize() -> None: - """ - FP8 helper finalize - """ - FP8Helper.INITIALIZED = False - FP8Helper.MARGIN = 0.0 - FP8Helper.FP8_FORMAT = Format.HYBRID - FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = _format2dtypes(FP8Helper.FP8_FORMAT) - FP8Helper.AMAX_HISTORY_LEN = 1024 - FP8Helper.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX - - @staticmethod - def update_collections(new: Collection, original: Collection) -> Collection: - """ - Update the collections - """ - assert isinstance(original, (dict, FrozenDict)) - assert isinstance(new, (dict, FrozenDict)) - frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original - for key in new: - if key in frozen_original: - frozen_original, _ = frozen_original.pop(key) - new_coll = FrozenDict({**new, **frozen_original}) - if not isinstance(original, FrozenDict): - new_coll = new_coll.unfreeze() - return new_coll - - @staticmethod - def generate_fp8_meta_dtype_converter_pair(*args): - """ - Generate a pair of conversion fun in-between fm32 and fp32. - """ - - def identical_fun(*metas): - return list(metas) - - def fm32_to_fp32_fun(*metas): - for meta in metas: - assert meta.dtype == FlaxFloatMeta32 - return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas] - - def fp32_to_fm32_fun(*metas): - for meta in metas: - assert meta.dtype == jnp.float32 - return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas] - - # Make functions to be a vaild JAX type - partial_identical_fun = jax.tree_util.Partial(identical_fun) - partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun) - partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun) - - if len(args) < 1: - return partial_identical_fun, partial_identical_fun - - original_dtype = args[0].dtype - for arg in args: - assert arg.dtype == original_dtype - - if original_dtype == FlaxFloatMeta32: - return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun - - return partial_identical_fun, partial_identical_fun - - @staticmethod - @jax.jit - def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray: - """ - Update the amax history - """ - updated_amax = jnp.roll(amax, -1, -1) - updated_amax = updated_amax.at[0].set(0) - return updated_amax - - @staticmethod - @partial(jax.jit, static_argnums=(2,)) - def update_fp8_scale(amax: jnp.ndarray, scale: jnp.ndarray, fp8_dtype: DType) -> jnp.ndarray: - """ - Calculate fp8 scale and scale_inv based on given amax. - """ - fp8_max = jnp.astype(jnp.finfo(fp8_dtype).max, jnp.float32) - - if FP8Helper.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: - amax = jnp.max(amax, axis=-1, keepdims=True) - else: - amax = amax[0:1] - - sf = (fp8_max / amax) / (2**FP8Helper.MARGIN) - sf = jnp.where(amax > 0.0, sf, scale) - sf = jnp.where(jnp.isfinite(amax), sf, scale) - scale = sf - scale_inv = 1 / sf - - return scale, scale_inv diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py deleted file mode 100644 index 66c5ef158e..0000000000 --- a/transformer_engine/jax/gemm.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -from functools import partial -from typing import Optional, Tuple, Union - -import jax -import jax.numpy as jnp -from jax.typing import ArrayLike - -from .fp8 import FP8Helper, FP8MetaPackage -from .cpp_extensions import ( - collective_gemm_impl, - collective_fp8_gemm_impl, - cast_transpose, - dact_lu, - dbias_cast_transpose, - dact_lu_dbias_cast_transpose, -) -from .cpp_extensions.gemm import sanitize_dims, mirror_dim - - -__all__ = [ - "gemm", - "fp8_gemm", - "type_safe_gemm", -] - - -def gemm( - x: ArrayLike, - kernel: ArrayLike, - bias: Optional[ArrayLike] = None, - contracting_dims: Tuple[int, int] = (-1, -2), - fuse_gelu: bool = False, - accumulate: bool = False, - use_split_accumulator: bool = False, -) -> ArrayLike: - """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" - return _gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) - - -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) -def _gemm( - x: ArrayLike, - kernel: ArrayLike, - bias: Union[ArrayLike, None], - contracting_dims: Tuple[int, int], - fuse_gelu: bool, - accumulate: bool, - use_split_accumulator: bool, -) -> ArrayLike: - out, _ = _gemm_fwd_rule( - x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator - ) - return out - - -def _gemm_fwd_rule( - x: ArrayLike, - kernel: ArrayLike, - bias: ArrayLike, - contracting_dims: Tuple[int, int], - fuse_gelu: bool, - accumulate: bool, - use_split_accumulator: bool, -) -> Tuple[ArrayLike, ...]: - assert ( - kernel.ndim == 2 - ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." - - fuse_bias = bias is not None - - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) - # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) - # - # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) - # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) - out, pre_gelu_out, _ = collective_gemm_impl( - x, - kernel, - bias=bias, - batched_output=(x.ndim > 2), - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - ctx = ( - x, - kernel, - pre_gelu_out if fuse_gelu else None, - fuse_bias, - ) - - return out, ctx - - -def _gemm_bwd_rule( - contracting_dims, - fuse_gelu, - accumulate, - use_split_accumulator, - ctx, - grad, -): - x, kernel, pre_gelu_out, fuse_bias = ctx - x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - x_outer_dim, kernel_outer_dim = map( - mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) - ) - - # FWD MODE: - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) - # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) - # - # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) - # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) - - # DGRAD: - # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) - # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) - # - # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) - # (DP, None, None) x (TP, None)^T --> (DP, None, TP) - dgrad, dgelu, _ = collective_gemm_impl( - grad, - kernel, - gelu_input=pre_gelu_out, - batched_output=(x.ndim > 2), - contracting_dims=(-1, kernel_outer_dim), - fuse_gelu=fuse_gelu, - fuse_bias=False, - grad=True, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - # WGRAD: - # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) - # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') - # - # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) - # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) - # Make XLA scatter output in first dim. - wgrad_rhs = dgelu if fuse_gelu else grad - wgrad, _, bgrad = collective_gemm_impl( - x, - wgrad_rhs, - gelu_input=pre_gelu_out, - batched_output=False, - contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), - fuse_gelu=False, - fuse_bias=fuse_bias, - grad=True, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - if not fuse_bias: - bgrad = None - - return dgrad, wgrad, bgrad - - -_gemm.defvjp(_gemm_fwd_rule, _gemm_bwd_rule) - - -def fp8_gemm( - x: ArrayLike, - kernel_t: ArrayLike, - fp8_meta: FP8MetaPackage, - bias: Optional[ArrayLike] = None, - out_dtype: jnp.dtype = jnp.bfloat16, - fuse_gelu: bool = False, - accumulate: bool = False, - use_split_accumulator: bool = False, -) -> ArrayLike: - """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" - return _fp8_gemm( - x, - kernel_t, - bias, - fp8_meta.amax_list, - fp8_meta.scale_list, - out_dtype, - fuse_gelu, - accumulate, - use_split_accumulator, - ) - - -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) -def _fp8_gemm( - x: ArrayLike, - kernel_t: ArrayLike, - bias: ArrayLike, - amax_list: ArrayLike, - scale_list: ArrayLike, - out_dtype: jnp.dtype, - fuse_gelu: bool, - accumulate: bool, - use_split_accumulator: bool, -) -> ArrayLike: - out, _ = _fp8_gemm_fwd_rule( - x, - kernel_t, - bias, - amax_list, - scale_list, - out_dtype, - fuse_gelu, - accumulate, - use_split_accumulator, - ) - return out - - -def _fp8_gemm_fwd_rule( - x: ArrayLike, - kernel_t: ArrayLike, - bias: ArrayLike, - amax_list: ArrayLike, - scale_list: ArrayLike, - out_dtype: jnp.dtype, - fuse_gelu: bool, - accumulate: bool, - use_split_accumulator: bool, -) -> Tuple[ArrayLike, ...]: - assert ( - kernel_t.ndim == 2 - ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." - - fuse_bias = bias is not None - - maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( - *amax_list, - *scale_list, - ) - amax_list = maybe_fm32_to_fp32(*amax_list) - scale_list = maybe_fm32_to_fp32(*scale_list) - - fwd_dtype = FP8Helper.FWD_DTYPE - bwd_dtype = FP8Helper.BWD_DTYPE - fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype, fwd_dtype] - scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( - amax_list, scale_list, fp8_dtype_list - ) - amax_list = FP8MetaPackage.update_amax_list(amax_list) - - x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] - x_scale = scale_list[FP8MetaPackage.INPUT_IDX] - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - casted_x, casted_x_t, updated_x_amax = cast_transpose( - x, - x_amax, - x_scale, - x_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_x = x - casted_x_t = jnp.matrix_transpose(x) - updated_x_amax = x_amax - - kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] - kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] - kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - if kernel_t.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - casted_kernel_t, casted_kernel, updated_kernel_amax = cast_transpose( - kernel_t, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_kernel = jnp.matrix_transpose(kernel_t) - casted_kernel_t = kernel_t - updated_kernel_amax = kernel_amax - - out_amax = ( - amax_list[FP8MetaPackage.OUTPUT_IDX][0:1] - if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - else None - ) - out_scale = ( - scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] - if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - else None - ) - out, updated_out_amax, updated_out_scale, pre_gelu_out = collective_fp8_gemm_impl( - casted_x, - x_scale_inv, - casted_kernel_t, - kernel_scale_inv, - bias=bias, - out_amax=out_amax, - out_scale=out_scale, - out_dtype=out_dtype, - batched_output=(x.ndim > 2), - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - updated_out_amax = None - updated_out_scale = None - - ctx = ( - casted_x_t, - casted_kernel, - amax_list, - scale_list, - scale_inv_list, - updated_x_amax, - updated_kernel_amax, - updated_out_amax, - pre_gelu_out if fuse_gelu else None, - fuse_bias, - maybe_fp32_to_fm32, - (x.ndim > 2), - ) - - return (out, updated_out_scale), ctx - - -def _fp8_gemm_bwd_rule( - out_dtype, - fuse_gelu, - accumulate, - use_split_accumulator, - ctx, - grad, -): - ( - casted_x_t, - casted_kernel, - amax_list, - scale_list, - scale_inv_list, - updated_x_amax, - updated_kernel_amax, - updated_out_amax, - pre_gelu_out, - fuse_bias, - maybe_fp32_to_fm32, - batched_input, - ) = ctx - - bwd_dtype = FP8Helper.BWD_DTYPE - - grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] - grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] - grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] - if fuse_gelu: - if fuse_bias: - # Fuse dbias into this dGELU. - casted_grad, casted_grad_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( - grad, - pre_gelu_out, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - activation_type=("gelu",), - ) - else: - # No bias to fuse so we just do dGELU. - casted_grad, casted_grad_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) - bgrad = None - else: - if fuse_bias: - # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. - casted_grad, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - # If both bias and GELU is fused into the forward pass, we will fuse dbias later with - # dGELU. No need to do it here. - casted_grad, casted_grad_t, updated_grad_amax = cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - bgrad = None - - kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad, *_ = collective_fp8_gemm_impl( - casted_grad, - grad_scale_inv, - casted_kernel, - kernel_scale_inv, - batched_output=batched_input, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad, *_ = collective_fp8_gemm_impl( - casted_x_t, - x_scale_inv, - casted_grad_t, - grad_scale_inv, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - amax_list[FP8MetaPackage.INPUT_IDX] = ( - amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) - ) - amax_list[FP8MetaPackage.WEIGHT_IDX] = ( - amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) - ) - amax_list[FP8MetaPackage.GRAD_IDX] = ( - amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) - ) - if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - amax_list[FP8MetaPackage.OUTPUT_IDX] = ( - amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) - ) - - amax_list = maybe_fp32_to_fm32(*amax_list) - scale_list = maybe_fp32_to_fm32(*scale_list) - - return dgrad, wgrad, bgrad, amax_list, scale_list - - -_fp8_gemm.defvjp(_fp8_gemm_fwd_rule, _fp8_gemm_bwd_rule) - - -def type_safe_gemm( - x: ArrayLike, - kernel: ArrayLike, - bias: Optional[ArrayLike] = None, - fp8_meta: Optional[FP8MetaPackage] = None, - out_dtype: Optional[jnp.dtype] = None, - contracting_dims: Tuple[int, int] = (-1, -2), - fuse_gelu: bool = False, - accumulate: bool = False, - use_split_accumulator: bool = False, -) -> ArrayLike: - if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ - jnp.float8_e4m3fn, - jnp.float8_e5m2, - ]: - assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." - - if fp8_meta is not None: - x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 1, ( - "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " - + "i.e. contracting_dims=(-1, -1)." - ) - return fp8_gemm( - x, - kernel, - bias, - fp8_meta, - out_dtype, - fuse_gelu, - accumulate, - use_split_accumulator, - ) - return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) From 2c39c6194f92976cd01c11e44d145ca7dac5d5b2 Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Tue, 15 Apr 2025 19:27:30 +0000 Subject: [PATCH 3/3] JAX collective GEMM without compute/communication overlap. Signed-off-by: Philipp Hack --- .../jax/cpp_extensions/__init__.py | 1 - .../jax/cpp_extensions/transpose.py | 855 ------------------ 2 files changed, 856 deletions(-) delete mode 100644 transformer_engine/jax/cpp_extensions/transpose.py diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 1afc172c9a..a9fe2e7d7b 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -8,4 +8,3 @@ from .normalization import * from .quantization import * from .softmax import * -from .transpose import * diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py deleted file mode 100644 index b7e6ac9ada..0000000000 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ /dev/null @@ -1,855 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""JAX/TE custom ops for transpose""" -from functools import partial, reduce -from typing import Tuple, Sequence, Union, Callable -import operator -from packaging import version - -import jax -import jax.numpy as jnp -from jax import dtypes -from jax.sharding import PartitionSpec, NamedSharding - -import transformer_engine_jax -from transformer_engine_jax import DType as TEDType - -from .base import BasePrimitive, register_primitive -from .misc import ( - check_valid_batch_dims, - jax_dtype_to_te_dtype, - te_dtype_to_jax_dtype, - get_padded_spec, - multidim_transpose, - normalize_axis_boundary, -) -from .activation import ActivationEnum -from .activation import _jax_act_lu -from .quantization import _jax_cast_fp8 -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp - -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - - -__all__ = [ - "transpose", - "cast_transpose", - "dbias_cast_transpose", - "dact_lu_dbias_cast_transpose", -] - - -def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary): - """ - JAX native transpose implementation - """ - axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary) - return jnp.transpose(inputs, axes=axes) - - -def _jax_cast_transpose( - inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary -): - """ - JAX native cast_transpose implementation - """ - casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype) - casted_transposed_output = _jax_transpose( - casted_output, static_axis_boundary, transpose_axis_boundary - ) - return casted_output, casted_transposed_output, updated_amax - - -def _jax_dbias_cast_transpose( - dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary -): - """ - JAX native dbias_cast_transpose implementation - """ - casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( - dz, - scale, - amax, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - dbias = jnp.sum( - dz, - axis=tuple( - range( - transpose_axis_boundary - if transpose_axis_boundary > 0 - else transpose_axis_boundary + dz.ndim - ) - ), - keepdims=False, - ) - dbias = dbias.ravel() # C++ function returns an 1D array for dbias - return casted_dz, cast_transposed_dz, dbias, updated_amax - - -class TransposePrimitive(BasePrimitive): - """ - Transpose Primitive - """ - - name = "te_transpose" - multiple_results = False - impl_static_args = (1, 2) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary): - """ - _transpose abstract - """ - transposed_x_shape = multidim_transpose( - x_aval.shape, static_axis_boundary, transpose_axis_boundary - ) - xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype) - - return xt_aval - - @staticmethod - def lowering(ctx, x, *, transpose_axis_boundary): - """ - _transpose cuda lowering - """ - - x_aval = ctx.avals_in[0] - assert x_aval.dtype in [ - jnp.float32, - jnp.float16, - jnp.bfloat16, - jnp.float8_e4m3fn, - jnp.float8_e5m2, - ] - - name = "te_transpose_ffi" - return ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary) - - @staticmethod - def impl(x, static_axis_boundary, transpose_axis_boundary): - """ - tcast_transpose implementation - """ - assert TransposePrimitive.inner_primitive is not None - transposed_x = TransposePrimitive.inner_primitive.bind( - x, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - return transposed_x - - @staticmethod - def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary): - check_valid_batch_dims(batch_dims) - assert TransposePrimitive.outer_primitive is not None - assert static_axis_boundary < 0 - - (x,) = batched_args - (x_bdim,) = batch_dims - - # Minus batch dim. - transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim - - out_bdims = x_bdim - return ( - TransposePrimitive.outer_primitive.bind( - x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary - ), - out_bdims, - ) - - @staticmethod - def infer_sharding_from_operands( - static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos - ): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - return transposed_x_sharding - - @staticmethod - def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = transposed_x_sharding - - impl = partial( - TransposePrimitive.impl, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(TransposePrimitive) - - -def transpose( - x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int -) -> jnp.ndarray: - """ - transpose wrapper - """ - if not TransposePrimitive.enabled(): - return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary) - return TransposePrimitive.outer_primitive.bind( - x, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - - -class CastTransposePrimitive(BasePrimitive): - """ - Cast Transpose Primitive - """ - - name = "te_cast_transpose" - multiple_results = True - impl_static_args = (4, 5, 6) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - x_aval, - amax_aval, - scale_aval, - scale_inv_aval, - *, - out_dtype, - static_axis_boundary, - transpose_axis_boundary - ): - """ - te_cast_transpose_p abstract - """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - - transposed_x_shape = multidim_transpose( - x_aval.shape, static_axis_boundary, transpose_axis_boundary - ) - - casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype) - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - - return casted_x_aval, casted_xt_aval, updated_amax_aval - - @staticmethod - def lowering( - ctx, x, amax, scale, scale_inv, *, transpose_axis_boundary - ): - """ - te_cast_transpose_p lowering rules - """ - x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - name = "te_cast_transpose_ffi" - return ffi.ffi_lowering(name, operand_output_aliases={1: 2})( - ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary - ) - - @staticmethod - def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary): - """ - te_cast_transpose implementation - """ - assert CastTransposePrimitive.inner_primitive is not None - casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind( - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - return casted_x, casted_transposed_x, updated_amax - - @staticmethod - def batcher( - batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary - ): - check_valid_batch_dims(batch_dims) - assert CastTransposePrimitive.outer_primitive is not None - assert static_axis_boundary < 0 - - x, amax, scale, scale_inv = batched_args - x_bdim, amax_bdim, *_ = batch_dims - - # Minus batch dim. - transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim - - out_bdims = x_bdim, x_bdim, amax_bdim - return ( - CastTransposePrimitive.outer_primitive.bind( - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary, - ), - out_bdims, - ) - - @staticmethod - def infer_sharding_from_operands( - out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos - ): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[0]) - casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) - - @staticmethod - def partition( - out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos - ): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) - - def sharded_impl(x, amax, scale, scale_inv): - local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl( - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) - - return local_cx, local_cxt, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(CastTransposePrimitive) - - -def cast_transpose( - x: jnp.ndarray, - amax: jnp.ndarray, - scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: jnp.dtype, - static_axis_boundary: int, - transpose_axis_boundary: int, -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - cast transpose wrapper - Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` - """ - if not CastTransposePrimitive.enabled(): - return _jax_cast_transpose( - x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary - ) - return CastTransposePrimitive.outer_primitive.bind( - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - - -class DBiasCastTransposePrimitive(BasePrimitive): - """ - DBias Cast Transpose Primitive - """ - - name = "te_dbias_cast_transpose" - multiple_results = True - # out_dtype, static_axis_boundary, transpose_axis_boundary - impl_static_args = (4, 5, 6) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - dz_aval, - amax_aval, - scale_aval, - scale_inv_aval, - *, - out_dtype, - static_axis_boundary, - transpose_axis_boundary - ): - """ - te_dbias_cast_transpose_p abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:]) - t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary) - out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) - t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - - dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size) - dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) - - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes( - dz_aval.size // gi_hidden_size, - gi_hidden_size, - jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - ) - wkspace_aval = dz_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) - - return out, t_out, dbias, updated_amax_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - te_dbias_cast_transpose_p outer abstract - """ - - out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract( - *args, **kwargs - ) - return out, t_out, dbias, updated_amax_aval - - @staticmethod - def lowering( - ctx, dz, amax, scale, scale_inv, *, transpose_axis_boundary - ): - """ - te_dbias_cast_transpose_p lowering rules - """ - dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - name = "te_dbias_cast_transpose_ffi" - return ffi.ffi_lowering(name, operand_output_aliases={1: 3})( - ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary - ) - - @staticmethod - def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary): - """ - to describe implementation - """ - assert DBiasCastTransposePrimitive.inner_primitive is not None - out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind( - dz, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - return out, t_out, dbias, updated_amax - - @staticmethod - def batcher( - batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary - ): - """ - to describe batch rules for vmap - """ - del static_axis_boundary - check_valid_batch_dims(batch_dims) - assert DBiasCastTransposePrimitive.outer_primitive is not None - dz, amax, scale, scale_inv = batched_args - dz_bdim, amax_bdim, _, _ = batch_dims - - # Minus batch dim. - transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim - - out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim - return ( - DBiasCastTransposePrimitive.outer_primitive.bind( - dz, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=dz_bdim, - transpose_axis_boundary=transpose_axis_boundary, - ), - out_bdims, - ) - - @staticmethod - def infer_sharding_from_operands( - out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos - ): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) - ) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) - - @staticmethod - def partition( - out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos - ): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - - dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = ( - casted_x_sharding, - casted_transposed_x_sharding, - dbias_shaprding, - amax_sharding, - ) - - def sharded_impl(dz, amax, scale, scale_inv): - local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl( - dz, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) - return local_out, local_t_out, global_dbias, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(DBiasCastTransposePrimitive) - - -def dbias_cast_transpose( - dz: jnp.ndarray, - amax: jnp.ndarray, - scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: TEDType, - static_axis_boundary: int, - transpose_axis_boundary: int = -1, -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - cast transpose dbias partial fusion wrapper - Return FP8(inputs), dbias - """ - if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes - - if not DBiasCastTransposePrimitive.enabled(): - return _jax_dbias_cast_transpose( - dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary - ) - - return DBiasCastTransposePrimitive.outer_primitive.bind( - dz, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - - -class DActLuDBiasCastTransposePrimitive(BasePrimitive): - """ - DActLu DBias Cast Transpose Primitive - """ - - name = "te_dact_lu_dbias_cast_transpose" - multiple_results = True - # out_dtype, static_axis_boundary, act_enum - impl_static_args = (5, 6, 7) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - dz_aval, - x_aval, - amax_aval, - scale_aval, - scale_inv_aval, - *, - out_dtype, - static_axis_boundary, - ): # pylint: disable=unused-argument - """ - te_dact_lu_dbais_cast_transpose_p abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_hidden_szie = dz_aval.shape[-1] - gi_hidden_size = x_aval.shape[-1] - assert ir_hidden_szie == gi_hidden_size - t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2) - out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) - t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - - dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size) - dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) - - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - - (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( - x_aval.size // gi_hidden_size, - gi_hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - ) - wkspace_aval = x_aval.update( - shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) - ) - - return out, t_out, dbias, updated_amax_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - te_dact_lu_dbais_cast_transpose_p outer abstract - """ - - out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract( - *args, **kwargs - ) - return out, t_out, dbias, updated_amax_aval - - @staticmethod - def lowering(ctx, dz, x, amax, scale, scale_inv, *, act_enum): - """ - te_dgated_act_lu_cast_transpose_p lowering rules - """ - dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dz_aval.dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - name = "te_dact_lu_dbias_cast_transpose_ffi" - return ffi.ffi_lowering(name, operand_output_aliases={2: 3})( - ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum) - ) - - @staticmethod - def impl( - dz, - x, - amax, - scale, - scale_inv, - out_dtype, - static_axis_boundary, - act_enum, - ): - """ - to describe implementation - """ - assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None - out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - act_enum=act_enum, - ) - return out, t_out, dbias, updated_amax - - @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum): - """ - to describe batch rules for vmap - """ - del static_axis_boundary - check_valid_batch_dims(batch_dims) - assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None - dz, x, amax, scale, scale_inv = batched_args - x_bdim, _, amax_bdim, _, _ = batch_dims - - out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim - return ( - DActLuDBiasCastTransposePrimitive.outer_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=x_bdim, - act_enum=act_enum, - ), - out_bdims, - ) - - @staticmethod - def infer_sharding_from_operands( - out_dtype, - static_axis_boundary, - act_enum, - mesh, - arg_infos, - result_infos, - ): - del out_dtype, result_infos, act_enum - x_spec = get_padded_spec(arg_infos[1]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2) - tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) - ) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) - return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) - - @staticmethod - def partition( - out_dtype, - static_axis_boundary, - act_enum, - mesh, - arg_infos, - result_infos, - ): - del result_infos - x_spec = get_padded_spec(arg_infos[1]) - casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2) - casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - - dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) - ) - - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = ( - casted_x_sharding, - casted_transposed_x_sharding, - dbias_shaprding, - amax_sharding, - ) - - def sharded_impl(dz, x, amax, scale, scale_inv): - local_out, local_t_out, local_dbias, local_amax = ( - DActLuDBiasCastTransposePrimitive.impl( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - act_enum=act_enum, - ) - ) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) - return local_out, local_t_out, global_dbias, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(DActLuDBiasCastTransposePrimitive) - - -def dact_lu_dbias_cast_transpose( - dz: jnp.ndarray, - x: jnp.ndarray, - amax: jnp.ndarray, - scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: TEDType, - static_axis_boundary: int, - activation_type: Sequence[Union[str, Callable]] = ("gelu",), -) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - cast transpose dact_lu and dbias fusion wrapper - Return FP8(dact_lu(inputs)), dbias - ONLY support non-gated activation type - """ - if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes - - if not DActLuDBiasCastTransposePrimitive.enabled(): - _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) - (dx,) = vjp_func(dz) - transpose_axis_boundary = -2 - return _jax_dbias_cast_transpose( - dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary - ) - - act_type_id = ActivationEnum[activation_type] - return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - act_enum=act_type_id, - )