diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py new file mode 100644 index 0000000000..60f0d7b6f6 --- /dev/null +++ b/tests/jax/test_distributed_gemm.py @@ -0,0 +1,307 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest +from functools import partial +from collections.abc import Iterable + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.jax.gemm import gemm +from transformer_engine.jax.quantize import helper + +from utils import assert_allclose + + +jax.config.update("jax_enable_compilation_cache", False) + + +# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) +# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128) +# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P) + +# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128) +# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P) +# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128) + +BATCH = 4 +BASE_SIZE = 16 +SEQ_LEN = BASE_SIZE * 8 +HIDDEN_SIZE = BASE_SIZE * 6 +FFN_HIDDEN_SIZE = BASE_SIZE * 16 + +COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"] +MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"] +NUM_DEVICES = 4 + +is_fp8_supported, no_fp8_reason = helper.is_fp8_available() + + +def _get_mesh(parallel_dist): + jax.clear_caches() + + batched = False + fsdp = False + mesh_shape = dict(tp=NUM_DEVICES) + resources = dict(cp_resource="tp", tp_resource="tp") + if parallel_dist in ["DP_TP", "FSDP_TP"]: + batched = True + tp = NUM_DEVICES // 2 + dp = NUM_DEVICES // tp + mesh_shape.update(dict(tp=tp, dp=dp)) + resources.update(dict(dp_resource="dp")) + if parallel_dist == "FSDP_TP": + fsdp = True + dp = 1 + zp = NUM_DEVICES // tp + mesh_shape.update(dict(tp=tp, dp=1, zp=zp)) + resources.update(dict(fsdp_resource="zp")) + mesh_resource = te.MeshResource(**resources) + + devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) + + mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) + + return mesh, mesh_resource, batched, fsdp + + +def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False): + fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + # Operand and output shapes + lhs_shape = ( + [SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] + ) + rhs_shape = ( + [HIDDEN_SIZE, FFN_HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [FFN_HIDDEN_SIZE, HIDDEN_SIZE] + ) + out_shape = [lhs_shape[0], rhs_shape[1]] + + if batched: + lhs_shape = [BATCH] + lhs_shape + out_shape = [BATCH] + out_shape + + # Operand and output partition specs + lhs_spec = ( + [mesh_resource.tp_resource, None] + if fwd_comm_type == "ALL_GATHER" + else [None, mesh_resource.tp_resource] + ) + rhs_spec = ( + [None, mesh_resource.tp_resource] + if fwd_comm_type == "ALL_GATHER" + else [mesh_resource.tp_resource, None] + ) + out_spec = [None, rhs_spec[-1]] + + # Modify RHS operand for FP8 + fsdp_gathered_rhs_spec = rhs_spec.copy() + if fp8_gemm: + rhs_shape = list(reversed(rhs_shape)) + rhs_spec = list(reversed(rhs_spec)) + fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec)) + + # Add batch dimensions and specs + if batched: + if fsdp: + lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec + rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec] + out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec + else: + lhs_spec = [mesh_resource.dp_resource] + lhs_spec + out_spec = [mesh_resource.dp_resource] + out_spec + + # Allocate global operands on device + key = jax.random.PRNGKey(42) + split_keys = jax.random.split(key, 3 if fwd_bwd else 2) + mu = 0.0 + sigma = 0.023 + shapes = (lhs_shape, rhs_shape) + if fwd_bwd: + shapes += (out_shape,) + global_operands = list( + map( + lambda key, shape: jax.device_put( + mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), + NamedSharding(mesh, PartitionSpec(None)), + ), + split_keys, + shapes, + ) + ) + + # Allocate sharded operands on device + partition_axes = (lhs_spec, rhs_spec) + if fwd_bwd: + partition_axes += (out_spec,) + local_operands = list( + map( + lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), + global_operands, + partition_axes, + ) + ) + + # Tranpose global RHS back to non-transpoosed orientation if it was originally allocated + # for FP8 GEMM + if fp8_gemm: + rhs_global = jnp.matrix_transpose(global_operands[1]) + global_operands = (global_operands[0], rhs_global, *global_operands[2:]) + + return ( + local_operands, + global_operands, + (out_shape, out_spec), + fsdp_gathered_rhs_spec, + ) + + +def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False): + num_operands = 3 if fwd_bwd else 2 + ref_operands = tensors[:num_operands] + test_outputs = tensors[num_operands:] + + # Check number of dimensions + assert test_outputs[0].ndim == len(expected_out_shape), ( + f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected " + + f"({len(expected_out_shape)})" + ) + + # Pad test output spec for unsharded dimensions + test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim) + + for i in range(test_outputs[0].ndim): + # Check shape + assert test_outputs[0].shape[i] == expected_out_shape[i], ( + f"Output with shape {test_outputs[0].shape} does not match expected shape " + + f"{expected_out_shape} in dimension index {i}." + ) + + # Check shardings (with padded output spec) + spec_mismatch = False + if isinstance(expected_out_specs[i], str): + if test_spec[i] != expected_out_specs[i]: + spec_mismatch = True + elif isinstance(expected_out_specs[i], Iterable): + if not isinstance(test_spec[i], type(expected_out_specs[i])): + if test_spec[i] not in expected_out_specs[i]: + spec_mismatch = True + elif len(test_spec[i]) != len(expected_out_specs[i]): + spec_mismatch = True + else: + for j in range(len(expected_out_specs[i])): + if test_spec[i][j] != expected_out_specs[i][j]: + spec_mismatch = True + break + elif expected_out_specs[i] == None: + if test_spec[i] != None: + spec_mismatch = True + else: + raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.") + if spec_mismatch: + raise AssertionError( + f"Output sharding {test_spec} does not match expected sharding " + + f"{expected_out_specs} in dimension index {i}." + ) + + def _native_gemm_fwd_bwd(lhs, rhs, grad): + fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs) + lhs_grad, rhs_grad = vjp_fn(grad) + return fwd_out, lhs_grad, rhs_grad + + ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot) + + out_names = ["output"] + ref_outputs = ref_fn(*ref_operands) + if not fwd_bwd: + ref_outputs = [ref_outputs] + else: + out_names += ["dgrad", "wgrad"] + + for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)): + test_out_global = jax.lax.with_sharding_constraint( + test_out, NamedSharding(mesh, PartitionSpec(None)) + ) + try: + assert_allclose(ref_out, test_out_global) + except AssertionError as err: + raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err)) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_impl(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) + + @jax.jit + def _test_fn(lhs, rhs): + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + return te.cpp_extensions.collective_gemm_impl(lhs, rhs_no_fsdp, batched_output=batched) + + with te.sharding.global_shard_guard(mesh_resource): + output, *_ = _test_fn(*local_operands) + + _check_output(mesh, *output_info, *global_operands, output) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_fwd_bwd(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) + + @jax.jit + def _test_fn(lhs, rhs, grad): + # Gather weights in FSDP axis + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + + # FWD pass + fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp) + + # BWD pass + lhs_grad, rhs_grad = vjp_fn(grad) + + return fwd_out, lhs_grad, rhs_grad + + print( + f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n" + + f" LHS sharding: {local_operands[0].sharding.spec}\n" + + f" RHS sharding: {local_operands[1].sharding.spec}\n" + ) + + with te.sharding.global_shard_guard(mesh_resource): + output, dgrad, wgrad = _test_fn(*local_operands) + + print( + f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: " + + f"{output.shape} | {output.sharding.spec}\n" + + f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n" + + f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n" + ) + + _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b8c8df37ee..f481ff39ab 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -91,30 +91,30 @@ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ - py::call_guard()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + pybind11::call_guard()) \ .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ - py::call_guard()); \ - py::class_()); \ + pybind11::class_, \ transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ - py::call_guard()); \ - py::class_()); \ + pybind11::class_, \ transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ - py::call_guard()); \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + pybind11::call_guard()); \ m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def( \ "get_stream_priority_range", \ [](int device_id = -1) { \ @@ -122,8 +122,8 @@ transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ return std::make_pair(low_pri, high_pri); \ }, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index ab56d60f59..cad5697daa 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -31,10 +31,10 @@ from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +module_name = "transformer_engine_jax" def _load_library(): """Load shared library with Transformer Engine C extensions""" - module_name = "transformer_engine_jax" if is_package_installed(module_name): assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." @@ -79,7 +79,9 @@ def _load_library(): spec.loader.exec_module(solib) -_load_library() +if module_name not in sys.modules: + _load_library() + from . import flax from . import quantize @@ -101,7 +103,6 @@ def _load_library(): ) __all__ = [ - "fp8_autocast", "MeshResource", "MajorShardingType", "ShardingResource", diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index ef8d76cd05..a9fe2e7d7b 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -4,7 +4,7 @@ """Python interface for c++ extensions""" from .activation import * from .attention import * +from .gemm import * from .normalization import * from .quantization import * from .softmax import * -from .gemm import * diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0fad75817f..9e9c6e71ce 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,15 +3,31 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict, List -from functools import partial, reduce +import warnings import operator +from functools import partial, reduce +from typing import Optional, Tuple, Sequence, Union, Dict, List +from packaging import version + from transformer_engine_jax import get_device_compute_capability import jax import jax.numpy as jnp +from jax import dtypes +from jax.sharding import PartitionSpec, NamedSharding +from jax.typing import ArrayLike from .base import BasePrimitive, register_primitive +from .misc import ( + jax_dtype_is_fp8, + get_padded_spec, + check_valid_batch_dims, +) +from ..sharding import ( + global_mesh_resource, + all_reduce_max_along_all_axes_except_PP, +) + from ..quantize import ( ScaledTensor, ScalingMode, @@ -20,12 +36,26 @@ noop_quantizer_set, ) +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + -__all__ = ["gemm", "grouped_gemm"] +__all__ = ["gemm", + "grouped_gemm", + "collective_gemm"] num_cublas_streams = 4 +def sanitize_dims(dim, ndims): + return (ndims + dim) if dim < 0 else dim + + +def mirror_dim(dim, ndims): + return ndims - 2 if dim == ndims - 1 else ndims - 1 + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" @@ -514,3 +544,678 @@ def grouped_gemm( out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape)) return out_tensors + +class CollectiveGemmPrimitive(BasePrimitive): + """ + cuBlasLt GEMM Primitive w/ support for distributed inputs + """ + + name = "te_gemm" + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) + multiple_results = True + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs_aval, + lhs_scale_inv_aval, + rhs_aval, + rhs_scale_inv_aval, + bias_aval, + gelu_input_aval, + out_amax_aval, + out_scale_aval, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + """ + cuBlasLt GEMM abstract + """ + del grad, accumulate, use_split_accumulator + + # Validate operand dtypes + lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) + rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) + assert lhs_dtype == rhs_dtype, "Mismatched matrix dtypes for GEMM." + is_fp8 = False + if jax_dtype_is_fp8(lhs_dtype): + assert ( + lhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(lhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing LHS operand scale inverse in FP8 GEMM." + is_fp8 = True + if jax_dtype_is_fp8(rhs_dtype): + assert ( + rhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing RHS operand scale inverse in FP8 GEMM." + + # Validate operand layouts + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) + assert ( + lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] + ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." + + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + assert not ( + lhs_trans and rhs_trans + ), "GEMM does not support transposed LHS and transposed RHS at the same time." + if is_fp8: + assert not lhs_trans, "FP8 GEMM does not support transposed LHS." + assert rhs_trans, "FP8 GEMM requires transposed RHS." + + # Validate output dtype + if jax_dtype_is_fp8(out_dtype): + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." + else: + out_dtype = lhs_dtype + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Make sure leading dimensions of RHS is broadcast-compatible with LHS + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim), + ) + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + + # Infer output shape + if batched_output: + assert ( + lhs_aval.ndim > 2 and rhs_aval.ndim == 2 + ), "Batched output requires batched LHS and non-batched RHS operands." + out_shape = ( + *lhs_batch_shape, + lhs_aval.shape[lhs_outer_dim], + rhs_aval.shape[rhs_outer_dim], + ) + else: + assert ( + lhs_aval.ndim == rhs_aval.ndim + ), "Non-batched output requires LHS and RHS operands with same number of dimensions." + if lhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + assert lhs_batch_size == rhs_batch_size, ( + f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_aval.shape=})." + ) + out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + + # Validate bias/bias_grad shape against inferred output + bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype + if fuse_bias: + assert ( + bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] + ), "Incorrect bias shape." + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + else: + assert bias_aval.size == 0, "Internal TE error." + + # Validate GELU input/output + gelu_shape = (0,) + if fuse_gelu: + gelu_shape = ( + (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) + if len(out_shape) > 2 + else out_shape + ) + assert gelu_input_aval.ndim == 2 and all( + gelu_input_aval.shape[i] == gelu_shape[i] for i in range(len(gelu_shape)) + ), "Invalid GELU input shape." + assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." + else: + assert gelu_input_aval.size == 0, "Internal TE error." + + # Create abstract arrays for all outputs + out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) + out_amax_updated_aval = out_amax_aval.update( + shape=out_amax_aval.shape, dtype=out_amax_updated_dtype + ) + out_scale_updated_aval = out_scale_aval.update( + shape=out_scale_aval.shape, dtype=out_scale_updated_dtype + ) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) + bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + workspace_aval = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) + + return ( + out_aval, + out_amax_updated_aval, + out_scale_updated_aval, + pre_gelu_out_aval, + bias_grad_aval, + workspace_aval, + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + cuBlasLt GEMM outer abstract + """ + (out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval, _) = ( + CollectiveGemmPrimitive.abstract(*args, **kwargs) + ) + return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval + @staticmethod + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + *, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + """ + Fused attention fwd lowering rules + """ + del batched_output + lhs_aval, _, rhs_aval, _, _, *_ = ctx.avals_in + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + + operand_output_aliases = { + 4: 4, # bias <--> bias_grad + 5: 3, # gelu_input <--> pre_gelu_out + 6: 1, # out_amax <--> out_amax_updated + 7: 2, # out_scale <--> out_scale_updated + } + + name = "te_gemm_ffi" + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + lhs_trans=lhs_trans, + rhs_trans=rhs_trans, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + + @staticmethod + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + assert CollectiveGemmPrimitive.inner_primitive is not None + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim) + ) + + # Infer output shape and collapse batch dimensions + lhs_2d_shape = rhs_2d_shape = None + lhs_layout = rhs_layout = None + lhs_batch_dims = [ + dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim] + ] + lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + contracting_dims_2d = list(contracting_dims).copy() + if lhs.ndim > 2 and rhs.ndim > 2: + # If both LHS and RHS are batched, the batch dimensions collapse into the + # contracting dimensions for both operands + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) + lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) + contracting_dims_2d[0] = 0 + + rhs_batch_dims = [ + dim for dim in range(rhs.ndim) if dim not in [rhs_inner_dim, rhs_outer_dim] + ] + rhs_batch_shape = [rhs.shape[dim] for dim in rhs_batch_dims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) + rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) + contracting_dims_2d[1] = 0 + elif lhs.ndim > 2: + # If only the LHS is batched,the batch dimension collapses into the outer dimension + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 + + # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM + if lhs_2d_shape is not None and lhs.ndim > 2: + lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout) + if jax_dtype_is_fp8(lhs.dtype): + lhs = jax.lax.transpose(lhs, (1, 0)) + contracting_dims_2d[0] = 1 + else: + contracting_dims_2d[0] = contracting_dims[0] + + if rhs_2d_shape is not None and rhs.ndim > 2: + rhs = jax.lax.reshape(rhs, rhs_2d_shape, dimensions=rhs_layout) + if jax_dtype_is_fp8(rhs.dtype): + rhs = jax.lax.transpose(rhs, (1, 0)) + contracting_dims_2d[1] = 1 + else: + contracting_dims_2d[1] = contracting_dims[1] + + # Invoke GEMM with guaranteed 2D inputs, so batched_output=False + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + _, + ) = CollectiveGemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype=out_dtype, + batched_output=False, + contracting_dims=contracting_dims_2d, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Recover batched dimensions in the output + if batched_output: + out_shape = (*lhs_batch_shape, out.shape[-2] // lhs_batch_size, out.shape[-1]) + out = jax.lax.reshape(out, out_shape) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): + assert CollectiveGemmPrimitive.outer_primitive is not None + check_valid_batch_dims(batch_dims) + lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims + + return ( + CollectiveGemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ), + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims), + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del out_dtype, accumulate, use_split_accumulator, result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs: + # - If contracting dimensions of both operands are sharded, force them to match. + # - If contracting dimensions of both operands are sharded, all-gather outer dimensions. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." + if lhs_spec_new[lhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the LHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + + if rhs_spec_new[rhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the RHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + rhs_spec_new[rhs_outer_dim] = None + else: + if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: + warnings.warn( + "Contracting dimension of the RHS operand must be all-gathered when the " + + "contracting dimension of the LHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: + if not grad: + # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + warnings.warn( + "Contracting dimension of the LHS operand must be all-gathered when the " + + "contracting dimension of the RHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + out_spec = [None, out_col_spec] + if batched_output: + out_spec = batch_spec + out_spec + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Bias gradient spec matches outer dimension of output if bias fusion is turned on + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) + + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) + + @staticmethod + def partition( + out_dtype, + batched_output, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): + del result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs: + # - Always all-gather the outer dimension of LHS. + # - If contracting dimensions of both operands are sharded, all-gather RHS outer dimension. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + reduce_output = False + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) + + # Bias is sharded to match outer dimension spec of the RHS operand (also the output) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + out_spec = [None, out_col_spec] + if batched_output: + out_spec = batch_spec + out_spec + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + arg_shardings = ( + lhs_sharding, + fp8_meta_sharding, + rhs_sharding, + fp8_meta_sharding, + bias_sharding, + gelu_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + ) + out_shardings = ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + ) + + def sharded_impl( + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale + ): + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + ) = CollectiveGemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # FP8 amax reduction + if jax_dtype_is_fp8(lhs.dtype): + out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) + + # All-reduce sum GEMM output when contracting dimensions are sharded + if reduce_output: + out = jax.lax.psum(out, global_mesh_resource().tp_resource) + if fuse_gelu: + pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CollectiveGemmPrimitive) + + +def collective_gemm( + lhs: Union[jnp.ndarray, ScaledTensor], + rhs: Union[jnp.ndarray, ScaledTensor], + bias: jnp.ndarray = None, + gelu_input: Optional[ArrayLike] = None, + batched_output: bool = False, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + fuse_gelu: bool = False, + grad: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + + if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): + scaling_mode = lhs.scaling_mode + out_dtype = lhs.dq_dtype + # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout + if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + assert not ( + lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 + ), "FP8 GEMM does not support E5M2 * E5M2" + else: + # For jnp.ndarray, only consider contracting_dims, layout is always NN + scaling_mode = ScalingMode.NVTE_NO_SCALING + out_dtype = lhs.dtype + + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + if bias is None: + bias = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + bias = jnp.zeros(rhs.shape[rhs_outer_dim], dtype=lhs.dtype) + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + elif gelu_input is None: + bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_size = reduce(operator.mul, [lhs.shape[dim] for dim in bdims], 1) + gelu_shape = (batch_size * lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(gelu_shape, dtype=lhs.dtypes) + + if scaling_mode == ScalingMode.NVTE_NO_SCALING: + lhs_scale_inv = jnp.ones(1, dtype=jnp.float32) + rhs_scale_inv = jnp.ones(1, dtype=jnp.float32) + if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: + lhs_scale_inv = lhs.scale_inv.reshape(-1) + rhs_scale_inv = rhs.scale_inv.reshape(-1) + if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: + lhs_scale_inv = lhs_scale_inv.reshape(-1) + rhs_scale_inv = rhs_scale_inv.reshape(-1) + + out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=contracting_dims, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if grad: + return out, pre_gelu_out, bias_grad + return out, pre_gelu_out, None diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 980ea556bb..d383e75143 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -161,6 +161,13 @@ def jax_version_meet_requirement(version: str): return jax_version >= jax_version_required +def jax_dtype_is_fp8(dtype): + """ + Check if the given jax.numpy.dtype is an FP8 dtype. + """ + return dtypes.canonicalize_dtype(dtype) in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + def get_xla_flag(flag: str, default=None, cast=str): """ Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 551b4b4bdb..9030387eb9 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -34,6 +34,14 @@ __all__ = ["quantize", "quantize_dbias"] +def _jax_cast_fp8(inputs, scale, amax, out_dtype): + """ + JAX native fp8 casting implementation + """ + casted_output = _jax_quantize(inputs, scale, dq_dtype=out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) + return casted_output, updated_amax + class DBiasQuantizePrimitive(BasePrimitive): """ diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 1950d6cbab..e8d26710d8 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -119,6 +119,18 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); // CuBLAS helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); +// GEMM + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 74909319cc..50f7023c1c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -210,5 +210,113 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode"), FFI_CudaGraph_Traits); +void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_shape, + float *lhs_scale_inv, bool lhs_trans, void *rhs, const std::vector &rhs_shape, + float *rhs_scale_inv, bool rhs_trans, DType operand_dtype, void *bias, + DType bias_dtype, void *out, float *out_amax, float *out_scale, DType out_dtype, + void *pre_gelu_out, void *workspace, size_t workspace_size, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, operand_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, operand_dtype, nullptr, nullptr, rhs_scale_inv); + + std::vector out_shape(2, 0); + out_shape[0] = (lhs_trans) ? lhs_shape[1] : lhs_shape[0]; + out_shape[1] = (rhs_trans) ? rhs_shape[0] : rhs_shape[1]; + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + void *bias_ptr = (fuse_bias) ? bias : nullptr; + std::vector bias_shape = + (fuse_bias) ? std::vector{out_shape[1]} : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + std::vector pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + // cuBLAS is column-major, so we swap LHS and RHS in the arguments + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_out_.data(), + (rhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, (lhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, workspace_.data(), accumulate, use_split_accumulator, num_math_sm, stream); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto operand_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + + // Outputs + auto out_ptr = out->untyped_data(); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->dimensions().back(); + + // Operand aliasing + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + std::vector lhs_shape(lhs.dimensions().begin(), lhs.dimensions().end()); + std::vector rhs_shape(rhs.dimensions().begin(), rhs.dimensions().end()); + + // Swap A and B argument locations to match what the TE/common kernel expects + GemmImpl(stream, lhs_ptr, lhs_shape, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, + rhs_scale_inv_ptr, rhs_trans, operand_dtype, bias_ptr, bias_dtype, out_ptr, out_amax_ptr, + out_scale_ptr, out_dtype, pre_gelu_out_ptr, workspace_ptr, workspace_size, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out_amax + .Arg() // out_scale + .Ret() // out + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index c777a02c99..35c4ae6c03 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -59,6 +59,7 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); return dict; } diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 43336768cb..e0fdc567dc 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -168,6 +168,216 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) +def collective_dense( + x, + kernel, + bias = None, + contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + quantizer_set: QuantizerSet = None, + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +): + """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" + return _collective_dense(x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator + ) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _collective_dense( + x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator, +): + out, _ = _collective_dense_fwd_rule( + x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator, + ) + return out + + +def _collective_dense_fwd_rule( + x, + kernel, + bias, + contracting_dims, + quantizer_set, + fuse_gelu, + accumulate, + use_split_accumulator, +): + assert ( + kernel.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + + fuse_bias = bias is not None + + if quantizer_set is None: + x_rowwise = x + x_colwise = x + kernel_colwise = kernel + kernel_rowwise = kernel + x_shape = x.shape + kernel_shape = kernel.shape + else: + q_x = tex.quantize(x, quantizer_set.x) + q_kernel = tex.quantize(kernel, quantizer_set.kernel) + x_rowwise = q_x.get_rowwise_tensor() + x_colwise = q_x.get_colwise_tensor() + kernel_colwise = q_kernel.get_colwise_tensor() + kernel_rowwise = q_kernel.get_rowwise_tensor() + x_shape = x_rowwise.data.shape + kernel_shape = kernel_rowwise.data.shape + + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + out, pre_gelu_out, _ = tex.collective_gemm( + x_rowwise, + kernel_colwise, + bias=bias, + batched_output=(x.ndim > 2), + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + ctx = ( + x_colwise, + kernel_rowwise, + x_shape, + kernel_shape, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + quantizer_set, + ) + + return out, ctx + + +def _collective_dense_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + ( + x_colwise, + kernel_rowwise, + x_shape, + kernel_shape, + pre_gelu_out, + fuse_bias, + quantizer_set, + ) = ctx + + if quantizer_set is None: + casted_grad = grad + bgrad = tex.quantization._jax_dbias(grad) + grad_rowwise = grad + grad_colwise = grad + else: + casted_grad, bgrad = tex.quantize_dbias( + grad, is_dbias=fuse_bias, quantizer=quantizer_set.dgrad + ) + grad_rowwise = casted_grad.get_rowwise_tensor() + grad_colwise = casted_grad.get_colwise_tensor() + + fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims + g_contracting_dim = tuple( + range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) + ) + k_contracting_dim = tuple( + dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims + ) + dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) + + # GEMM TN + g_contracting_dim = x_contracting_dim = tuple( + range(0, len(x_shape) - len(fwd_x_contracting_dims)) + ) + wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) + + # FWD MODE: + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + + # DGRAD: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) + # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) + # (DP, None, None) x (TP, None)^T --> (DP, None, TP) + dgrad, dgelu, _ = tex.collective_gemm( + grad_rowwise, + kernel_rowwise, + gelu_input=pre_gelu_out, + batched_output=(x_colwise.ndim > 2), + contracting_dims=dgrad_contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=False, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # WGRAD: + # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') + # + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) + # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) + # Make XLA scatter output in first dim. + wgrad_rhs = dgelu if fuse_gelu else grad_colwise + wgrad, _, bgrad = tex.collective_gemm( + x_colwise, + wgrad_rhs, + gelu_input=pre_gelu_out, + batched_output=False, + contracting_dims=wgrad_contracting_dims, + fuse_gelu=False, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_collective_dense.defvjp(_collective_dense_fwd_rule, _collective_dense_bwd_rule) + + def grouped_dense( x_list, kernel_list,