diff --git a/examples/jax/comm_overlap/flax_with_overlap.py b/examples/jax/comm_overlap/flax_with_overlap.py new file mode 100644 index 0000000000..801fc0bfd7 --- /dev/null +++ b/examples/jax/comm_overlap/flax_with_overlap.py @@ -0,0 +1,259 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" +import os +import argparse +from functools import partial + +from mpi4py import MPI + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from flax.linen import partitioning as nn_partitioning + +import transformer_engine.jax as te +import transformer_engine_jax as tex +from transformer_engine.jax.sharding import ( + get_padded_spec, + MeshResource, + HIDDEN_AXES, + HIDDEN_TP_AXES, + BATCH_AXES, + SEQLEN_TP_AXES, + SEQLEN_AXES, + W_NO_SHARD_AXES, + W_FSDP_AXES, + W_TP_AXES, + W_JOINED_AXES, +) +from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP +from transformer_engine.common import recipe + +# This script needs to be launched via `mpirun` with 1 process per GPU +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.clear_caches() +jax.distributed.initialize(cluster_detection_method="mpi4py") +assert ( + jax.local_device_count() == 1 +), f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" + +# Parse script arguments +_supported_layers = (DenseGeneral, LayerNormDenseGeneral, LayerNormMLP) +_layer_map = dict((layer.__name__.lower(), layer) for layer in _supported_layers) + + +def _te_flax_layer(layer_name): + assert isinstance(layer_name, str) and layer_name.lower() in _layer_map + return _layer_map[layer_name.lower()] + + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) +parser.add_argument("-np", "--num-gpus", type=int, default=numranks) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument( + "--layer-type", type=_te_flax_layer, default=DenseGeneral, choices=_supported_layers +) +parser.add_argument( + "--fp8-recipe", type=str.lower, default="none", choices=["none", "current", "delayed", "mxfp8"] +) +parser.add_argument("--check-result", action="store_true") +parser.add_argument("--seed", type=int, default=42) +args = parser.parse_args() + +# FP8 recipe +fp8_recipe = None +match args.fp8_recipe: + case "current": + fp8_recipe = recipe.Float8CurrentScaling() + case "delayed": + fp8_recipe = recipe.DelayedScaling() + case "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling() + case _: + fp8_recipe = None + +# Single GPU evaluation +layer_kwargs = {"use_bias": True} +match args.layer_type: + case DenseGeneral: + layer_kwargs.update({"features": args.hidden_size, "name": "proj"}) + case LayerNormDenseGeneral: + layer_kwargs.update( + {"features": 3 * args.hidden_size, "return_layernorm_output": False, "name": "qkv"} + ) + case LayerNormMLP: + layer_kwargs.update( + { + "intermediate_dim": args.activation_size, + "return_layernorm_output": False, + "name": "mlp", + } + ) + +rng = jax.random.PRNGKey(args.seed) +rng, params_rng = jax.random_split(rng) +init_rngs = {"params": params_rng} + +dtype = jnp.bfloat16 +input_shape = (args.seq_length, args.hidden_size) +if not args.no_batch: + input_shape = (args.batch_size,) + input_shape +x = jnp.random.normal(rng, input_shape, dtype=jnp.bfloat16) + +with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + model_single = partial(args.layer_type, **layer_kwargs) + params_single = model_single.init(init_rngs, x, deterministic=True) + output_single = model_single.apply(params_single, x, deterministic=True) + +# Resources and partition specs +DEVICE_DP_AXIS = "dp" +DEVICE_TP_AXIS = "tp" +mesh_shape = (args.dp_size, args.tp_size) +mesh_axes = (DEVICE_DP_AXIS, DEVICE_TP_AXIS) +mesh_resource = MeshResource( + dp_resource=DEVICE_DP_AXIS if args.no_fsdp else None, + fsdp_resource=None if args.no_fsdp else DEVICE_DP_AXIS, + tp_resource=DEVICE_TP_AXIS, +) + +INPUT_AXES = ( + SEQLEN_TP_AXES if args.layer_type != DenseGeneral else SEQLEN_AXES, + HIDDEN_AXES if args.layer_type != DenseGeneral else HIDDEN_TP_AXES, +) +INTERMEDIATE_AXES = (SEQLEN_AXES, HIDDEN_TP_AXES) +if not args.no_batch: + INPUT_AXES = (BATCH_AXES,) + INPUT_AXES + INTERMEDIATE_AXES = (BATCH_AXES,) + INTERMEDIATE_AXES + +LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES,) + +KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) +BIAS_AXES_ROW_PARALLEL = (W_NO_SHARD_AXES,) +KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) +BIAS_AXES_COL_PARALLEL = (W_TP_AXES,) +if args.layer_type == LayerNormMLP: + KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) + BIAS_AXES_COL_PARALLEL = (W_JOINED_AXES, W_NO_SHARD_AXES) + +# Multi GPU evaluation +layer_kwargs.update({"enable_comm_overlap": True}) +if args.layer_type in (DenseGeneral, LayerNormDenseGeneral): + layer_kwargs.update( + { + "kernel_axes": KERNEL_AXES_COL_PARALLEL, + "bias_axes": BIAS_AXES_COL_PARALLEL, + "comm_overlap_config": {"method": tex.CommOverlapMethod.RING_EXCHANGE}, + } + ) + if args.layer_type == LayerNormDenseGeneral: + layer_kwargs.update( + { + "layernorm_input_axes": INPUT_AXES, + "scale_axes": LN_SCALE_AXES, + "ln_bias_axes": LN_BIAS_AXES, + "dot_input_axes": INPUT_AXES, + } + ) +else: + layer_kwargs.update( + { + "layernorm_input_axes": INPUT_AXES, + "scale_axes": LN_SCALE_AXES, + "ln_bias_axes": LN_BIAS_AXES, + "dot_1_input_axes": INPUT_AXES, + "kernel_1_axes": KERNEL_AXES_COL_PARALLEL, + "bias_axes_1": BIAS_AXES_COL_PARALLEL, + "dot_2_input_axes": INTERMEDIATE_AXES, + "kernel_2_axes": KERNEL_AXES_ROW_PARALLEL, + "bias_axes_2": BIAS_AXES_ROW_PARALLEL, + "dot_1_comm_overlap_config": {"method": tex.CommOverlapMethod.RING_EXCHANGE}, + "dot_2_comm_overlap_config": {"method": tex.CommOverlapMethod.RING_EXCHANGE}, + } + ) + +device_mesh = mesh_utils.create_device_mesh((args.dp_size, args.tp_size)) +mesh = Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)) +axis_rules = nn_partitioning.axis_rules( + ( + (BATCH_AXES, DEVICE_DP_AXIS), + (SEQLEN_AXES, None), + (SEQLEN_TP_AXES, DEVICE_TP_AXIS), + (HIDDEN_AXES, None), + (HIDDEN_TP_AXES, DEVICE_TP_AXIS), + (W_NO_SHARD_AXES, None), + (W_JOINED_AXES, None), + (W_FSDP_AXES, None if args.no_fsdp else DEVICE_DP_AXIS), + (W_TP_AXES, DEVICE_TP_AXIS), + ) +) +with ( + mesh, + axis_rules, + te.fp8_autocast( + enabled=fp8_recipe is not None, + fp8_recipe=fp8_recipe, + mesh_resource=mesh_resource, + ), +): + model_sharded = partial(args.layer_type, **layer_kwargs) + params_sharded = model_sharded.init(init_rngs, x, deterministic=True) + output_sharded = model_sharded.apply(params_sharded, x, deterministic=True) + +if myrank == 0: + print( + f"{myrank}: {args.layer_type.__name__} OUTPUT {output_sharded.shape}\n" + + f" Sharding: {get_padded_spec(output_sharded.sharding.spec, output_sharded.ndim)}\n", + flush=True, + ) + +if args.check_result: + output_gathered = jax.lax.with_sharding_constraint( + output_sharded, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(output_gathered) + + diff = jnp.abs(output_single - output_gathered).flatten() + if myrank == 0: + print(f"{myrank}: Global output difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(output_single.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + numerics_failed = False + if rel_err > rtol and abs_err > atol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m} " + + f"with {output_gathered.flatten()[m].item()} vs {output_single.flatten()[m].item()} " + + f"| rel. error = {rel_err} (tol = {rtol}) " + + f"| abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +tex.destroy_all_comm_overlap_buffers() diff --git a/examples/jax/comm_overlap/gemm_with_overlap.py b/examples/jax/comm_overlap/gemm_with_overlap.py new file mode 100644 index 0000000000..615be69e15 --- /dev/null +++ b/examples/jax/comm_overlap/gemm_with_overlap.py @@ -0,0 +1,226 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" + +import argparse +from functools import partial +from pprint import pprint + +import numpy as np +from mpi4py import MPI + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +import transformer_engine_jax as tex +from transformer_engine.jax.sharding import get_padded_spec +from transformer_engine.jax.cpp_extensions import ( + gemm, + CommOverlapHelper, +) + +jax.clear_caches() + +# This script needs to be launched via `mpirun` with 1 process per GPU +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.distributed.initialize(cluster_detection_method="mpi4py") + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=1) +parser.add_argument("-zp", "--fsdp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) +parser.add_argument("-np", "--num-gpus", type=int, default=numranks) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"]) +parser.add_argument("--check-result", action="store_true") +args = parser.parse_args() + +# Operand shapes +dtype = jnp.bfloat16 +lhs_shape = ( + [args.seq_length, args.hidden_size] + if args.comm_type == "AG" + else [args.seq_length, args.activation_size] +) +rhs_shape = ( + [args.hidden_size, args.activation_size] + if args.comm_type == "AG" + else [args.activation_size, args.hidden_size] +) + +# Operand partitioning +batched = not args.no_batch +fsdp = not args.no_fsdp +input_specs = [None] * len(lhs_shape) +weight_specs = [None] * len(rhs_shape) +if batched: + lhs_shape = [args.batch_size] + lhs_shape + if fsdp: + mesh_shape = {"dp": args.dp_size, "zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", tp_resource="tp", cp_resource="tp", fsdp_resource="zp" + ) + if args.comm_type == "AG": + input_specs = [("dp", "zp"), "tp", None] + weight_specs = ["zp", "tp"] + elif args.comm_type == "RS": + input_specs = [("dp", "zp"), None, "tp"] + weight_specs = ["tp", "zp"] + else: + mesh_shape = {"dp": args.dp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", + tp_resource="tp", + cp_resource="tp", + ) + if args.comm_type == "AG": + input_specs = ["dp", "tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = ["dp", None, "tp"] + weight_specs = ["tp", None] +else: + if fsdp: + mesh_shape = {"zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource(fsdp_resource="zp", tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = ["zp", "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", "zp"] + else: + mesh_shape = {"tp": args.tp_size} + mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", None] + +# Mesh setup and sharding definitions +devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) +mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +no_sharding = NamedSharding(mesh, PartitionSpec(None)) +input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) +weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) + +# Operand initialization +key = jax.random.PRNGKey(0) +key1, key2 = jax.random.split(key, 2) +lhs_data = jax.random.normal(key1, lhs_shape, dtype=dtype) +rhs_data = jax.random.normal(key2, rhs_shape, dtype=dtype) +lhs = jax.device_put(lhs_data, input_sharding) +rhs = jax.device_put(rhs_data, weight_sharding) +dimension_numbers = (((-1,), (0,)), ((0,), ())) + +# Name of comm+GEMM overlap layer +overlap_method = tex.CommOverlapMethod.RING_EXCHANGE +comm_type = tex.CommOverlapType.AG if args.comm_type == "AG" else tex.CommOverlapType.RS + +# Bootstrap Userbuffers communicators and communication buffers +# NOTE: All-gather overlap requires buffer to be sized the LHS operand's global shape. +# Reduce-scatter overlap requires buffer to be sized to the GEMM output's global shape. +output_shape = (*lhs_shape[:-1], rhs_shape[-1]) +buffer_shape = list(lhs_shape if comm_type == tex.CommOverlapType.AG else output_shape).copy() +if batched: + # The only all-gathered dimension is sequence, batch is still sharded for the buffer + buffer_shape[0] = buffer_shape[0] // (args.dp_size * args.fsdp_size) +overlap_helper = CommOverlapHelper( + method=overlap_method, + comm_type=comm_type, + buffer_shape=buffer_shape, + buffer_dtype=dtype, + tp_size=args.tp_size, + tp_resource="tp", + sp_resource="tp", +) +if myrank == 0: + print(f"{myrank}: OVERLAP CONFIG:", flush=True) + pprint(overlap_helper) + print( + f"\n{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n" + + f"{myrank}: LHS sharding: {lhs.sharding.spec}\n" + + f"{myrank}: RHS sharding: {rhs.sharding.spec}\n", + flush=True, + ) + + +@jax.jit +def _gemm_wrapper(x, y): + return partial( + gemm, + dimension_numbers=(((-1,), (0,)), ((0,), ())), + comm_overlap=overlap_helper, + )(x, y) + + +with te.sharding.global_shard_guard(mesh_resource): + output = _gemm_wrapper(lhs, rhs) + +jax.block_until_ready(output) +if myrank == 0: + print( + f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT " + + f"{output.shape}\n" + + f"{myrank}: Sharding: {get_padded_spec(output.sharding.spec, output.ndim)}\n", + flush=True, + ) + +if args.check_result: + ref_global = jnp.matmul( + jax.device_put(lhs_data, no_sharding), jax.device_put(rhs_data, no_sharding) + ) + jax.block_until_ready(ref_global) + if myrank == 0: + print(f"{myrank}: Global reference: {ref_global}\n", flush=True) + + output_global = jax.lax.with_sharding_constraint(output, no_sharding) + jax.block_until_ready(output_global) + if myrank == 0: + print(f"{myrank}: Global output: {output_global}\n", flush=True) + + diff = jnp.abs(ref_global - output_global).flatten() + if myrank == 0: + print(f"{myrank}: Global difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(ref_global.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + numerics_failed = False + if rel_err > rtol and abs_err > atol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m} " + + f"with {output.flatten()[m].item()} vs {ref_global.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +tex.destroy_all_comm_overlap_buffers() diff --git a/examples/jax/comm_overlap/layer_prim_with_overlap.py b/examples/jax/comm_overlap/layer_prim_with_overlap.py new file mode 100644 index 0000000000..790c1092ad --- /dev/null +++ b/examples/jax/comm_overlap/layer_prim_with_overlap.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" +import os +import argparse +from functools import partial + +from mpi4py import MPI + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.common import recipe +from transformer_engine.jax.sharding import ( + MeshResource, + global_shard_guard, + generate_pspec, + BATCH_AXES, + SEQLEN_AXES, + SEQLEN_TP_AXES, + HIDDEN_AXES, + HIDDEN_TP_AXES, + JOINED_AXES, + W_FSDP_AXES, + W_NO_SHARD_AXES, + W_JOINED_AXES, + W_TP_AXES, +) +from transformer_engine.jax.dense import dense +from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.layernorm_mlp import layernorm_mlp +from transformer_engine.jax.cpp_extensions import CommOverlapHelper, CommOverlapHelperSet + +import transformer_engine_jax as tex + +# This script needs to be launched via `mpirun` with 1 process per GPU +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.clear_caches() +jax.distributed.initialize(cluster_detection_method="mpi4py") +assert ( + jax.local_device_count() == 1 +), f"[{myrank}|{numranks}] Expected 1 GPU per process, found {jax.local_device_count()}" + +# Parse script arguments +_supported_prims = (dense, layernorm_dense, layernorm_mlp) +_prim_map = dict((prim.__name__.lower(), prim) for prim in _supported_prims) + + +def _te_layer_prim(prim_name): + assert isinstance(prim_name, str) and prim_name.lower() in _prim_map + return _prim_map[prim_name.lower()] + + +parser = argparse.ArgumentParser() +parser.add_argument("-dp", "--dp-size", type=int, default=1) +parser.add_argument("-zp", "--fsdp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=numranks // 2) +parser.add_argument("-np", "--num-gpus", type=int, default=numranks) +parser.add_argument("--batch-size", type=int, default=2) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--hidden-size", type=int, default=16384) +parser.add_argument("--activation-size", type=int, default=53248) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--layer-type", type=_te_layer_prim, default=dense, choices=_supported_prims) +parser.add_argument( + "--fp8-recipe", type=str.lower, default="none", choices=["none", "current", "delayed", "mxfp8"] +) +parser.add_argument("--check-result", action="store_true") +parser.add_argument("--seed", type=int, default=42) +args = parser.parse_args() + +# FP8 recipe +fp8_recipe = None +match args.fp8_recipe: + case "current": + fp8_recipe = recipe.Float8CurrentScaling() + case "delayed": + fp8_recipe = recipe.DelayedScaling() + case "mxfp8": + fp8_recipe = recipe.MXFP8BlockScaling() + case _: + fp8_recipe = None + +# Declare inputs +dtype = jnp.bfloat16 +input_shape = (args.seq_length, args.hidden_size) +if not args.no_batch: + input_shape = (args.batch_size,) + input_shape +features = args.hidden_size # post-attention projection +if args.layer_type is layernorm_dense: + features *= 3 # QKV projection +kernel_shape = ( + (args.hidden_size, 1, args.activation_size) # MLP FFN1 + if args.layer_type is layernorm_mlp + else (args.hidden_size, features) +) +bias_shape = (1, args.activation_size) if args.layer_type is layernorm_mlp else (features,) + +rng = jax.random.PRNGKey(args.seed) +rng, params_rng = jax.random.split(rng) +params_rng, kernel_rng = jax.random.split(params_rng) +params_rng, bias_rng = jax.random.split(params_rng) +x = jax.random.normal(rng, input_shape, dtype=jnp.bfloat16) + +gamma = beta = None +if args.layer_type in (layernorm_dense, layernorm_mlp): + params_rng, gamma_rng = jax.random.split(params_rng) + gamma = jax.random.normal(gamma_rng, (args.hidden_size,), dtype=jnp.bfloat16) + params_rng, beta_rng = jax.random.split(params_rng) + beta = jax.random.normal(beta_rng, (args.hidden_size,), dtype=jnp.bfloat16) + +kernel_1 = jax.random.normal(kernel_rng, kernel_shape, dtype=jnp.bfloat16) +bias_1 = jax.random.normal(bias_rng, bias_shape, dtype=jnp.bfloat16) + +kernel_2 = bias_2 = None +if args.layer_type is layernorm_mlp: + kernel_rng, kernel_2_rng = jax.random.split(kernel_rng) + kernel_2 = jax.random.normal( + kernel_2_rng, (args.activation_size, args.hidden_size), dtype=jnp.bfloat16 + ) + bias_rng, bias_2_rng = jax.random.split(bias_rng) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_size,), dtype=jnp.bfloat16) + +if myrank == 0: + print( + f"[{myrank}|{numranks}] {args.layer_type.__name__} inputs:\n" + + f" x: {x.shape}\n" + + f" gamma: {gamma.shape if gamma is not None else None}\n" + + f" beta: {beta.shape if beta is not None else None}\n" + + f" kernel_1: {kernel_1.shape}\n" + + f" bias_1: {bias_1.shape}\n" + + f" kernel_2: {kernel_2.shape if kernel_2 is not None else None}\n" + + f" bias_2: {bias_2.shape if bias_2 is not None else None}\n" + ) + + +# Single GPU evaluation +def _eval_layer_serial(layer_type_, x_, gamma_, beta_, kernel_1_, bias_1_, kernel_2_, bias_2_): + layer_args = [] + layer_kwargs = {} + + if layer_type_ is dense: + layer_args = (x_, kernel_1_, bias_1_) + layer_kwargs = {"contracting_dims": ((x.ndim - 1,), (0,))} + + elif layer_type_ is layernorm_dense: + layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) + + elif layer_type_ is layernorm_mlp: + layer_args = (x_, gamma_, beta_, (kernel_1_, kernel_2_), (bias_1_, bias_2_)) + + return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) + + +with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + fwd_bwd_serial = jax.jit( + jax.value_and_grad(partial(_eval_layer_serial, args.layer_type), argnums=range(7)) + ) + output_serial, grads_serial = fwd_bwd_serial(x, gamma, beta, kernel_1, bias_1, kernel_2, bias_2) + +# Device mesh and logical axis resources +DEVICE_FSDP_AXIS = "zp" +DEVICE_DP_AXIS = "dp" +DEVICE_TP_AXIS = "tp" +mesh_shape = {DEVICE_TP_AXIS: args.tp_size} +if not args.no_batch: + mesh_shape[DEVICE_DP_AXIS] = args.dp_size +if not args.no_fsdp: + mesh_shape[DEVICE_FSDP_AXIS] = args.fsdp_size +devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) +mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +mesh_resource = MeshResource( + dp_resource=None if args.no_batch else DEVICE_DP_AXIS, + fsdp_resource=None if args.no_fsdp else DEVICE_FSDP_AXIS, + tp_resource=DEVICE_TP_AXIS, +) +if myrank == 0: + print(f"[{myrank}|{numranks}] Device mesh: {mesh}\n") + +# Logical axes +INPUT_AXES = ( + SEQLEN_AXES if args.layer_type is dense else SEQLEN_TP_AXES, + HIDDEN_TP_AXES if args.layer_type is dense else HIDDEN_AXES, +) +INTERMEDIATE_AXES = (SEQLEN_AXES, HIDDEN_TP_AXES) +if not args.no_batch: + INPUT_AXES = (BATCH_AXES,) + INPUT_AXES + INTERMEDIATE_AXES = (BATCH_AXES,) + INTERMEDIATE_AXES + +LN_SCALE_AXES = LN_BIAS_AXES = (W_NO_SHARD_AXES,) + +KERNEL_AXES_ROW_PARALLEL = (W_TP_AXES, W_FSDP_AXES) +BIAS_AXES_ROW_PARALLEL = (W_FSDP_AXES,) +KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_TP_AXES) +BIAS_AXES_COL_PARALLEL = (W_TP_AXES,) +if args.layer_type is layernorm_mlp: + KERNEL_AXES_COL_PARALLEL = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES) + BIAS_AXES_COL_PARALLEL = (W_JOINED_AXES, W_TP_AXES) + +KERNEL_1_AXES = KERNEL_AXES_ROW_PARALLEL if args.layer_type is dense else KERNEL_AXES_COL_PARALLEL +BIAS_1_AXES = BIAS_AXES_ROW_PARALLEL if args.layer_type is dense else BIAS_AXES_COL_PARALLEL +KERNEL_2_AXES = KERNEL_AXES_ROW_PARALLEL if args.layer_type is layernorm_mlp else None +BIAS_2_AXES = BIAS_AXES_ROW_PARALLEL if args.layer_type is layernorm_mlp else None + + +# Multi GPU evaluation +def _eval_layer_sharded( + layer_type_, + comm_overlaps_, + x_, + gamma_, + beta_, + kernel_1_, + bias_1_, + kernel_2_, + bias_2_, +): + layer_args = [] + layer_kwargs = {} + + if layer_type_ is dense: + layer_args = (x_, kernel_1_, bias_1_) + layer_kwargs = { + "input_axes": INPUT_AXES, + "kernel_axes": KERNEL_AXES_ROW_PARALLEL, + "comm_overlaps": comm_overlaps_[0], + "contracting_dims": ((x.ndim - 1,), (0,)), + } + + elif layer_type_ is layernorm_dense: + layer_args = (x_, kernel_1_, gamma_, beta_, bias_1_) + layer_kwargs = { + "layernorm_input_axes": INPUT_AXES, + "dot_input_axes": INPUT_AXES, + "kernel_axes": KERNEL_AXES_COL_PARALLEL, + "comm_overlaps": comm_overlaps_[0], + } + + elif layer_type_ is layernorm_mlp: + layer_args = (x_, gamma_, beta_, (kernel_1_, kernel_2_), (bias_1_, bias_2_)) + layer_kwargs = { + "norm_input_axes": INPUT_AXES, + "dot_1_input_axes": INPUT_AXES, + "kernel_1_axes": KERNEL_AXES_COL_PARALLEL, + "dot_2_input_axes": INTERMEDIATE_AXES, + "kernel_2_axes": KERNEL_AXES_ROW_PARALLEL, + "ffn1_comm_overlaps": comm_overlaps_[0], + "ffn2_comm_overlaps": comm_overlaps_[1], + } + + return jnp.mean(layer_type_(*layer_args, **layer_kwargs)) + + +with ( + mesh, + global_shard_guard(mesh_resource), + te.fp8_autocast( + enabled=fp8_recipe is not None, + fp8_recipe=fp8_recipe, + mesh_resource=mesh_resource, + ), +): + # Comm+GEMM overlap configs + # NOTE: Need to set `tp_resource=` kwarg when *not* initializing under a `global_shard_guard()`. + # Also need `logical_tp_axis=` and `logical_sp_axis=` kwargs if they differ from TE's + # built-in logical axis names. + buffer_shape = list(input_shape).copy() + if not args.no_batch: + buffer_shape[0] = buffer_shape[0] // (args.dp_size * args.fsdp_size) + fprop_1_overlap = CommOverlapHelper( + comm_type=tex.CommOverlapType.RS if args.layer_type is dense else tex.CommOverlapType.AG, + method=tex.CommOverlapMethod.RING_EXCHANGE, + buffer_shape=buffer_shape, + ) + comm_overlaps = [CommOverlapHelperSet(fprop=fprop_1_overlap)] + if args.layer_type is layernorm_mlp: + fprop_2_overlap = CommOverlapHelper( + comm_type=tex.CommOverlapType.RS, + method=tex.CommOverlapMethod.RING_EXCHANGE, + buffer_shape=buffer_shape, + ) + comm_overlaps.append(CommOverlapHelperSet(fprop=fprop_2_overlap)) + + x_sharding = NamedSharding(mesh, generate_pspec(INPUT_AXES)) + x = jax.device_put(x, x_sharding) + + gamma_sharding = beta_sharding = None + if gamma is not None: + gamma_sharding = NamedSharding(mesh, generate_pspec(LN_SCALE_AXES)) + gamma = jax.device_put(gamma, gamma_sharding) + if beta is not None: + beta_sharding = NamedSharding(mesh, generate_pspec(LN_BIAS_AXES)) + beta = jax.device_put(beta, beta_sharding) + + kernel_1_sharding = NamedSharding(mesh, generate_pspec(KERNEL_1_AXES)) + bias_1_sharding = NamedSharding(mesh, generate_pspec(BIAS_1_AXES)) + + kernel_2_sharding = bias_2_sharding = None + if kernel_2 is not None: + kernel_2_sharding = NamedSharding(mesh, generate_pspec(KERNEL_2_AXES)) + kernel_2 = jax.device_put(kernel_2, kernel_2_sharding) + if bias_2 is not None: + bias_2_sharding = NamedSharding(mesh, generate_pspec(BIAS_2_AXES)) + bias_2 = jax.device_put(bias_2, bias_2_sharding) + + input_shardings = ( + x_sharding, + gamma_sharding, + beta_sharding, + kernel_1_sharding, + bias_1_sharding, + kernel_2_sharding, + bias_2_sharding, + ) + output_shardings = ( + NamedSharding(mesh, PartitionSpec()), + input_shardings, + ) + value_and_grad_sharded = jax.jit( + jax.value_and_grad( + partial(_eval_layer_sharded, args.layer_type, comm_overlaps), argnums=range(7) + ), + in_shardings=input_shardings, + out_shardings=output_shardings, + ) + + output_sharded, grads_sharded = value_and_grad_sharded( + x, gamma, beta, kernel_1, bias_1, kernel_2, bias_2 + ) + +if args.check_result: + diff = jnp.abs(output_serial - output_sharded) + if myrank == 0: + print( + f"[{myrank}|{numranks}] Output: serial = {output_serial} | sharded = {output_sharded}" + ) + rel_err = diff / max(abs(diff), 1e-5) + if rel_err > 0.02 and diff > 0.001: + if myrank == 0: + print("NUMERICAL CHECK_FAILED: Output not close enough!\n") + else: + if myrank == 0: + print("NUMERICAL CHECK PASSED\n") + + labels = ("dX", "dGamma", "dBeta", "dKernel_1", "dBias_1", "dKernel_2", "dBias_2") + for i, (serial, sharded) in enumerate(zip(grads_serial, grads_sharded)): + if serial is not None and sharded is not None: + if myrank == 0: + print( + f"[{myrank}|{numranks}] {labels[i]} : {sharded.shape}\n" + + f" Sharding: {sharded.sharding.spec}\n" + ) + gathered = jax.lax.with_sharding_constraint( + sharded, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(gathered) + diff = jnp.abs(serial - gathered).flatten() + if myrank == 0: + print(f"{myrank}: Global {labels[i]} difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(output_serial.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + if rel_err > rtol and abs_err > atol: + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"{labels[i]} not close enough at index {m} " + + f"with {gathered.flatten()[m].item()} vs {serial.flatten()[m].item()} " + + f"| rel. error = {rel_err} (tol = {rtol}) " + + f"| abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + +tex.destroy_all_comm_overlap_buffers() diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 1f45d10faf..00203a4537 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -25,6 +25,7 @@ assert_params_sufficiently_sharded, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -465,8 +466,8 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): @@ -516,6 +517,9 @@ def test_te_mxfp8_with_sp(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True @@ -523,6 +527,9 @@ def test_te_bf16_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True @@ -532,6 +539,9 @@ def test_te_delayed_scaling_fp8_shardy(self): assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_with_sp_shardy(self): """Test Transformer Engine with DelayedScaling FP8 + SP""" self.args.enable_shardy = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 12148b0e29..44cafa7396 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -21,6 +21,7 @@ from common import is_bf16_supported, get_fp8_recipe_from_name_string import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_fp8_available, ScalingMode @@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase): is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) def setUp(self): - """Run 3 epochs for testing""" - self.args = encoder_parser(["--epochs", "3"]) + """Run 5 epochs for testing""" + self.args = encoder_parser(["--epochs", "5"]) @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -445,7 +446,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -453,7 +454,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -461,34 +462,43 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_fp8_supported, fp8_reason) + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.535 and actual[1] > 0.73 + assert actual[0] < 0.536 and actual[1] > 0.73 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 580824cefa..b5d03c0796 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -28,8 +28,8 @@ get_fp8_recipe_from_name_string, ) import transformer_engine.jax as te +import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase): """Encoder unittests""" def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): - """Run 3 epochs for testing""" - args = encoder_parser([]) + """Run 5 epochs for testing""" + args = encoder_parser(["--epochs", "5"]) num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 @@ -634,6 +634,9 @@ def test_te_mxfp8(self): assert result[0] < 0.505 and result[1] > 0.754 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" result = self.exec(False, None, enable_shardy=True) @@ -642,6 +645,9 @@ def test_te_bf16_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" ) + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_delayed_scaling_fp8_shardy(self): """Test Transformer Engine with DelayedScaling FP8""" result = self.exec(True, "DelayedScaling", enable_shardy=True) @@ -652,6 +658,9 @@ def test_te_delayed_scaling_fp8_shardy(self): @unittest.skipIf( not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" ) + @unittest.skipIf( + not tex.gemm_uses_jax_dot(), "TE cuBLAS GEMM custom op does not support shardy" + ) def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 349916cafe..a50d5363ae 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -30,7 +30,6 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( - DelayedScaleQuantizer, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -851,6 +850,22 @@ def test_quantize_dact_dbias_mxfp8_scaling( ) +valid_fp8_gemm_operand_types = [ + (jnp.float8_e4m3fn, jnp.float8_e4m3fn), + (jnp.float8_e5m2, jnp.float8_e4m3fn), + (jnp.float8_e4m3fn, jnp.float8_e5m2), +] + + +def _use_jax_fp8_gemm(enabled=False): + import os + + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": @@ -883,27 +898,46 @@ def _generate_gemm_input(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout): x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) - primitive_out = tex.gemm(x, w, contracting_dims) + primitive_out = tex.gemm(x, w, dimension_numbers=(contracting_dims, ((), ()))) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) + @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) - def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + if ( + not with_jax_gemm + and scaling_mode.is_1d_block_scaling() + and jnp.float8_e5m2 in (x_qtype, w_qtype) + ): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + + _use_jax_fp8_gemm(enabled=with_jax_gemm) + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=False, ) primitive_out = tex.gemm( - x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set + x, + w, + dimension_numbers=(contracting_dims, ((), ())), + lhs_quantizer=quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad, + rhs_quantizer=( + quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad + ), ) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): @@ -932,9 +966,11 @@ def ref_func(x, w, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) - def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + _use_jax_fp8_gemm(enabled=with_jax_gemm) + data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -956,7 +992,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True + scaling_mode=scaling_mode, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, + is_2x2x=True, ) n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 @@ -969,10 +1008,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=q_dtype) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) + assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2) @pytest.fixture(name="random_inputs") @@ -996,19 +1035,14 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) + def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") + _use_jax_fp8_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1025,8 +1059,8 @@ def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1072,32 +1106,27 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) - assert_allclose(prim_out, ref_out, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) - assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) - @pytest.mark.parametrize("use_bias", [True, False]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias + self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule """ - # No Norm FWD E5M2 in TE backend - if q_dtype == jnp.float8_e5m2 and scaling_mode in ( - ScalingMode.DELAYED_TENSOR_SCALING, - ScalingMode.CURRENT_TENSOR_SCALING, - ): - pytest.skip("E5M2 is not supported in normalization with TE Backend!") + _use_jax_fp8_gemm(enabled=with_jax_gemm) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False @@ -1123,8 +1152,8 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, scaling_mode=scaling_mode, - fwd_dtype=q_dtype, - bwd_dtype=q_dtype, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, is_2x2x=True, ) @@ -1153,14 +1182,13 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ln_out = _ref_jax_norm_impl( x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None ) - # TODO: replace gemm with jnp.dot - linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,))) + linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ()))) if use_bias: bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) x = _jax_act_lu(linear_1_out, activation_type) - linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) + linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape linear_2_out += jnp.reshape(bias_2, bias_2_shape) @@ -1193,18 +1221,18 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=q_dtype) + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) - assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) # E5M2 * E5M2 is not supported diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index f16c84094d..a093ff5d91 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -44,6 +44,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) if is_mxfp8_supported: SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) +SUPPORTED_RECIPES_WITH_SHARDY = SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES DTYPES = [jnp.bfloat16, jnp.float16] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] @@ -74,6 +75,15 @@ def generate_fsdp_and_tp_configs(): return configs +def use_jax_fp8_gemm(enabled=False): + import os + + if enabled: + os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: + os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + + class TestDistributedLayernormMLP: def generate_inputs(self, input_shape, activation_type, use_bias, dtype): @@ -146,8 +156,17 @@ def layernorm_fp8_mlp_prim_func( ) def _test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + use_shardy, + with_jax_gemm, ): + use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -208,20 +227,25 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - assert_allclose(multi_fwd, single_fwd, dtype=dtype) + fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn + bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): assert isinstance(single_grads[i], list) for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): assert_allclose( - m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" + m_grad, + s_grad, + dtype=bwd_test_type, + err_msg=f"multi_grads[{i}] is not close", ) else: assert_allclose( multi_grads[i], single_grads[i], - dtype=dtype, + dtype=bwd_test_type, err_msg=f"multi_grads[{i}] is not close", ) @@ -232,8 +256,16 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + fp8_recipe, + with_jax_gemm, ): self._test_layernorm_mlp_grad( mesh_config, @@ -243,6 +275,7 @@ def test_layernorm_mlp_grad( dtype, fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -251,19 +284,22 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) def test_layernorm_mlp_grad_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe ): - # We don't test block scaling with Shardy because at the time of writing, - # it is not supported in JAX's scaled_matmul_stablehlo. + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the + # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=recipe.DelayedScaling(), + fp8_recipe=fp8_recipe, use_shardy=True, + with_jax_gemm=True, ) def _test_layernorm_mlp( @@ -276,7 +312,9 @@ def _test_layernorm_mlp( use_fp8, fp8_recipe, use_shardy, + with_jax_gemm, ): + use_jax_fp8_gemm(enabled=with_jax_gemm) jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -340,9 +378,9 @@ def _test_layernorm_mlp( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("use_shardy", [False, True]) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer( - self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy + self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -352,7 +390,8 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, fp8_recipe=None, - use_shardy=use_shardy, + use_shardy=False, + with_jax_gemm=with_jax_gemm, ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -361,9 +400,12 @@ def test_layernorm_mlp_layer( @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper( + "fp8_recipe", SUPPORTED_RECIPES[:-1] if is_mxfp8_supported else SUPPORTED_RECIPES + ) + @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm ): self._test_layernorm_mlp( mesh_config, @@ -374,4 +416,52 @@ def test_layernorm_mlp_layer_fp8( use_fp8=True, fp8_recipe=fp8_recipe, use_shardy=False, + with_jax_gemm=with_jax_gemm, + ) + + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("use_bias", [True, False]) + def test_layernorm_mlp_layer_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype + ): + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX dot_general. + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=False, + fp8_recipe=None, + use_shardy=True, + with_jax_gemm=True, + ) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) + @pytest_parametrize_wrapper("use_bias", [True, False]) + @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES_WITH_SHARDY) + def test_layernorm_mlp_layer_fp8_shardy( + self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe + ): + # TE cuBLAS GEMM custom op does not implement shardy rules so we test shardy only with + # native JAX FP8 dot_general. We don't test block scaling with Shardy because at the + # time of writing, it is not supported in JAX's scaled_matmul_stablehlo. + self._test_layernorm_mlp( + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + use_fp8=True, + fp8_recipe=fp8_recipe, + use_shardy=True, + with_jax_gemm=True, ) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 40595ea988..3e3f9c6be9 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -53,17 +53,38 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl bool atomic_gemm) { // Initialize userbuf communicator if (!_comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, allgather_handle, barrier_handle, 1, 1, tp_size, 1); -#endif _comm_created = true; + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Initialized Userbuffers Communicator\n"); + } + } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +CommOverlapCore::CommOverlapCore(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { + if (!_comm_created) { + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); + _comm_created = true; + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Initialized Userbuffers Communicator (w/ MPI Boostrapping)\n"); + } } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -262,6 +283,21 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : CommOverlapCore(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, + comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -285,6 +321,39 @@ CommOverlapBase::~CommOverlapBase() { cudaStreamDestroy(_stream_comm); } +void CommOverlapBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, "bytes, UB dtype has ", _ubuf.element_size(), + "bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + const size_t ubuf_size = _ubuf.numel(); + void *dst_ptr = _ubuf.dptr(); + if (local_chunk) { + NVTE_CHECK(source_size * _tp_size == ubuf_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", tensor_parallel_size=", _tp_size, + ", ubuf_size=", ubuf_size, ")"); + dst_ptr = (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); + } else { + NVTE_CHECK(source_size == ubuf_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", ubuf_size, ")"); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -600,6 +669,21 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) + : CommOverlapCore(tp_size, tp_size, num_max_streams, comm_cga_size, gemm_priority, + comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -607,13 +691,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; @@ -621,14 +705,14 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -651,7 +735,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); @@ -669,6 +753,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); } +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor @@ -809,6 +925,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -877,12 +1002,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -930,16 +1049,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526d..77560a9482 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -26,17 +26,20 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); -enum class CommOverlapType { RS = 0, AG = 1 }; - -enum class CommOverlapAlgo { - BULK_OVERLAP_AG = 0, - BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG_P2P = 2, - SPLIT_PIPELINED_RS = 3, - SPLIT_PIPELINED_RS_P2P = 4, - ATOMIC_GEMM_RS = 5, - ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 +enum class CommOverlapType : int64_t { NONE = 0, RS = 1, AG = 2 }; + +enum class CommOverlapMethod : int64_t { NONE = 0, BULK = 1, PIPELINE = 2, RING_EXCHANGE = 3 }; + +enum class CommOverlapAlgo : int64_t { + NO_OVERLAP = 0, + BULK_OVERLAP_AG = 1, + BULK_OVERLAP_RS = 2, + SPLIT_PIPELINED_AG_P2P = 3, + SPLIT_PIPELINED_RS = 4, + SPLIT_PIPELINED_RS_P2P = 5, + ATOMIC_GEMM_RS = 6, + ATOMIC_GEMM_AG_P2P = 7, + ATOMIC_GEMM_RS_P2P = 8 }; class CommOverlapCore { @@ -66,28 +69,48 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python + // External/framework collectives-based constructor CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); + // MPI-based constructor + CommOverlapCore(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -142,9 +165,14 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python + // External/framework collective-based constructor CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, @@ -153,8 +181,18 @@ class CommOverlapBase : public CommOverlapCore { bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + // MPI-based constructor + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int tp_size, + int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + virtual ~CommOverlapBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -215,9 +253,14 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + // External/framework collective-based constructor CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, @@ -226,8 +269,18 @@ class CommOverlapP2PBase : public CommOverlapCore { int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + // MPI-based constructor + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int tp_size, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); + virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a1cd85ba2a..beb96545a7 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,8 +8,11 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include +#include +#include #include #include "cuda_runtime.h" @@ -17,12 +20,25 @@ #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt64", transformer_engine::DType::kInt64) \ .value("kInt32", transformer_engine::DType::kInt32) \ .value("kFloat32", transformer_engine::DType::kFloat32) \ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat8E8M0", transformer_engine::DType::kFloat8E8M0); \ + pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) \ + .value("GELU", NVTE_Activation_Type::GELU) \ + .value("GEGLU", NVTE_Activation_Type::GEGLU) \ + .value("SILU", NVTE_Activation_Type::SILU) \ + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) \ + .value("RELU", NVTE_Activation_Type::RELU) \ + .value("REGLU", NVTE_Activation_Type::REGLU) \ + .value("QGELU", NVTE_Activation_Type::QGELU) \ + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ + .value("SRELU", NVTE_Activation_Type::SRELU) \ + .value("SREGLU", NVTE_Activation_Type::SREGLU); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ @@ -75,16 +91,27 @@ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTE_Norm_Type", pybind11::module_local()) \ + .value("LayerNorm", NVTE_Norm_Type::LayerNorm) \ + .value("RMSNorm", NVTE_Norm_Type::RMSNorm); \ pybind11::enum_( \ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ + .value("NONE", transformer_engine::CommOverlapType::NONE) \ .value("RS", transformer_engine::CommOverlapType::RS) \ .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapMethod", \ + pybind11::module_local()) \ + .value("NONE", transformer_engine::CommOverlapMethod::NONE) \ + .value("BULK", transformer_engine::CommOverlapMethod::BULK) \ + .value("PIPELINE", transformer_engine::CommOverlapMethod::PIPELINE) \ + .value("RING_EXCHANGE", transformer_engine::CommOverlapMethod::RING_EXCHANGE); \ pybind11::enum_(m, "CommOverlapAlgo", \ pybind11::module_local()) \ + .value("NO_OVERLAP", transformer_engine::CommOverlapAlgo::NO_OVERLAP) \ .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ .value("SPLIT_PIPELINED_AG_P2P", \ @@ -95,30 +122,31 @@ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ - py::call_guard()) \ + pybind11::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + pybind11::call_guard()) \ .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ - py::call_guard()) \ + pybind11::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ - py::call_guard()); \ + pybind11::call_guard()); \ + pybind11::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", \ + pybind11::module_local()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + pybind11::call_guard()); \ + pybind11::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(pybind11::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + pybind11::call_guard()); \ m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def( \ "get_stream_priority_range", \ [](int device_id = -1) { \ @@ -126,8 +154,14 @@ transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ return std::make_pair(low_pri, high_pri); \ }, \ - py::call_guard(), py::arg("device_id") = -1); \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); + pybind11::call_guard()); \ + m.def("get_qkv_format", &nvte_get_qkv_format, \ + pybind11::call_guard()); \ + m.def("get_num_compute_streams", &nvte_get_num_compute_streams, \ + pybind11::call_guard()); \ + m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported, \ + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fe4109cee8..4d1e8316c3 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -17,7 +17,7 @@ from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Format -from transformer_engine_jax import nvte_get_qkv_format +from transformer_engine_jax import get_qkv_format from . import cpp_extensions as tex @@ -109,7 +109,7 @@ def get_qkv_format(self): """ Return the corresponding qkv_format (BSHD, SBHD, THD) """ - return QKVFormat(nvte_get_qkv_format(self.value)) + return QKVFormat(get_qkv_format(self.value)) def is_qkvpacked(self): """ diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ce66bba3cf..341dcb0c8c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -985,6 +985,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -993,6 +994,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1037,6 +1039,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) + if noop_scaled_tensor: + return ScaledTensorFactory.create_2x( + out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1090,6 +1096,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1100,6 +1107,7 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1113,13 +1121,49 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) + scale = jnp.empty((), jnp.float32) + act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive - if not PrimitiveClass.enabled(): + if not PrimitiveClass.enabled() or ( + quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE + ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) + if quantizer is None: + output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + dz, + x, + scale, + # outputs float32 for dbias accumulation + out_dtype=(jnp.float32 if is_dbias else x.dtype), + # default value for no scaling, TE/common ignore this value when scale is unset + scaling_mode=ScalingMode.NO_SCALING.value, + is_2x=False, # unused + scale_dtype=jnp.float32, # unused + is_dbias=False, + act_enum=act_type_id, + act_len=act_len, + is_outer=True, + ) + output = output.astype(x.dtype) + dbias = None + if is_dbias: + dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) + + if noop_scaled_tensor: + return ( + ScaledTensorFactory.create_2x( + output, + None, + output, + None, + ScalingMode.NO_SCALING, + dq_dtype=output.dtype, + ), + dbias, + ) + + return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): @@ -1145,31 +1189,6 @@ def quantize_dact_dbias( if war_output is not None: return war_output - scale = jnp.empty((), jnp.float32) - - act_type_id = ActivationEnum[activation_type] - - if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( - dz, - x, - scale, - # outputs float32 for dbias accumulation - out_dtype=(jnp.float32 if is_dbias else x.dtype), - # default value for no scaling, TE/common ignore this value when scale is unset - scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused - scale_dtype=jnp.float32, # unused - is_dbias=False, - act_enum=act_type_id, - act_len=act_len, - is_outer=True, - ) - dbias = None - if is_dbias: - dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - return output.astype(x.dtype), dbias - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( @@ -1183,7 +1202,7 @@ def quantize_dact_dbias( ) return out, dbias - if isinstance(quantizer, DelayedScaleQuantizer): + if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # TE/common dact_dbias_quantize does not support gated act yet @@ -1243,6 +1262,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1252,6 +1272,7 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1262,5 +1283,6 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, + noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a6c58edb4a..8fe5fb2d66 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,19 +3,25 @@ # See LICENSE for license information. """JAX te modules""" -from typing import Tuple, Sequence, Union, Dict -from functools import partial, reduce -import operator import math +import operator +from collections.abc import Iterable +from dataclasses import dataclass, field +from functools import partial, reduce +from typing import Tuple, Sequence, Union + import jax import jax.numpy as jnp -from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams +from jax import dtypes +from jax.sharding import NamedSharding, PartitionSpec + +import transformer_engine_jax as tex from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize - from ..quantize import ( ScaledTensor, + ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, @@ -25,22 +31,1584 @@ QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, + apply_padding_to_scale_inv, + remove_padding_from_scale_inv, +) +from .misc import get_padded_spec, jax_dtype_to_te_dtype +from ..sharding import ( + global_mesh_resource, + get_mesh_axis_size, + generate_pspec, + W_TP_AXES, + SEQLEN_TP_AXES, ) -__all__ = ["gemm", "grouped_gemm"] +__all__ = [ + "CommOverlapHelper", + "CommOverlapHelperSet", + "gemm", + "grouped_gemm", + "gemm_uses_jax_dot", + "sanitize_dims", + "get_non_contracting_dims", + "transpose_dims", +] + +num_cublas_streams = tex.get_num_compute_streams() -num_cublas_streams = get_num_compute_streams() +CUDA_STREAM_PRIORITY_LOWEST = None +CUDA_STREAM_PRIORITY_HIGHEST = None def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if get_device_compute_capability(0) >= 90: + if tex.get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 +def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]: + """Convert relative (negative) indexes to absolute dimension numbers.""" + dims_ = dims if isinstance(dims, Iterable) else (dims,) + if len(dims_) == 0: + return dims_ + return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None) + + +def get_non_contracting_dims(ndim, contracting_dims): + """Return a tuple of dimensions not included in the contracting dimensions.""" + contracting_dims = sanitize_dims(ndim, contracting_dims) + return tuple(dim for dim in range(ndim) if dim not in contracting_dims) + + +def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1): + """Compute the new dimension numbers after transpose.""" + if len(dims_to_transpose) == 0: + return dims_to_transpose + flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis + transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis)) + return tuple(transposed_dims.index(dim) for dim in dims_to_transpose) + + +def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool: + lhs, rhs, e4m3, e5m2 = map( + dtypes.canonicalize_dtype, + ( + lhs_dtype, + rhs_dtype, + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ), + ) + + # FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3) + if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3): + return True + + # Any other combination of data types is not supported + return False + + +def _get_gemm_layout( + operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]] +) -> Tuple[bool, bool]: + lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting + rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting + return lhs_is_transposed, rhs_is_transposed + + +def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims): + lhs_q = lhs + rhs_q = rhs + + if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None: + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0]) + lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims + need_lhs_colwise = lhs_is_transposed and ( + lhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() + ) + flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) + lhs_q = lhs_quantizer.quantize( + lhs, + is_rowwise=not need_lhs_colwise, + is_colwise=need_lhs_colwise, + flatten_axis=flatten_axis, + ) + + if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None: + rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1]) + rhs_is_transposed = rhs.ndim - 1 in rhs_cdims + need_rhs_colwise = not rhs_is_transposed and ( + rhs_quantizer.scaling_mode.is_1d_block_scaling() + or not is_fp8_gemm_with_all_layouts_supported() + ) + flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 + rhs_q = rhs_quantizer.quantize( + rhs, + is_rowwise=not need_rhs_colwise, + is_colwise=need_rhs_colwise, + flatten_axis=flatten_axis, + ) + + assert not isinstance(lhs_q, ScaledTensor2x) + assert not isinstance(rhs_q, ScaledTensor2x) + + return lhs_q, rhs_q + + +@dataclass(frozen=True) +class CommOverlapHelper: + """ + Helper object that carries comm+GEMM overlap configuration, initializes the internal + communication buffer, and generates lowering arguments and partitioning rules for + the GemmPrimitive. + """ + + # Core init arguments + comm_type: tex.CommOverlapType = field(default=tex.CommOverlapType.NONE) + method: tex.CommOverlapMethod = field(default=tex.CommOverlapMethod.NONE) + buffer_shape: Sequence[int] = field(default=None) + buffer_dtype: jnp.dtype = field(default=jnp.bfloat16) + tp_size: int = field( + default_factory=lambda: get_mesh_axis_size(global_mesh_resource().tp_resource) + ) + + # Userbuffers bootstrap kwargs + num_splits: int = field(default=None, kw_only=True) + num_max_streams: int = field(default=3, kw_only=True) + comm_cga_size: int = field(default=None, kw_only=True) + gemm_priority: int = field(default=CUDA_STREAM_PRIORITY_LOWEST, kw_only=True) + comm_priority: int = field(default=CUDA_STREAM_PRIORITY_HIGHEST, kw_only=True) + num_comm_sm: int = field(default=None, kw_only=True) + set_sm_margin: bool = field(default=None, kw_only=True) + use_ce: bool = field(default=None, kw_only=True) + atomic_gemm: bool = field(default=False, kw_only=True) + rs_overlap_first_gemm: bool = field(default=False, kw_only=True) + aggregate_ag: bool = field(default=False, kw_only=True) + + # Other kwargs not passed to Userbuffers + tp_resource: str = field(default_factory=lambda: global_mesh_resource().tp_resource) + logical_tp_axis: str = field(default=W_TP_AXES, kw_only=True) + logical_sp_axis: str = field(default=SEQLEN_TP_AXES, kw_only=True) + output_all_gathered_lhs: bool = field(default=False, kw_only=True) + flatten_axis: int = field(default=-1, kw_only=True) + + # Internal attributes + is_enabled: bool = field(default=False, init=False) + unique_id: int = field(default=-1, init=False, compare=False) + sharded_impl: bool = field(default=False, init=False, compare=False) + gather_dim: int = field(default=-2, init=False, compare=False) + scatter_dim: int = field(default=-2, init=False, compare=False) + + def __post_init__(self): + # Update global min/max CUDA stream priority values if not already done + global CUDA_STREAM_PRIORITY_LOWEST, CUDA_STREAM_PRIORITY_HIGHEST + if CUDA_STREAM_PRIORITY_LOWEST is None or CUDA_STREAM_PRIORITY_HIGHEST is None: + ( + CUDA_STREAM_PRIORITY_LOWEST, + CUDA_STREAM_PRIORITY_HIGHEST, + ) = tex.get_stream_priority_range() + if self.gemm_priority is None: + object.__setattr__(self, "gemm_priority", CUDA_STREAM_PRIORITY_LOWEST) + if self.comm_priority is None: + object.__setattr__(self, "comm_priority", CUDA_STREAM_PRIORITY_HIGHEST) + + if self.method != tex.CommOverlapMethod.NONE or self.comm_type != tex.CommOverlapType.NONE: + assert self.method != tex.CommOverlapMethod.NONE, ( + f"CommOverlapHelper: {self.comm_type} is not a valid collective type for " + f"{self.method}." + ) + assert self.comm_type != tex.CommOverlapType.NONE, ( + f"CommOverlapHelper: {self.method} is not a valid overlap method for " + f"{self.comm_type}." + ) + assert ( + self.buffer_shape is not None and len(self.buffer_shape) >= 2 + ), f"CommOverlapHelper: {self.buffer_shape} is not a valid buffer shape." + assert self.tp_resource is not None, ( + "CommOverlapHelper: Communication + GEMM overlap requires a valid TP resource. " + "This must either be specified via the `tp_resource=` keyword, or " + "`CommOverlapHelper` needs to be initialized under a " + "`te.sharding.global_shard_guard()` using a `te.sharding.MeshResource()` with a " + "valid tensor-parallel mesh axis name." + ) + assert ( + self.tp_size % 2 == 0 + ), f"CommOverlapHelper: Tensor-parallel axis of {self.tp_size} is not divisible by 2." + if not self.is_bulk() and not self.is_p2p(): + # Pipelined overlap is only for reduce-scatter + assert not self.is_all_gather(), ( + f"CommOverlapHelper: {self.method} is not a valid overlap method for " + f"{self.comm_type}." + ) + + # Collapse buffer shape to 2D + if len(self.buffer_shape) > 2: + if self.flatten_axis < 0: + object.__setattr__( + self, "flatten_axis", self.flatten_axis + len(self.buffer_shape) + ) + object.__setattr__( + self, + "buffer_shape", + ( + reduce(operator.mul, self.buffer_shape[: self.flatten_axis]), + reduce(operator.mul, self.buffer_shape[self.flatten_axis :]), + ), + ) + + # Num splits for P2P overlap is always fixed to TP size + if self.is_p2p(): + object.__setattr__(self, "num_splits", self.tp_size) + elif self.num_splits is None: + object.__setattr__(self, "num_splits", self.tp_size) + + # Set conditional defaults for config options not specified at init time + if self.comm_cga_size is None: + object.__setattr__(self, "comm_cga_size", 1 if self.is_p2p() else 2) + if self.num_comm_sm is None: + object.__setattr__(self, "num_comm_sm", 1 if self.is_p2p() else 16) + if self.set_sm_margin is None: + object.__setattr__(self, "set_sm_margin", not self.is_p2p()) + if self.use_ce is None: + object.__setattr__(self, "use_ce", self.is_p2p()) + + # Allocate the communication buffer + args, kwargs = self.get_bootstrap_args_kwargs() + object.__setattr__(self, "unique_id", tex.create_comm_overlap_buffer(*args, **kwargs)) + object.__setattr__(self, "is_enabled", True) + + def _set_sharded_impl(self, value): + assert isinstance(value, bool) + object.__setattr__(self, "sharded_impl", value) + + def _set_gather_dim(self, value): + assert isinstance(value, int) + object.__setattr__(self, "gather_dim", value) + + def _set_scatter_dim(self, value): + assert isinstance(value, int) + object.__setattr__(self, "scatter_dim", value) + + def is_bulk(self): + """Check if this is a bulk overlap.""" + return self.method == tex.CommOverlapMethod.BULK + + def is_p2p(self): + """Check if this is a peer-to-peer (ring-exchange) overlap.""" + return self.method == tex.CommOverlapMethod.RING_EXCHANGE + + def is_all_gather(self): + """Check if the overlapped collective is an all-gather.""" + return self.comm_type == tex.CommOverlapType.AG + + def is_reduce_scatter(self): + """Check if the overlapped collective is a reduce-scatter.""" + return self.comm_type == tex.CommOverlapType.RS + + def has_aux_output(self): + """Check if the comm+GEMM overlap has an auxiliary output.""" + return self.is_enabled and ( + self.is_bulk() or (self.is_all_gather() and self.output_all_gathered_lhs) + ) + + def get_bootstrap_args_kwargs(self): + """Generate positional and keyword arguments to bootstrap Userbuffers.""" + args = ( + self.comm_type, + self.method, + self.buffer_shape, + jax_dtype_to_te_dtype(self.buffer_dtype), + self.tp_size, + ) + kwargs = { + "num_splits": self.num_splits, + "num_max_streams": self.num_max_streams, + "comm_cga_size": self.comm_cga_size, + "gemm_priority": self.gemm_priority, + "comm_priority": self.comm_priority, + "num_comm_sm": self.num_comm_sm, + "set_sm_margin": self.set_sm_margin, + "use_ce": self.use_ce, + "atomic_gemm": self.atomic_gemm, + "rs_overlap_first_gemm": self.rs_overlap_first_gemm, + "aggregate_ag": self.aggregate_ag, + } + return args, kwargs + + def get_lowering_kwargs(self): + """Generate a dictionary of keyword arguments used in GemmPrimitive.lowering().""" + aux_axis_boundary = -1 + if self.is_enabled and self.sharded_impl: + if self.is_all_gather(): + assert self.gather_dim >= 0, ( + "Internal TE error: CommOverlapHelper.gather_dim is not set correctly in " + "GemmPrimitive." + ) + aux_axis_boundary = self.gather_dim + 1 + elif self.is_reduce_scatter(): + assert self.scatter_dim >= 0, ( + "Internal TE error: CommOverlapHelper.scatter_dim is not set correctly in " + "GemmPrimitive." + ) + aux_axis_boundary = self.scatter_dim + 1 + + return { + "comm_overlap_id": self.unique_id, + "comm_overlap_method": int(self.method.value), + "comm_type": int(self.comm_type.value), + "aux_axis_boundary": aux_axis_boundary, + } + + @staticmethod + def _check_operand_specs(lhs_specs, rhs_specs, dimension_numbers): + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + + def _split_specs(specs, contracting_dims, batch_dims): + ndims = len(specs) + cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) + + # Batch specs + bspecs = tuple(specs[i] for i in bdims) + + # Non-batch leading dimension specs + lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) + + # Non-batch contracting dimension specs + cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) + + return bspecs, lspecs, cspecs + + ( + (lhs_bspecs, lhs_lspecs, lhs_cspecs), + (rhs_bspecs, rhs_lspecs, rhs_cspecs), + ) = map( + _split_specs, + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + (lhs_bdims, rhs_bdims), + ) + + # Batched dimensions must have the same sharding + if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: + assert all( + lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) + ), ( + "cuBLAS GEMM operand batch dimensions must have the same sharding: " + f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." + ) + + # Only one each of the non-batched leading dimensions and non-batched contracting + # dimensions can be sharded + lhs_ldims, rhs_ldims = map( + lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + (lhs_ndim, rhs_ndim), + (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), + ) + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( + lambda specs: tuple(spec for spec in specs if spec is not None), + (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), + ) + assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " + f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." + ) + assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( + "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " + f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + ) + + # Extract single leading and contracting dimension specs + (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + lambda specs: None if len(specs) == 0 else specs[0], + (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + ) + return (lhs_lspec, lhs_cspec), (rhs_lspec, rhs_cspec) + + def _get_no_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + + (lhs_lspec, lhs_cspec), (rhs_lspec, rhs_cspec) = self._check_operand_specs( + lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) + ) + + # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts + # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. + # 1. K1 == K2 != None and N == None + # LHS: (B, M, K) + # RHS: (B, None, K) + # OUT: (B, M, None) --(AR)-> (B, M, None) + # 2. K1 == K2 != None and M == N != None + # LHS: (B, M, K) + # RHS: (B, N, K)--(AG)->(B, None, K) + # OUT: (B, M, None) --(RS)--> (B, M, N) + # 3. M == N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, M, K)--(AG)->(B, None, None) + # OUT: (B, M, None) + # 4. M != N + # LHS: (B, M, K)--(AG)->(B, M, None) + # RHS: (B, N, K)--(AG)->(B, N, None) + # OUT: (B, M, N) + reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec + all_reduce_output = reduce_flag and rhs_lspec is None + reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec + all_reduce_spec = reduce_scatter_spec = scatter_dim = None + + lhs_non_contracting_specs, rhs_non_contracting_specs = map( + lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), + (lhs_specs, rhs_specs), + (lhs_cdims, rhs_cdims), + ) + out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) + if reduce_scatter_output: + # All-gather (if necessary) the non-batch non-contracting dimension of RHS + # LHS: (B, M, K) + # RHS: (B, N, K) --(AG)-> (B, None, K) + # OUT: (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) + rhs_spec = tuple( + rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) + ) + reduce_scatter_spec = lhs_cspec + scatter_dim = out_specs.index(rhs_lspec) + + elif all_reduce_output: + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + all_reduce_spec = lhs_cspec + else: + # All-gather (if necessary) the non-batch contracting dimensions + # LHS: (B, M, K) --(AG)-> (B, M, None) + # RHS: (B, N, K) --(AG)-> (B, N, None) + # OUT: (B, M, None) x (B, N, None)^T = (B, M, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] + for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Check if RHS non-contracting spec also appears in the LHS non-contracting specs + if rhs_lspec is not None and rhs_lspec in tuple( + lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims + ): + # All-gather (if necessary) the non-batch non-contracting dimensions of RHS + # LHS: (B, M, None) + # RHS: (B, N, None) --(AG)-> (B, None, None) + # OUT: (B, M, None) x (B, None, None)^T = (B, M, None) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + # Set all output trailing dimensions to zero + out_specs = ( + *lhs_non_contracting_specs, + *[None for _ in range(len(rhs_non_contracting_specs))], + ) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_contracting_specs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), + (out_specs, bias_specs, gelu_specs, (None,)), + (all_reduce_spec, reduce_scatter_spec, scatter_dim), + ) + + def _get_bulk_overlap_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + assert self.tp_resource in aux_in_specs, ( + "CommOverlapHelper: Auxiliary input for bulk all-gather overlap is not sharded " + f"over the tensor-parallel mesh resource '{self.tp_resource}' in any dimension." + ) + + aux_out_specs = (None,) + bulk_comm_dim = aux_in_specs.index(self.tp_resource) + aux_in_specs_batch = aux_in_specs[:bulk_comm_dim] + aux_in_specs_tail = aux_in_specs[bulk_comm_dim + 1 :] + if self.is_all_gather(): + assert all(spec is None for spec in aux_in_specs_tail), ( + "CommOverlapHelper: Trailing dimensions of the auxiliary input for bulk all-gather " + "overlap cannot be sharded." + ) + self._set_gather_dim(bulk_comm_dim) + aux_out_specs = ( + *aux_in_specs_batch, + None, # all-gathered dimension + *[None for _ in range(len(aux_in_specs_tail))], + ) + else: + assert all(spec is None for spec in aux_in_specs[bulk_comm_dim:]), ( + "CommOverlapHelper: Non-batch dimensions of the auxiliary input for bulk " + "reduce-scatter overlap cannot be sharded." + ) + self._set_scatter_dim(bulk_comm_dim) + aux_out_specs = ( + *aux_in_specs_batch, + self.tp_resource, + *[None for _ in range(len(aux_in_specs_tail))], + ) + + # GEMM is independent of communication so specs are as if there is no overlap + operand_specs, output_specs, xla_reduce_info = self._get_no_overlap_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) + + return ( + operand_specs, + (*output_specs[:-1], aux_out_specs), + xla_reduce_info, + ) + + def _get_all_gather_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( + sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batch_dims + ) + + (lhs_lspec, _), _ = self._check_operand_specs( + lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) + ) + assert lhs_lspec == self.tp_resource, ( + "CommOverlapHelper: Non-batch leading dimension of the LHS operand for AG->GEMM " + f"overlap must be sharded over the tensor-parallel mesh resource {self.tp_resource}, " + f"but got {lhs_lspec} sharding instead." + ) + + # AG->GEMM overlap: Require non-batched contracting dimensions to be unsharded (e.g. FSDP) + # LHS: (B, M, None) + # RHS: (None, N) + # OUT: (B, M, None) --(AG)-> (B, None, None) x (None, N) = (B, None, N) + lhs_specs = tuple( + None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] for i in range(rhs_ndim) + ) + + # GEMM output spec keeps LHS batch spec and RHS non-contracting specs, but is None + # in the non-batched leading dimensions. + lhs_non_cspecs_gathered = list( + lhs_specs[i] if i in lhs_bdims else None for i in range(lhs_ndim) if i not in lhs_cdims + ) + rhs_non_cspecs = tuple(rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims) + out_specs = (*lhs_non_cspecs_gathered, *rhs_non_cspecs) + self._set_gather_dim(lhs_specs.index(lhs_lspec)) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_non_cspecs_gathered) :] + gelu_specs = out_specs + + # Auxiliary input/output specs depend on bulk vs. non-bulk overlap + aux_out_specs = (None,) + if self.output_all_gathered_lhs: + # Auxiliary output is the same as the LHS spec, except the gathered dimension unsharded + aux_out_specs = list(lhs_specs).copy() + aux_out_specs[self.gather_dim] = None + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), + (out_specs, bias_specs, gelu_specs, aux_out_specs), + (None, None, None), + ) + + def _get_reduce_scatter_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims) + + (_, lhs_cspec), (_, rhs_cspec) = self._check_operand_specs( + lhs_specs, rhs_specs, ((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)) + ) + assert lhs_cspec == rhs_cspec == self.tp_resource, ( + "CommOverlapHelper: Non-batched contracting dimensions of LHS and RHS operands for " + "GEMM->RS overlap must be sharded over the tensor-parallel resource " + f"{self.tp_resource}, but got LHS:{lhs_cspec} and RHS:{rhs_cspec} sharding instead." + ) + + # GEMM->RS overlap: Require non-contracting non-batch dimensions to be unsharded (e.g. FSDP) + # LHS: (B, None, K) + # RHS: (K, None) + # OUT: (B, None, K) x (K, None) = (B, None, None) --(UB-RS)-> (B, M, None) + lhs_specs = tuple( + None if i not in lhs_bdims + lhs_cdims else lhs_specs[i] for i in range(lhs_ndim) + ) + rhs_specs = tuple( + None if i not in rhs_bdims + rhs_cdims else rhs_specs[i] for i in range(rhs_ndim) + ) + + # GEMM output is the internal communication buffer, but we will use the XLA output buffer + # as the final reduce-scattered output so we shard it accordingly here. + lhs_bspecs = tuple( + lhs_specs[i] for i in range(lhs_ndim) if i in lhs_bdims and i not in lhs_cdims + ) + lhs_lspecs = tuple(lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims) + rhs_non_cspecs = tuple(rhs_specs[i] for i in range(rhs_ndim) if i not in rhs_cdims) + out_specs = ( + *lhs_bspecs, + self.tp_resource, + *[None for _ in range(len(lhs_lspecs) - 1)], + *rhs_non_cspecs, + ) + self._set_scatter_dim(out_specs.index(self.tp_resource)) + + # Bias and Pre-GeLU sharding is based on GEMM output + bias_specs = out_specs[len(lhs_bspecs) + len(lhs_lspecs) :] + gelu_specs = out_specs + + return ( + (lhs_specs, rhs_specs, bias_specs, gelu_specs, aux_in_specs), + (out_specs, bias_specs, gelu_specs, (None,)), + (None, None, None), + ) + + def get_partitioning_rules(self, lhs_specs, rhs_specs, aux_in_specs, dimension_numbers): + """ + Correct operand specs to partititions suitable for the GemmPrimitive, and infer the + partition specs of the outputs. + """ + if self.is_bulk(): + return self._get_bulk_overlap_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) + + impl_map = { + tex.CommOverlapType.NONE: self._get_no_overlap_rules, + tex.CommOverlapType.AG: self._get_all_gather_rules, + tex.CommOverlapType.RS: self._get_reduce_scatter_rules, + } + return impl_map[self.comm_type](lhs_specs, rhs_specs, aux_in_specs, dimension_numbers) + + def get_logical_output_axes(self, lhs_axes, rhs_axes, dimension_numbers): + """ + Compute the logical axis names for the GEMM output axes based on LHS and RHS operands' + logical axis names. + """ + if not lhs_axes or not rhs_axes: + assert not lhs_axes and not rhs_axes, ( + "CommOverlapHelper: Logical axes must either be defined or not defined for both " + "forward operands." + ) + return None + + contracting_dims, batch_dims = dimension_numbers + lhs_ndim, rhs_ndim = map(len, (lhs_axes, rhs_axes)) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), batch_dims) + + lhs_batch_axes = tuple( + lhs_axes[i] for i in range(lhs_ndim) if i in lhs_bdims and i not in lhs_cdims + ) + lhs_leading_axes = tuple( + lhs_axes[i] for i in range(lhs_ndim) if i not in lhs_bdims + lhs_cdims + ) + rhs_non_contracting_axes = tuple(rhs_axes[i] for i in range(rhs_ndim) if i not in rhs_cdims) + + out_axes = (*lhs_batch_axes, *lhs_leading_axes, *rhs_non_contracting_axes) + if self.is_enabled and not self.is_bulk(): + if self.is_all_gather(): + out_axes = ( + *lhs_batch_axes, + *[None for _ in range(len(lhs_leading_axes))], + *rhs_non_contracting_axes, + ) + elif self.is_reduce_scatter(): + out_axes = ( + *lhs_batch_axes, + self.logical_sp_axis, + *[None for _ in range(len(lhs_leading_axes) - 1)], + *[None for _ in range(len(rhs_non_contracting_axes))], + ) + else: + # Generate grad axes without any communication overlap + lhs_specs = generate_pspec(lhs_axes) + lhs_lspec = tuple( + lhs_specs[i] + for i in range(lhs_ndim) + if i not in lhs_bdims + lhs_cdims and lhs_specs[i] is not None + ) + lhs_lspec = None if len(lhs_lspec) == 0 else lhs_lspec[0] + lhs_cspec = tuple( + lhs_specs[i] for i in lhs_cdims if i not in lhs_bdims and lhs_specs[i] is not None + ) + lhs_cspec = None if len(lhs_cspec) == 0 else lhs_cspec[0] + + rhs_specs = generate_pspec(rhs_axes) + rhs_lspec = tuple( + rhs_specs[i] + for i in range(rhs_ndim) + if i not in rhs_bdims + rhs_cdims and rhs_specs[i] is not None + ) + rhs_lspec = None if len(rhs_lspec) == 0 else rhs_lspec[0] + rhs_cspec = tuple( + rhs_specs[i] for i in rhs_cdims if i not in rhs_bdims and rhs_specs[i] is not None + ) + rhs_cspec = None if len(rhs_cspec) == 0 else rhs_cspec[0] + + if not ( + lhs_cspec is not None + and lhs_cspec == rhs_cspec + and lhs_lspec is not None + and lhs_lspec == rhs_lspec + ): + # Trailing dimension is not scattered (i.e. not doing jax.lax.psum_scatter) + out_axes = ( + *lhs_batch_axes, + *lhs_leading_axes, + *[None for _ in range(len(rhs_non_contracting_axes))], + ) + + return out_axes + + +@dataclass(frozen=True) +class CommOverlapHelperSet: + """ + A set of CommOverlapHelper objects that provide complementary comm+GEMM overlap configurations + for FPROP, DGRAD and WGRAD GEMMs in FWD/BWD passes through Dense-layers. + """ + + fprop: CommOverlapHelper = field(default=None) + dgrad: CommOverlapHelper = field(default=None) + wgrad: CommOverlapHelper = field(default=None) + + def _sanity_check(self): + # Require any argument that exists to be a `CommOverlapHelper` instance + for overlap, name in zip((self.fprop, self.dgrad, self.wgrad), ("fprop", "dgrad", "wgrad")): + if overlap is not None: + assert isinstance(overlap, CommOverlapHelper), ( + f"CommOverlapHelperSet: Expected `{name}` to be a {CommOverlapHelper} but got " + f"{type(overlap)} instead." + ) + + # If FPROP overlap is not defined or not enabled, require DGRAD and WGRAD to also not be + # be defined or not enabled + if self.fprop is None or not self.fprop.is_enabled: + assert (self.dgrad is None or not self.dgrad.is_enabled) and ( + self.wgrad is None or not self.wgrad.is_enabled + ), ( + "CommOverlapHelperSet: Cannot do communication overlap for DGRAD and/or WGRAD when " + "there is no communication overlap for FPROP." + ) + return + + assert ( + not self.fprop.is_bulk() + ), "CommOverlapHelperSet: Cannot overlap bulk collectives with FPROP." + + if self.fprop.is_all_gather(): + if self.dgrad is not None and self.dgrad.is_enabled: + if self.dgrad.is_bulk() and self.dgrad.is_all_gather(): + assert not self.fprop.output_all_gathered_lhs, ( + "CommOverlapHelperSet: AG->GEMM FPROP does not support BULK-AG overlap for " + "DGRAD when the all-gathered LHS is already saved in the forward pass." + ) + assert ( + self.wgrad is not None + and self.wgrad.is_enabled + and self.wgrad.is_bulk() + and self.wgrad.is_reduce_scatter() + ), ( + "CommOverlapHelperSet: AG->GEMM FPROP with BULK-AG overlap for DGRAD " + "requires BULK-RS overlap for WGRAD." + ) + + elif not self.dgrad.is_bulk() and self.dgrad.is_reduce_scatter(): + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: AG->GEMM FPROP with GEMM->RS DGRAD does not support " + "communication overlap for WGRAD." + ) + + else: + raise AssertionError( + "CommOverlapHelperSet: AG->GEMM FPROP requires communication overlap for " + "DGRAD to be either BULK-AG or GEMM->RS." + ) + else: + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: AG->GEMM FPROP with no communication overlap for DGRAD" + "does not support communication overlap for WGRAD." + ) + + elif self.fprop.is_reduce_scatter(): + if self.dgrad is not None and self.dgrad.is_enabled: + assert not self.dgrad.is_bulk() and self.dgrad.is_all_gather(), ( + "CommOverlapHelperSet: GEMM->RS FPROP requires communication overlap for DGRAD " + "to be AG->GEMM." + ) + + assert self.wgrad is None or not self.wgrad.is_enabled, ( + "CommOverlapHelperSet: GEMM->RS FPROP does not support communication overlap " + "for WGRAD." + ) + + else: + raise RuntimeError( + "CommOverlapHelperSet: Internal TE error, unrecognized collective type " + f"{self.fprop.comm_type} in communication overlap for FPROP." + ) + + def __post_init__(self): + self._sanity_check() + + if self.fprop is None: + object.__setattr__(self, "fprop", CommOverlapHelper()) + + # Column-parallel layers: QKV projection and MLP FFN1 + # FPROP with AG->GEMM: + # LHS:(B, M, None)--(AG)->(B, None, None) x RHS:(None, N) = OUT:(B, None, N) + # DGRAD w/ BULK-AG for LHS: + # GRAD:(B, None, N) x RHS:(None, N)^T = DGRAD:(B, None, None) + # LHS:(B, M, None)--(BULK-AG)->(B, None, None) + # WGRAD w/ BULK-RS for DGRAD: + # LHS:(B, None, None)^T x GRAD:(B, None, N) = WGRAD:(None, N) + # DGRAD:(B, None, None)--(BULK-RS)->(B, M, None) + # + # Row-parallel layers: Post-attention projection and MLP FFN2 + # FPROP with GEMM->RS: + # LHS:(B, None, K) x RHS:(K, None) = (B, None, None)--(RS)->(B, M, None) + # DGRAD with AG->GEMM (all-gathered GRAD saved for WGRAD): + # GRAD:(B, M, None)--(AG)->(B, None, None) x RHS:(K, None)^T = (B, None, K) + # WGRAD with NO OVERLAP: + # LHS:(B, None, K)^T x GRAD:(B, None, None) = (K, None) + if self.dgrad is None: + dgrad_overlap = None + + if self.fprop.is_all_gather() and not self.fprop.output_all_gathered_lhs: + # FPROP AG->GEMM and DGRAD GEMM->RS + dgrad_overlap = CommOverlapHelper( + method=tex.CommOverlapMethod.RING_EXCHANGE, + comm_type=tex.CommOverlapType.RS, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + logical_tp_axis=self.fprop.logical_tp_axis, + logical_sp_axis=self.fprop.logical_sp_axis, + ) + + elif self.fprop.is_reduce_scatter(): + # FPROP GEMM->RS and DGRAD AG->GEMM + dgrad_overlap = CommOverlapHelper( + method=tex.CommOverlapMethod.RING_EXCHANGE, + comm_type=tex.CommOverlapType.AG, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + logical_tp_axis=self.fprop.logical_tp_axis, + logical_sp_axis=self.fprop.logical_sp_axis, + output_all_gathered_lhs=True, + ) + + else: + dgrad_overlap = CommOverlapHelper() + + object.__setattr__(self, "dgrad", dgrad_overlap) + + if self.wgrad is None: + wgrad_overlap = self.wgrad + + if ( + self.fprop.is_all_gather() + and self.dgrad.is_enabled + and self.dgrad.is_bulk() + and self.dgrad.is_all_gather() + ): + # FPROP AG->GEMM, DGRAD BULK-AG for LHS and WGRAD BULK-RS for DGRAD + wgrad_overlap = CommOverlapHelper( + method=tex.CommOverlapMethod.BULK, + comm_type=tex.CommOverlapType.RS, + buffer_shape=self.fprop.buffer_shape, + buffer_dtype=self.fprop.buffer_dtype, + tp_size=self.fprop.tp_size, + logical_tp_axis=self.fprop.logical_tp_axis, + logical_sp_axis=self.fprop.logical_sp_axis, + ) + + else: + wgrad_overlap = CommOverlapHelper() + + object.__setattr__(self, "wgrad", wgrad_overlap) + + +class GemmPrimitive(BasePrimitive): + """ + Primitive for cuBLAS GEMM + """ + + name = "te_gemm_ffi" + multiple_results = True + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + aux_in, + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + comm_overlap, + ): + del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + + # Sanity-check operand layouts and types + operand_ndims = (lhs.ndim, rhs.ndim) + contracting_dims, _ = dimension_numbers + ( + lhs_contracting_dims, + rhs_contracting_dims, + ) = map(sanitize_dims, operand_ndims, contracting_dims) + lhs_contracting_size, rhs_contracting_size = map( + lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims), + ) + assert lhs_contracting_size == rhs_contracting_size, ( + "cuBLAS GEMM operands have incompatible contracting dimensions: " + f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}." + ) + + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) + if scaling_mode != ScalingMode.NO_SCALING: + assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + "cuBLAS GEMM quantized operands have incompatible data types: " + f"{lhs.dtype} x {rhs.dtype}." + ) + assert ( + lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0 + ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." + if ( + scaling_mode != ScalingMode.MXFP8_1D_SCALING + and not tex.is_non_nt_fp8_gemm_supported() + ): + assert not lhs_is_transposed and rhs_is_transposed, ( + "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " + "require non-transposed LHS and transposed RHS operands " + "(`contracting_dims=((-1, ), (-1, ))`)." + ) + + # Determine output shape and dtype + assert ( + dtypes.canonicalize_dtype(out_dtype).itemsize > 1 + ), "cuBLAS GEMM custom op does not support 8-bit quantized output types." + lhs_non_contracting_shape, rhs_non_contracting_shape = map( + lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims], + (lhs.shape, rhs.shape), + (lhs_contracting_dims, rhs_contracting_dims), + ) + out_shape = [*lhs_non_contracting_shape, *rhs_non_contracting_shape] + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # Auxiliary output for comm+GEMM overlap + aux_out_shape = (0,) + aux_out_dtype = jnp.bfloat16 + if comm_overlap.is_enabled: + if comm_overlap.is_bulk(): + # Bulk overlap will all-gather or reduce-scatter the tensor in the auxiliary input + # and return the result of the collective in the auxiliary output + assert ( + aux_in.size > 0 + ), "cuBLAS GEMM w/ bulk collective overlap requires an auxiliary input." + assert aux_in.ndim > 1, ( + "cuBLAS GEMM w/ bulk collective overlap only supports multidimensional " + "auxiliary inputs." + ) + + aux_out_shape = list(aux_in.shape).copy() + aux_out_dtype = aux_in.dtype + if comm_overlap.sharded_impl: + if comm_overlap["comm_type"] == tex.CommOverlapType.AG: + aux_out_shape[comm_overlap.gather_dim] *= comm_overlap.tp_size + else: + assert aux_in.shape[comm_overlap.scatter_dim] % comm_overlap.tp_size, ( + "cuBLAS GEMM w/ bulk reduce-scatter overlap requires the auxiliary " + "input to be divisible by tensor-parallel size in dimension index " + f"{comm_overlap.scatter_dim}." + ) + aux_out_shape[comm_overlap.scatter_dim] = ( + aux_out_shape[comm_overlap.scatter_dim] // comm_overlap.tp_size + ) + + elif comm_overlap.is_all_gather(): + # Sharded abstract multiplies gathered dimension by TP size + if comm_overlap.sharded_impl: + out_shape[comm_overlap.gather_dim] *= comm_overlap.tp_size + output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + + # AG->GEMM overlap can copy all-gathered LHS into the auxiliary buffer + if comm_overlap.output_all_gathered_lhs: + aux_out_shape = list(lhs.shape).copy() + aux_out_dtype = lhs.dtype + + # Sharded abstract multiplies gathered dimension by TP size + if comm_overlap.sharded_impl: + aux_out_shape[comm_overlap.gather_dim] *= comm_overlap.tp_size + elif comm_overlap.is_reduce_scatter(): + # GEMM->RS auxiliary output is the reduce-scattered output + rs_out_shape = list(out_shape).copy() + + # Sharded abstract divides scattered dimension by TP size + if comm_overlap.sharded_impl: + rs_out_shape[comm_overlap.scatter_dim] = ( + rs_out_shape[comm_overlap.scatter_dim] // comm_overlap.tp_size + ) + + output = jax.core.ShapedArray(shape=rs_out_shape, dtype=out_dtype) + + aux_out = jax.core.ShapedArray(shape=aux_out_shape, dtype=aux_out_dtype) + + # Validate bias -- shape always depends on pure GEMM output even for GEMM->RS overlap + bias_shape = (0,) + bias_dtype = out_dtype + if fuse_bias: + expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) + if not grad: + assert bias.size == expected_bias_size, ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({expected_bias_size}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {bias_dtype} but found {bias.dtype}." + ) + bias_shape = bias.shape + else: + bias_shape = rhs_non_contracting_shape + bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + + # Validate pre-GeLU -- shape always depends on pure GEMM output even for GEMM->RS overlap + pre_gelu_shape = (0,) + pre_gelu_dtype = out_dtype + if fuse_gelu: + pre_gelu_shape = out_shape + if grad: + pre_gelu_ndim = len(pre_gelu_shape) + assert gelu_input.ndim == pre_gelu_shape and all( + gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim) + ), ( + "cuBLAS GEMM pre-GeLU tensor has incorrect shape, " + f"expected {pre_gelu_shape} but found {gelu_input.shape}." + ) + assert gelu_input.dtype == out_dtype, ( + "cuBLAS GEMM pre-GeLU tensor has incorrect data type, " + f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." + ) + pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + + # Need extra workspace for swizzled scale factors + lhs_swizzle_size = 0 + rhs_swizzle_size = 0 + swizzle_dtype = jnp.uint8 + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + lhs_swizzle_size = lhs_scale_inv.size + rhs_swizzle_size = rhs_scale_inv.size + lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) + rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) + + # Size cuBLAS workspace -- multiplied by number of comm+GEMM overlap compute streams + workspace_size = get_cublas_workspace_size_bytes() + if comm_overlap.is_enabled: + workspace_size *= comm_overlap.num_max_streams + + # cuBLAS requires workspace pointers aligned to 256 bytes but XLA does not guarantee that + # so we add to the size here and align the pointer in the C++ custom call. + workspace_size += 256 + workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) + + return output, bias_grad, pre_gelu_out, aux_out, lhs_swizzle, rhs_swizzle, workspace + + @staticmethod + def outer_abstract(*args, **kwargs): + outputs = GemmPrimitive.abstract(*args, **kwargs) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + aux_in, + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + comm_overlap, + ): + del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype + contracting_dims, _ = dimension_numbers + lhs_aval, _, rhs_aval, *_ = ctx.avals_in + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) + ) + + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, aux_in) + kwargs = { + "scaling_mode": int(scaling_mode.value), + "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_transposed": lhs_transposed, + "rhs_transposed": rhs_transposed, + "fuse_bias": fuse_bias, + "fuse_gelu": fuse_gelu, + "grad": grad, + "use_split_accumulator": use_split_accumulator, + } + kwargs.update(comm_overlap.get_lowering_kwargs()) + + operand_output_aliases = {} + if fuse_bias and not grad: + operand_output_aliases.update({4: 1}) # bias <-> bias_grad + if fuse_gelu and grad: + operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out + + return jax.ffi.ffi_lowering( + GemmPrimitive.name, + operand_output_aliases=operand_output_aliases, + )(ctx, *args, **kwargs) + + @staticmethod + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + aux_in, + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + comm_overlap, + ): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), dimension_numbers[0]) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, + scaling_mode, + lhs.shape, + is_colwise=lhs_quantized_colwise, + flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, + scaling_mode, + rhs.shape, + is_colwise=rhs_quantized_colwise, + flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + ) + + outputs = GemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + aux_in, + out_dtype=out_dtype, + dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, + ) + return outputs[:-3] # discard workspace arrays + + @staticmethod + def batcher( + batched_args, + batch_dims, + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + comm_overlap, + ): + assert GemmPrimitive.outer_primitive is not None + lhs, _, rhs, *_, aux_in_bdims = batched_args + lhs_bdims, _, rhs_bdims, *_ = batch_dims + contracting_dims, batch_dims = dimension_numbers + arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) + arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims + assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( + "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " + f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." + ) + + # Output is batched like the non-contracting batch dimensions of the LHS operand + lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) + lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) + out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + + # Bias gradient is never batched + bias_bdims = (None,) + + # Pre-GeLU output, if exists, is batched like GEMM output + pre_gelu_bdims = (None,) + if fuse_gelu and not grad: + pre_gelu_bdims = out_bdims + + aux_out_bdims = (None,) + if comm_overlap.is_enabled: + if comm_overlap.is_bulk(): + # Bulk overlap auxiliary output must have the same batch dims as the auxiliary input + aux_out_bdims = aux_in_bdims + elif comm_overlap.is_all_gather() and comm_overlap.output_all_gathered_lhs: + # AG->GEMM overlap with all-gathered LHS output must have same batch dims as + # sharded LHS input + aux_out_bdims = arg_lhs_bdims + + return ( + GemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, + ), + (out_bdims, bias_bdims, pre_gelu_bdims, aux_out_bdims), + ) + + @staticmethod + def infer_sharding_from_operands( + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + comm_overlap, + mesh, + arg_infos, + result_infos, + ): + del ( + out_dtype, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + grad, + ) + del use_split_accumulator, result_infos + + lhs_specs, _, rhs_specs, *_, aux_in_specs = map(get_padded_spec, arg_infos) + (_, (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), *_) = ( + comm_overlap.get_partitioning_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) + ) + + # Discard bias gradient and pre-GeLU output specs based on fusion choices + if not fuse_bias: + bias_grad_specs = (None,) + if not fuse_gelu: + pre_gelu_specs = (None,) + + # Assemble output shardings + out_shardings = list( + map( + lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), + ) + ) + + return out_shardings + + @staticmethod + def partition( + out_dtype, + dimension_numbers, + lhs_quantized_colwise, + rhs_quantized_colwise, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + comm_overlap, + mesh, + arg_infos, + result_infos, + ): + del result_infos + + lhs_specs, _, rhs_specs, *_, aux_in_specs = map(get_padded_spec, arg_infos) + ( + (lhs_specs, rhs_specs, bias_specs, gelu_input_specs, aux_in_specs), + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), + (all_reduce_spec, reduce_scatter_spec, scatter_dim), + ) = comm_overlap.get_partitioning_rules( + lhs_specs, rhs_specs, aux_in_specs, dimension_numbers + ) + + # Block scale inverses match their operands, but tensor scale inverses are unsharded. + lhs_scale_specs = (None,) + rhs_scale_specs = (None,) + if scaling_mode.is_1d_block_scaling() and not comm_overlap.is_enabled: + lhs_scale_specs = lhs_specs + rhs_scale_specs = rhs_specs + + # Discard bias and pre-GeLU specs based on fusion choices + if not fuse_bias: + bias_specs = (None,) + bias_grad_specs = (None,) + if not fuse_gelu: + gelu_input_specs = (None,) + pre_gelu_specs = (None,) + + # Assemble argument shardings + arg_shardings = tuple( + map( + lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), + ( + lhs_specs, + lhs_scale_specs, + rhs_specs, + rhs_scale_specs, + bias_specs, + gelu_input_specs, + aux_in_specs, + ), + ) + ) + + # Assemble output shardings + out_shardings = list( + map( + lambda specs: NamedSharding(mesh, PartitionSpec(*specs)), + (out_specs, bias_grad_specs, pre_gelu_specs, aux_out_specs), + ) + ) + + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, aux_in): + comm_overlap._set_sharded_impl(True) + outputs = GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + aux_in, + out_dtype=out_dtype, + dimension_numbers=dimension_numbers, + lhs_quantized_colwise=lhs_quantized_colwise, + rhs_quantized_colwise=rhs_quantized_colwise, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, + ) + comm_overlap._set_sharded_impl(False) + + # All-Reduce/Reduce-Scatter GEMM output + if all_reduce_spec is not None: + outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) + elif reduce_scatter_spec is not None: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + if fuse_gelu and not grad: + outputs[2] = jax.lax.psum_scatter( + outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + ) + + return outputs + + return mesh, _sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args, **kwargs): + del args, kwargs + raise NotImplementedError( + "TE cuBLAS GEMM custom op does not support the Shardy partitioner. You can disable the " + 'custom op by setting `NVTE_JAX_CUSTOM_CALLS_RE="^(?!GemmPrimitive$).+$"` in the ' + "environment, which will make GEMM operations in TE will execute with native " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` calls." + ) + + +register_primitive(GemmPrimitive) + + +def gemm_uses_jax_dot() -> bool: + """Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot.""" + return not GemmPrimitive.enabled() + + +def _get_scale_inv_without_padding(scaled_tensor): + return remove_padding_from_scale_inv( + scaled_tensor.scale_inv, + scaled_tensor.scaling_mode, + scaled_tensor.data.shape, + is_colwise=scaled_tensor.is_colwise, + flatten_axis=scaled_tensor.flatten_axis, + ) + + +def _te_gemm( + lhs: Union[jax.Array, ScaledTensor], + rhs: Union[jax.Array, ScaledTensor], + bias: jax.Array = None, + gelu_input: jax.Array = None, + aux_in: jax.Array = None, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), + fuse_bias: bool = False, + fuse_gelu: bool = False, + grad: bool = False, + use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + comm_overlap: CommOverlapHelper = CommOverlapHelper(), +) -> Tuple[jax.Array, ...]: + # Prepare non-quantized GEMM operands + lhs_data = lhs + rhs_data = rhs + lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) + scaling_mode = ScalingMode.NO_SCALING + contracting_dims, batch_dims = dimension_numbers + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batch_dims) + + # Quantize operands (if necessary) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + + # Extract GEMM custom op inputs from quantized operands + if isinstance(lhs_q, ScaledTensor): + assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( + "cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid " + "`Quantizer` object to quantize the RHS operand." + ) + if isinstance(lhs_q, ScaledTensor2x): + # Choose the quantization of the contracting dimension(s) + lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor() + scaling_mode = lhs_q.scaling_mode + lhs_data = lhs_q.data + lhs_scale_inv = _get_scale_inv_without_padding(lhs_q) + if lhs_q.data_layout == "T": + lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) + + if isinstance(rhs_q, ScaledTensor): + assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( + "cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid " + "`Quantizer` object to quantize the LHS operand." + ) + if isinstance(rhs_q, ScaledTensor2x): + # Choose the quantization of the contracting dimension(s) + rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() + assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + "cuBLAS GEMM quantized operands have mismatched scaling types, " + f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." + ) + rhs_data = rhs_q.data + rhs_scale_inv = _get_scale_inv_without_padding(rhs_q) + if rhs_q.data_layout == "T": + rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) + + # Dummy empties for bias, gelu and aux_in + out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype + if bias is None or not (fuse_bias and not grad): + bias = jnp.empty(0, dtype=out_dtype) + if gelu_input is None or not (fuse_gelu and grad): + gelu_input = jnp.empty(0, dtype=out_dtype) + if aux_in is None or not comm_overlap.is_enabled: + aux_in = jnp.empty(0, dtype=jnp.bfloat16) + + return GemmPrimitive.outer_primitive.bind( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + gelu_input, + aux_in, + out_dtype=out_dtype, + dimension_numbers=((lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims)), + lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, + rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, + scaling_mode=scaling_mode, + fuse_bias=fuse_bias, + fuse_gelu=fuse_gelu, + grad=grad, + use_split_accumulator=use_split_accumulator, + comm_overlap=comm_overlap, + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -218,11 +1786,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) def _calculate_remaining_shape(shape, contracting_dims): - return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) - - -def _transpose_contract_dims(ndim, contracting_dims): - return tuple(ndim - i - 1 for i in contracting_dims)[::-1] + contracting_dims_ = sanitize_dims(len(shape), contracting_dims) + return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_) # Apply jit to guarantee correctness of FP8 GEMM. @@ -230,9 +1795,11 @@ def _transpose_contract_dims(ndim, contracting_dims): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": - lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract) + lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) + lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": - rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract) + rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) + rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -303,12 +1870,12 @@ def _jax_gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, ) -> jnp.ndarray: """ FP8 GEMM via JAX """ - dim_nums = (contracting_dims, ((), ())) def _jax_gemm_fp8_impl(lhs, rhs): @@ -328,37 +1895,16 @@ def _jax_gemm_fp8_impl(lhs, rhs): raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") - if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): - return _jax_gemm_fp8_impl(lhs, rhs) + lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) - if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): - if quantizer_set != noop_quantizer_set: - assert type(quantizer_set.x) is type(quantizer_set.kernel) - if ( - quantizer_set.x.scaling_mode.is_tensor_scaling() - and is_fp8_gemm_with_all_layouts_supported() - ): - lhs_is_rowwise = rhs_is_rowwise = True - else: - (((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums - lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1 - rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1 - lhs_q = quantizer_set.x.quantize( - lhs, - is_rowwise=lhs_is_rowwise, - is_colwise=not lhs_is_rowwise, - ) - rhs_q = quantizer_set.kernel.quantize( - rhs, - is_rowwise=rhs_is_rowwise, - is_colwise=not rhs_is_rowwise, - ) - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): + return _jax_gemm_fp8_impl(lhs_q, rhs_q) if ( isinstance(lhs, jnp.ndarray) and isinstance(rhs, jnp.ndarray) - and quantizer_set == noop_quantizer_set + and lhs_quantizer is None + and rhs_quantizer is None ): return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) @@ -368,30 +1914,112 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], - contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - quantizer_set: QuantizerSet = noop_quantizer_set, -) -> jnp.ndarray: - """General matrix multiplication with optional quantization. - - Args: - lhs: First input matrix. - rhs: Second input matrix. - contracting_dims: Tuple of two sequences representing the contracting dimensions. - The first sequence represents the contracting dimensions of the first matrix, - and the second sequence represents the contracting dimensions of the second matrix. - quantizer_set: Set of quantizers for FP8 quantization of the output. - If None, no quantization is applied and the output has the same dtype as the inputs. - - Returns: - If quantizer_set is None: - The matrix multiplication result. - Shape: (M, N) - Dtype: Same as input dtype - If quantizer_set is provided: - A ScaledTensor containing the quantized matrix multiplication result. + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), + lhs_quantizer: Quantizer = None, + rhs_quantizer: Quantizer = None, + **kwargs, +) -> Tuple[jnp.ndarray, ...]: + r"""General matrix multiplication with optional quantization. + + Parameters + ---------- + lhs: Union[jax.Array, ScaledTensor] + Left-hand side operand in the matrix multiplication. + rhs: Union[jax.Array, ScaledTensor] + Right-hand side operand in the matrix multiplication. + lhs_quantizer: Quantizer, default = None + Object for down-casting the LHS operand for quantized GEMM. + rhs_quantizer: Quantizer, default = None + Object for down-casting the RHS operand for quantized GEMM. + dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]], default = (((-1, ), (0, )), ((), ())) + Tuple of two tuples of sequences representing the contracting and batched dimensions, + respectively. The first sequence in each tuple represents the contracting/batched + dimensions of the LHS operand, and the second sequence represents the contracting/batched + dimensions of the RHS operand. + bias: jax.Array, default = None + Optional additive bias term, required for forward GEMM with bias fusion. Only supported + with TE's custom call to cuBLAS GEMM. + gelu_input: jax.Array, default = None + Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only + supported with TE's custom call to cuBLAS GEMM. + fuse_bias: bool, default = False + Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with + TE's custom call to cuBLAS GEMM. + fuse_gelu: bool, default = False + Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported + with TE's custom call to cuBLAS GEMM. + grad: bool, default = False + Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with + TE's custom call to cuBLAS GEMM. + use_split_accumulator: bool, default = True + Enable promoting some intermediate sums to higher precision when accumulating the result in + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + comm_overlap: CommOverlapHelper, default = None + Helper object that manages comm+GEMM overlap options. + + Returns + ------- + jax.Array: + Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the + GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution + when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and + `grad=False`. + Optional[jax.Array]: + Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call + to cuBLAS GEMM. + Optional[jax.Array]: + Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input + to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to + compute the GeLU contribution to the gradient. Only supported with TE's custom call to + cuBLAS GEMM. """ + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility + if lhs_quantizer is None or rhs_quantizer is None: + quantizer_set = kwargs.get("quantizer_set", None) + if quantizer_set is not None: + lhs_quantizer = quantizer_set.x + rhs_quantizer = quantizer_set.kernel + + # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + fuse_bias = kwargs.get("fuse_bias", False) + fuse_gelu = kwargs.get("fuse_gelu", False) + if not GemmPrimitive.enabled(): + assert kwargs.get("bias", None) is None and not fuse_gelu, ( + "TE GEMM was invoked with bias fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( + "TE GEMM was invoked with GeLU fusion options that are not supported by the " + "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + return _jax_gemm(lhs, rhs, dimension_numbers[0], lhs_quantizer, rhs_quantizer) + + outputs = _te_gemm( + lhs, + rhs, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + dimension_numbers=dimension_numbers, + **kwargs, + ) - return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) + # Discard empty outputs + grad = kwargs.get("grad", False) + comm_overlap = kwargs.get("comm_overlap", CommOverlapHelper()) + clean_outputs = outputs[0] # first output is the final result and is never empty + if (fuse_bias and grad) or (fuse_gelu and not grad) or comm_overlap.has_aux_output(): + clean_outputs = (outputs[0],) + if fuse_bias and grad: # only return bias gradient if it exists + clean_outputs += (outputs[1],) + if fuse_gelu and not grad: # only return pre-GeLU output if it exists + clean_outputs += (outputs[2],) + if comm_overlap.has_aux_output(): + # only return aux output for bulk overlap or non-bulk all-gather overlap + # with gathered LHS output + clean_outputs += (outputs[3],) + return clean_outputs def grouped_gemm( diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 87f1c1913a..94dfaa45a4 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to calculate dbias separately. This function checks if the workaround should be applied. """ + if quantizer is None: + return False + arch_l_100 = False for local_gpu_id in range(len(jax.local_devices())): if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: arch_l_100 = True break + # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, + # but this fails when bias fusion is turned on with arch < 100. + force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() return ( - quantizer is not None - and quantizer.q_layout == QuantizeLayout.ROWWISE + (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 and is_dbias ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 07ebb33114..bf5c257d7b 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1276,6 +1276,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1292,6 +1293,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1319,6 +1321,15 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") + if quantizer is None and noop_scaled_tensor: + return ( + ScaledTensorFactory.create_2x( + output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + ), + mu, + rsigma, + ) + return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 11b3cdc2a3..3cb0e1cdfb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -36,7 +36,6 @@ Quantizer, GroupedQuantizer, QuantizeLayout, - DelayedScaleQuantizer, ScalingMode, compute_scale_from_amax, ) @@ -538,11 +537,12 @@ def _jax_quantize( def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): - assert flatten_axis < 0 + sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype dbias = jnp.sum( dx.astype(jnp.float32), - axis=tuple(range(dx.ndim + flatten_axis)), + axis=tuple(range(sum_axis)), keepdims=False, ) return dbias.astype(dtype) @@ -568,6 +568,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -577,24 +578,34 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + # Early-exit for non-quantized call dq_dtype = dq_dtype or x.dtype - - PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if not PrimitiveClass.enabled(): + if quantizer is None: + dbias = None if is_dbias: - return _jax_quantize_dbias( - x, - quantizer=quantizer, - dq_dtype=dq_dtype, - flatten_axis=flatten_axis, + dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + if noop_scaled_tensor: + # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() + # always works. + return ( + ScaledTensorFactory.create_2x( + x, + None, + x, + None, + ScalingMode.NO_SCALING, + dq_dtype=x.dtype, + data_layout="NN", + flatten_axis=flatten_axis, + ), + dbias, ) - return ( - _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), - None, - ) + return x, dbias - # TE/common doesn't support colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, + # fall back on the native-JAX quantize implementation + PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive + if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -606,9 +617,8 @@ def _quantize_dbias_impl( _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), None, ) - scale = jnp.empty((), jnp.float32) - # TE/common dbias_quantize does not support 1x on arch < 100 + # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100 if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out, _ = _quantize_dbias_impl( x=x, @@ -620,29 +630,23 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - if quantizer is None: - if is_dbias: - return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - return x, None - + scale = jnp.empty((), jnp.float32) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) scale = compute_scale_from_amax(amax, quantizer.q_dtype) - - if isinstance(quantizer, DelayedScaleQuantizer): + elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale - is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) # It is faster to use 1x quantization for tensor scaling + is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() and is_1x_kernel_supported ) - q_layout = quantizer.q_layout if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE @@ -666,7 +670,7 @@ def _quantize_dbias_impl( is_outer=True, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if force_1x_quantization: colwise_scale_inv = rowwise_scale_inv if q_layout == QuantizeLayout.ROWWISE: @@ -698,6 +702,7 @@ def quantize( x: jnp.ndarray, quantizer: Quantizer, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -707,6 +712,8 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer + is None. Returns: A ScaledTensor containing the quantized input tensor. @@ -715,6 +722,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) return out @@ -724,6 +732,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -734,6 +743,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when + quantizer is None. Returns: A tuple containing: @@ -743,7 +754,11 @@ def quantize_dbias( Shape: (K,) or empty if is_dbias is False. """ return _quantize_dbias_impl( - dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis + dz, + quantizer=quantizer, + is_dbias=is_dbias, + flatten_axis=flatten_axis, + noop_scaled_tensor=noop_scaled_tensor, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0789478348..6432eb5f77 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,8 @@ #include "transformer_engine/multi_stream.h" // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::CommOverlapMethod); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::CommOverlapType); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); namespace transformer_engine { @@ -119,6 +122,21 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); +// GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + +int64_t CreateCommOverlapBuffer(CommOverlapType comm_type, CommOverlapMethod method, + const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, int num_splits = 3, int num_max_streams = 3, + int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 16, int set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool rs_overlap_first_gemm = false, + bool aggregate_ag = false); + +void DestroyCommOverlapBuffer(size_t unique_id); + +void DestroyAllCommOverlapBuffers(); + // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index a760df4a79..e77c38e990 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E4M3FN: return DType::kFloat8E4M3; break; - // case xla::ffi::DataType::F8E8M0FNU: - // return DType::kFloat8E8M0; - // break; + case xla::ffi::DataType::F8E8M0FNU: + return DType::kFloat8E8M0; + break; default: auto type_num = static_cast(type); - if (type_num == 33) return DType::kFloat8E8M0; NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", static_cast(type_num)); break; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index c03f7f7751..35f84af543 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -6,11 +6,13 @@ #include "transformer_engine/gemm.h" #include +#include +#include #include "../extensions.h" #include "common/util/cuda_runtime.h" +#include "common/util/string.h" #include "common/util/system.h" -#include "transformer_engine/multi_stream.h" #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -25,6 +27,323 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ~static_cast(255)); } +std::tuple> xla_buffer_to_nvte_gemm_operand( + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + // Set tensor data with collapsed 2D shape + auto buffer_dims = buffer.dimensions(); + std::vector input_shape = {product(buffer_dims, 0, axis_boundary), + product(buffer_dims, axis_boundary, buffer_dims.size())}; + auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type()); + TensorWrapper input(get_nvte_scaling_mode(scaling_mode)); + + if (rowwise) { + input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + } else { + input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + } + + // Set scaling factor for quantized tensors + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); + + std::vector scale_shape = {1}; + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Block scaling also needs to be collapsed to match 2D data + scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), + product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + } + + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + + // Swizzle scaling factors for MXFP8 + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Get the swizzle buffer + NVTE_CHECK(swizzled_scale_inv->element_count() > 0, + "Missing swizzled inverse scale buffer in the JAX primitive."); + auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + auto swizzled_scale_inv_dtype = + convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); + NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); + + // Create tensor to hold swizzled scale factor + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, + scale_shape); + } + + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + + // Set swizzled scales into the input tensor + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, + scale_shape); + } + } + } + + return std::make_tuple(std::move(input), input_shape); +} + +static std::unordered_map comm_overlaps; + +int64_t CreateCommOverlapBuffer(CommOverlapType comm_type, CommOverlapMethod method, + const std::vector &buffer_shape, DType buffer_dtype, + int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, + int set_sm_margin, bool use_ce, bool atomic_gemm, + bool rs_overlap_first_gemm, bool aggregate_ag) { + int64_t unique_id = 0; + hash_combine(unique_id, static_cast(comm_type), static_cast(method), buffer_shape[0], + buffer_shape[0], static_cast(buffer_dtype), tp_size, num_splits, + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, use_ce, atomic_gemm, rs_overlap_first_gemm, aggregate_ag); + + auto it = comm_overlaps.find(unique_id); + if (it == comm_overlaps.end()) { + if (method == CommOverlapMethod::RING_EXCHANGE) { + comm_overlaps[unique_id] = reinterpret_cast( + new CommOverlapP2PBase(buffer_shape, buffer_dtype, tp_size, comm_type, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, use_ce, atomic_gemm, aggregate_ag)); + } else { + comm_overlaps[unique_id] = reinterpret_cast( + new CommOverlapBase(buffer_shape, buffer_dtype, tp_size, num_splits, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, rs_overlap_first_gemm)); + } + } + + return unique_id; +} + +void DestroyCommOverlapBuffer(size_t unique_id) { + auto it = comm_overlaps.find(unique_id); + if (it != comm_overlaps.end()) { + delete it->second; + comm_overlaps.erase(it); + } +} + +void DestroyAllCommOverlapBuffers() { + for (auto it = comm_overlaps.begin(); it != comm_overlaps.end();) { + delete it->second; + it = comm_overlaps.erase(it); + } +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type aux_in, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type aux_out, Result_Type lhs_swizzle, + Result_Type rhs_swizzle, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + CommOverlapMethod comm_overlap_method, CommOverlapType comm_type, + int64_t comm_overlap_id, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, + int64_t aux_axis_boundary, bool lhs_transposed, bool rhs_transposed, + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + // Operands (this includes swizzling MXFP8 scaling factors) + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when + // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) + bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); + bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; + bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( + stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + + // Output tensor -- create with nullptr for GEMM->RS overlap because GEMM output goes into + // the communication buffer. We can use the XLA output buffer for the reduce-scattered + // auxiliary output tensor later. + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + void *out_ptr = + (comm_type == CommOverlapType::RS && comm_overlap_method != CommOverlapMethod::BULK) + ? comm_overlaps[comm_overlap_id]->get_ubuf_dptr() + : output->untyped_data(); + auto out_ = TensorWrapper(out_ptr, out_shape, out_dtype); + + // Bias input to forward pass or bias gradient output from backward pass + void *bias_ptr = nullptr; + std::vector bias_shape = {0}; + DType bias_dtype = out_dtype; + if (fuse_bias) { + if (!grad) { + NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); + } + bias_ptr = bias_grad->untyped_data(); + bias_shape.at(0) = bias_grad->dimensions().front(); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + } + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + // Pre-GeLU output from forward pass or input to backward pass + void *pre_gelu_ptr = nullptr; + std::vector pre_gelu_shape = {0}; + DType pre_gelu_dtype = out_dtype; + if (gelu_input.element_count() > 0) { + if (grad) { + NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(), + "Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"); + } + pre_gelu_ptr = pre_gelu_out->untyped_data(); + pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1), + static_cast(pre_gelu_out->dimensions().back())}; + pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type()); + } + auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); + + // cuBLAS workspace + 256 alignment enforcement + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; + auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); + + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + if (comm_type == CommOverlapType::NONE) { + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", + to_string_like(out_shape), " but got ", output->element_count(), " elements ", + to_string_like(output->dimensions())); + + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + } else { + auto executor = comm_overlaps[comm_overlap_id]; + auto tp_size = executor->get_tp_size(); + if (comm_overlap_method == CommOverlapMethod::BULK) { + // Prepare the auxiliary output tensor + auto aux_out_dims = aux_out->dimensions(); + std::vector aux_out_shape = {0}; + auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); + if ((comm_type == CommOverlapType::AG && aux_out->element_count() > 0) || + comm_type == CommOverlapType::RS) { + std::vector aux_out_shape = { + product(aux_out_dims, 0, aux_axis_boundary), + product(aux_out_dims, aux_axis_boundary, aux_out_dims.size())}; + } + auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); + + // Copy the auxiliary data into the communications buffer + auto aux_in_dims = aux_in.dimensions(); + std::vector aux_in_shape = { + product(aux_in_dims, 0, aux_axis_boundary), + product(aux_in_dims, aux_axis_boundary, aux_in_dims.size())}; + auto aux_in_dtype = convert_ffi_datatype_to_te_dtype(aux_in.element_type()); + auto aux_in_ = TensorWrapper(aux_in.untyped_data(), aux_in_shape, aux_in_dtype); + if (comm_type == CommOverlapType::AG && aux_out->element_count() > 0) { + NVTE_CHECK(aux_in_shape[0] == tp_size * aux_out_shape[0], + "cuBLAS GEMM w/ bulk AG overlap auxiliary output is sized incorrectly, ", + "expected (", aux_in_shape[0] / tp_size, ",", aux_in_shape[1], ") but got ", + to_string_like(aux_out_dims)); + } else if (comm_type == CommOverlapType::RS) { + NVTE_CHECK(tp_size * aux_in_shape[0] == aux_out_shape[0], + "cuBLAS GEMM w/ bulk RS overlap auxiliary output is sized incorrectly, ", + "expected (", aux_in_shape[0] * tp_size, ",", aux_in_shape[1], ") but got ", + to_string_like(aux_out_dims)); + } + executor->copy_into_buffer(stream, aux_in_, (comm_type == CommOverlapType::AG)); + + // Launch GEMM w/ bulk overlap + executor->bulk_overlap(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, comm_type, aux_out_, + stream); + } else if (comm_type == CommOverlapType::RS) { + // Prepare the auxiliary buffer for the reduce-scattered GEMM output + auto rs_out_shape = std::vector(out_shape); + rs_out_shape.at(0) /= tp_size; + auto rs_out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + auto rs_out_ = TensorWrapper(output->untyped_data(), rs_out_shape, rs_out_dtype); + NVTE_CHECK(rs_out_.numel() == output->element_count(), + "cuBLAS GEMM->RS overlap output buffer is sized incorrectly, expected ", + rs_out_.numel(), " elements ", to_string_like(rs_out_shape), " but got ", + output->element_count(), " elements ", to_string_like(output->dimensions())); + + // Launch GEMM+RS + executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, rs_out_, stream); + } else if (comm_type == CommOverlapType::AG) { + // Prepare the auxiliary buffer for all-gathered LHS + std::vector aux_out_shape = {0}; + auto aux_out_dtype = convert_ffi_datatype_to_te_dtype(aux_out->element_type()); + if (aux_out->element_count() > 0) { + aux_out_shape = std::vector(lhs_shape); + aux_out_shape.at(0) *= tp_size; + auto aux_out_numel = aux_out_shape[0] * aux_out_shape[1]; + NVTE_CHECK(aux_out_numel == aux_out->element_count(), + "cuBLAS AG->GEMM overlap auxiliary buffer is sized incorrectly, expected ", + aux_out_numel, " elements ", to_string_like(aux_out_shape), " but got ", + aux_out->element_count(), " elements ", to_string_like(aux_out->dimensions())); + } + auto aux_out_ = TensorWrapper(aux_out->untyped_data(), aux_out_shape, aux_out_dtype); + + // Copy the distributed LHS operand into the local chunk of the communication buffer + executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); + + // Launch AG+GEMM + executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, stream); + } else { + NVTE_ERROR("cuBLAS GEMM w/ comm. overlap invoked with invalid collective type (", + static_cast(comm_type), ")"); + } + } + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // aux_in + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // aux_out + .Ret() // lhs_swizzled + .Ret() // rhs_swizzled + .Ret() // workspace + .Attr("scaling_mode") + .Attr("comm_overlap_method") + .Attr("comm_type") + .Attr("comm_overlap_id") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("aux_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 03194e9d72..4578a09391 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { CURRENT_TENSOR_SCALING = 3, }; +inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING || + mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING); +} + +inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: @@ -78,5 +87,11 @@ constexpr struct Alignment { std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +template +void hash_combine(int64_t &seed, const T &v, Rest... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2d7801cc20..29516a5e54 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -5,7 +5,7 @@ ************************************************************************/ #include "../extensions.h" - +#include "common/util/pybind_helper.h" namespace transformer_engine { namespace jax { @@ -55,6 +55,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); + // GEMM + dict["te_gemm_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); + // Grouped GEMM dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), @@ -63,97 +68,55 @@ pybind11::dict Registrations() { return dict; } +} // namespace jax +} // namespace transformer_engine + PYBIND11_MODULE(transformer_engine_jax, m) { - m.def("registrations", &Registrations); - m.def("get_fused_attn_backend", &GetFusedAttnBackend); - m.def("get_cuda_version", &GetCudaRuntimeVersion); - m.def("get_cudnn_version", &GetCudnnRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); - m.def("get_num_compute_streams", &nvte_get_num_compute_streams); + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + + m.def("registrations", &transformer_engine::jax::Registrations); + m.def("get_fused_attn_backend", &transformer_engine::jax::GetFusedAttnBackend); + m.def("get_cuda_version", &transformer_engine::jax::GetCudaRuntimeVersion); + m.def("get_cudnn_version", &transformer_engine::jax::GetCudnnRuntimeVersion); + m.def("get_device_compute_capability", &transformer_engine::jax::GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); - m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes); - m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes); - m.def("get_norm_fwd_workspace_sizes", &GetNormForwardWorkspaceSizes); - m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); - m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); - m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); - m.def("nvte_get_qkv_format", &nvte_get_qkv_format); - - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); - - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU) - .export_values(); - - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); - - pybind11::enum_(m, "NVTE_Norm_Type", pybind11::module_local()) - .value("LayerNorm", NVTE_Norm_Type::LayerNorm) - .value("RMSNorm", NVTE_Norm_Type::RMSNorm) - .export_values(); - - pybind11::enum_(m, "JAXX_Scaling_Mode", pybind11::module_local()) - .value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING) - .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) - .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) - .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) - .export_values(); + m.def("get_dact_dbias_quantize_workspace_sizes", + &transformer_engine::jax::GetDActDBiasQuantizeWorkspaceSizes); + m.def("get_dbias_quantize_workspace_sizes", + &transformer_engine::jax::GetDBiasQuantizeWorkspaceSizes); + m.def("get_norm_fwd_workspace_sizes", &transformer_engine::jax::GetNormForwardWorkspaceSizes); + m.def("get_norm_bwd_workspace_sizes", &transformer_engine::jax::GetNormBackwardWorkspaceSizes); + m.def("get_fused_attn_fwd_workspace_sizes", + &transformer_engine::jax::GetFusedAttnForwardWorkspaceSizes); + m.def("get_fused_attn_bwd_workspace_sizes", + &transformer_engine::jax::GetFusedAttnBackwardWorkspaceSizes); + m.def("create_comm_overlap_buffer", &transformer_engine::jax::CreateCommOverlapBuffer, + pybind11::arg("comm_type"), pybind11::arg("method"), pybind11::arg("buffer_shape"), + pybind11::arg("buffer_dtype"), pybind11::arg("tp_size"), pybind11::pos_only(), + pybind11::kw_only(), pybind11::arg("num_splits") = 4, pybind11::arg("num_max_streams") = 3, + pybind11::arg("comm_cga_size") = 2, pybind11::arg("gemm_priority") = 0, + pybind11::arg("comm_priority") = 0, pybind11::arg("num_comm_sm") = 16, + pybind11::arg("set_sm_margin") = true, pybind11::arg("use_ce") = true, + pybind11::arg("atomic_gemm") = false, pybind11::arg("rs_overlap_first_gemm") = false, + pybind11::arg("aggregate_ag") = false, + pybind11::call_guard()); + m.def("destroy_comm_overlap_buffer", &transformer_engine::jax::DestroyCommOverlapBuffer, + pybind11::call_guard()); + m.def("destroy_all_comm_overlap_buffers", &transformer_engine::jax::DestroyAllCommOverlapBuffers, + pybind11::call_guard()); + + pybind11::enum_(m, "JAXX_Scaling_Mode", + pybind11::module_local()) + .value("NO_SCALING", transformer_engine::jax::JAXX_Scaling_Mode::NO_SCALING) + .value("DELAYED_TENSOR_SCALING", + transformer_engine::jax::JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) + .value("MXFP8_1D_SCALING", transformer_engine::jax::JAXX_Scaling_Mode::MXFP8_1D_SCALING) + .value("CURRENT_TENSOR_SCALING", + transformer_engine::jax::JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING); pybind11::enum_(m, "QuantizeLayout", pybind11::module_local()) .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) - .export_values(); + .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE); } - -} // namespace jax -} // namespace transformer_engine diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 57170e85be..a1e149abe8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -30,6 +30,8 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, + comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -43,25 +45,47 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set: + if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims) + output = tex.gemm(x, kernel, dimension_numbers=(contracting_dims, ((), ()))) if bias is not None: bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) else: - output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) + output = _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, + ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, +): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -74,45 +98,93 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer contracting_dims: Contracting dimensions specification input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix + batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, ) return output -def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): +def _dense_fwd_rule( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + comm_overlaps, + quantizer_set, +): """Forward pass rule for dense layer transformation. Returns: Tuple of (output, context) for backward pass """ - x_contracting_dims, k_contracting_dims = contracting_dims + x_contracting_dims, k_contracting_dims = map( + tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims + ) + + # Check supported input layout + x_is_transposed = x.ndim - 1 not in x_contracting_dims + k_is_transposed = kernel.ndim - 1 in k_contracting_dims + assert ( + not x_is_transposed and not k_is_transposed + ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." + + # Determine X batch dimension + # - If `batch_first=True` -> (batch, leading..., contracting...) + # - Otherwise -> (leading..., batch, contracting...) + # NOTE: Always assume a single batch dimension + x_bdim = None + num_cdims = len(x_contracting_dims) + if x.ndim >= num_cdims + 2: + # Assume X is batched if it has at least +2 dimensions more than the number of contracting + # dimensions. + x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) - casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) + casted_x = tex.quantize( + x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel + kernel, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.kernel, + noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # GEMM NN + use_bias = bias is not None output = tex.gemm( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + comm_overlap=comm_overlaps.fprop, ) - use_bias = bias is not None - if use_bias: + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) @@ -124,20 +196,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, comm_overlaps, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. Returns: Tuple of gradients with respect to inputs """ - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims - ( casted_x_lhs, casted_kernel_rhs, @@ -146,12 +217,42 @@ def _dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, + x_bdim, ) = ctx + fwd_x_contracting_dims, fwd_k_contracting_dims = map( + tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims + ) + casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis_k, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, + ) + casted_grad = with_sharding_constraint_by_logical_axes( + casted_grad, + comm_overlaps.fprop.get_logical_output_axes( + input_axes, kernel_axes, (contracting_dims, ((x_bdim,), ())) + ), ) + # If casted_x has transposed data-layout, we need to untranspose it here, and then transpose + # it back after the bulk-AG. This should ideally never be necessary if the data layouts are + # handled correctly in the tensor usages. + dgrad_aux_in = None + dgrad_aux_transposed_axes = ( + *tuple(range(casted_x_lhs.flatten_axis, casted_x_lhs.ndim)), + *tuple(range(casted_x_lhs.flatten_axis)), + ) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: + dgrad_aux_in = ( + casted_x_lhs.data.transpose(dgrad_aux_transposed_axes) + if casted_x_lhs.data_layout == "T" + else casted_x_lhs.data + ) + # GEMM NT # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim g_contracting_dim = tuple( @@ -164,9 +265,10 @@ def _dense_bwd_rule( dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, - (g_contracting_dim, k_contracting_dim), + dimension_numbers=((g_contracting_dim, k_contracting_dim), ((x_bdim,), ())), + comm_overlap=comm_overlaps.dgrad, + aux_in=dgrad_aux_in, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -174,13 +276,42 @@ def _dense_bwd_rule( range(0, len(x_shape) - len(fwd_x_contracting_dims)) ) + casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: + # LHS was bulk all-gathered during DGRAD and returned as auxiliary input + casted_x_lhs.data = ( + dgrad[-1].transpose(dgrad_aux_transposed_axes) + if casted_x_lhs.data_layout == "T" + else dgrad[-1] + ) + # DGRAD output will need to be bulk reduce-scattered during WGRAD + dgrad = dgrad[0] + elif comm_overlaps.dgrad.is_all_gather() and comm_overlaps.dgrad.output_all_gathered_lhs: + # GRAD was all-gathered for DGRAD and a copy of the gathered GRAD is in the auxiliary output + casted_grad_rhs.data = ( + dgrad[-1].transpose( + *range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), + *range(casted_grad_rhs.flatten_axis), + ) + if casted_grad_rhs.data_layout == "T" + else dgrad[-1] + ) + dgrad = dgrad[0] + wgrad = tex.gemm( casted_x_lhs, - casted_grad.get_tensor(usage=TensorUsage.RHS), - (x_contracting_dim, g_contracting_dim), + casted_grad_rhs, + dimension_numbers=((x_contracting_dim, g_contracting_dim), ((x_bdim,), (x_bdim,))), + comm_overlap=comm_overlaps.wgrad, + aux_in=(dgrad if comm_overlaps.wgrad.is_bulk() else None), ) - wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) + if comm_overlaps.wgrad.is_bulk(): + # DGRAD was bulk reduce-scattered during WGRAD and returned as auxiliary output + dgrad = wgrad[-1] + wgrad = wgrad[0] + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) + wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index bd311472f0..4aa0c75c25 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -7,6 +7,7 @@ from functools import reduce import operator from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union +from dataclasses import field import numpy as np import jax.numpy as jnp @@ -29,18 +30,24 @@ jax_scaled_softmax, jax_scaled_masked_softmax, jax_scaled_upper_triang_masked_softmax, + CommOverlapHelper, + CommOverlapHelperSet, ) from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode -from ..sharding import get_non_contracting_logical_axes +from ..sharding import ( + get_non_contracting_logical_axes, + global_mesh_resource, +) + +import transformer_engine_jax as tex PRNGKey = Any Shape = Tuple[int, ...] -DType = jnp.dtype -Array = jnp.ndarray +jnp.dtype = jnp.dtype PrecisionLike = Union[ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] ] -Initializer = Callable[[PRNGKey, Shape, DType], Array] +Initializer = Callable[[PRNGKey, Shape, jnp.dtype], jnp.ndarray] def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: @@ -108,7 +115,7 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") -def _combine_biases(*masks: List[Array]): +def _combine_biases(*masks: List[jnp.ndarray]): """Combine attention biases.""" masks = [m for m in masks if m is not None] if not masks: @@ -149,6 +156,46 @@ def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, return output +def _generate_comm_overlap_meta( + input_shape: Sequence[int], + input_axes: Sequence[str], + param_shape: Sequence[int], + param_axes: Sequence[str], + config: dict, +): + method = config.pop("method", tex.CommOverlapMethod.RING_EXCHANGE) + if method == tex.CommOverlapMethod.NONE: + return CommOverlapHelperSet() + + tp_resource = config.pop("tp_resource", global_mesh_resource().tp_resource) + + input_sp_dim = list(nn.logical_to_mesh_axes(input_axes)).index(tp_resource) + logical_sp_axis = config.pop("logical_sp_axis", input_axes[input_sp_dim]) + + param_tp_dim = list(nn.logical_to_mesh_axes(param_axes)).index(tp_resource) + logical_tp_axis = config.pop("logical_tp_axis", param_axes[param_tp_dim]) + + row_parallel = param_tp_dim == 0 + comm_type = tex.CommOverlapType.RS if row_parallel else tex.CommOverlapType.AG + _ = config.pop("comm_type") + + buffer_shape = config.pop( + "buffer_shape", (*input_shape[:-1], param_shape[-1]) if row_parallel else input_shape + ) + + return CommOverlapHelperSet( + fprop=CommOverlapHelper( + comm_type=comm_type, + method=method, + buffer_shape=buffer_shape, + tp_resource=tp_resource, + logical_tp_axis=logical_tp_axis, + logical_sp_axis=logical_sp_axis, + **config, + ) + ) + + class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. @@ -172,7 +219,9 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods softmax_type: SoftmaxType = SoftmaxType.SCALED @nn.compact - def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp.ndarray: + def __call__( + self, inputs: jnp.ndarray, mask: jnp.ndarray = None, bias: jnp.ndarray = None + ) -> jnp.ndarray: batch = inputs.shape[0] heads = inputs.shape[1] q_seqlen = inputs.shape[2] @@ -287,7 +336,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods scale_axes: Tuple[str, ...] = ("embed",) bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): @@ -415,12 +464,17 @@ class DenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + enable_comm_overlap: bool, default = False + Enable fine-grained All-Gather or Reduce-Scatter overlap with GEMM for sequence-parallel + inputs. + comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options. Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -436,9 +490,11 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 + dtype: jnp.dtype = jnp.float32 transpose_batch_sequence: bool = False input_axes: Tuple[str, ...] = () + enable_comm_overlap: bool = False + comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -448,7 +504,7 @@ def __post_init__(self): super().__post_init__() @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """ Apply the dense layer transformation to the input. @@ -476,9 +532,15 @@ def __call__(self, inputs: Array) -> Array: "Expected len(kernel_shape) to match len(kernel_axes)," f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" ) + else: + assert not self.enable_comm_overlap, ( + "Communication + GEMM overlap requires the dot kernel sharding to be defined in " + "`kernel_axes`." + ) + kernel_partitioning = nn.with_logical_partitioning(self.kernel_init, self.kernel_axes) kernel = self.param( "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_partitioning, kernel_shape, self.dtype, ) @@ -498,6 +560,9 @@ def __call__(self, inputs: Array) -> Array: quantizer_set = self.generate_quantizer_set() contract_ind = tuple(range(0, len(axis))) + + if not self.enable_comm_overlap: + self.comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) y = dense( inputs, kernel, @@ -505,6 +570,14 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + comm_overlaps=_generate_comm_overlap_meta( + inputs.shape, + self.input_axes, + kernel.shape, + self.kernel_axes, + self.comm_overlap_method, + ), + batch_first=not self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -617,12 +690,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input of dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + enable_comm_overlap: bool, default = False + Enable fine-grained All-Gather overlap with GEMM for sequence-parallel inputs. + comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options. Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -650,11 +727,13 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True + dtype: jnp.dtype = jnp.float32 + transpose_batch_sequence: bool = False layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None + enable_comm_overlap: bool = False + comm_overlap_config: dict = field(default_factory=dict) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -672,7 +751,7 @@ def __post_init__(self): super().__post_init__() @nn.compact - def __call__(self, inputs: Array) -> Array: + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """ Apply layer normalization to the input followed by a dense layer transformation. @@ -742,9 +821,21 @@ def __call__(self, inputs: Array) -> Array: axis = _normalize_axes(axis, y.ndim) kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features + + if self.kernel_axes: + assert len(kernel_shape) == len(self.kernel_axes), ( + "Expected len(kernel_shape) to match len(kernel_axes)," + f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" + ) + else: + assert not self.enable_comm_overlap, ( + "Communication + GEMM overlap requires the dot kernel sharding to be defined in " + "`kernel_axes`." + ) + kernel_partitioning = nn.with_logical_partitioning(self.kernel_init, self.kernel_axes) kernel = self.param( "kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_partitioning, kernel_shape, self.dtype, ) @@ -753,6 +844,16 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(0, len(axis))) + # All-Gather is the only supported collective to overlap in LayerNormDenseGeneral + if not self.enable_comm_overlap: + self.comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) + comm_overlaps = _generate_comm_overlap_meta( + inputs.shape, + self.layernorm_input_axes, + kernel_shape, + self.kernel_axes, + self.comm_overlap_config, + ) if fuse_layernorm: z = layernorm_dense( y, @@ -765,7 +866,9 @@ def __call__(self, inputs: Array) -> Array: layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, + batch_first=not self.transpose_batch_sequence, quantizer_set=quantizer_set, + comm_overlaps=comm_overlaps, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -775,7 +878,9 @@ def __call__(self, inputs: Array) -> Array: contracting_dims=(axis, contract_ind), input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, + batch_first=not self.transpose_batch_sequence, quantizer_set=quantizer_set, + comm_overlaps=comm_overlaps, ) if self.enable_low_rank_adaptation: @@ -924,12 +1029,25 @@ class LayerNormMLP(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input of 2nd dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + enable_comm_overlap: bool, default = False + Enable fine-grained All-Gather overlap with the 1st dot and Reduce-Scatter overlap with + the 2nd dot. + enable_dot_1_comm_overlap: bool, default = False + Enable fine-grained All-Gather overlap with the 1st dot. This option is overriden by + `enable_comm_overlap=True`. + enable_dot_2_comm_overlap: bool, default = False + Enable fine-grained Reduce-Scatter overlap with the 2nd dot. This option is overriden by + `enable_comm_overlap=True`. + dot_1_comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options for the 1st dot. + dot_2_comm_overlap_config: dict, default = {} + Optional config dictionary for controlling communication overlap options for the 2nd dot. Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True + transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). @@ -960,11 +1078,20 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 - dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True + dtype: jnp.dtype = jnp.float32 + transpose_batch_sequence: bool = False layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None + enable_comm_overlap: bool = False + enable_dot_1_comm_overlap: bool = False + enable_dot_2_comm_overlap: bool = False + dot_1_comm_overlap_config: dict = field( + default_factory=dict + ) # pylint: disable=invalid-field-call + dot_2_comm_overlap_config: dict = field( + default_factory=dict + ) # pylint: disable=invalid-field-call def __post_init__(self): if self.kernel_init is None: @@ -978,7 +1105,7 @@ def __post_init__(self): super().__post_init__() @nn.compact - def __call__(self, inputs: Array, deterministic: bool = False) -> Array: + def __call__(self, inputs: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1082,9 +1209,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): axis = _canonicalize_tuple(self.axis) axis = _normalize_axes(axis, y.ndim) kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim) + self.enable_dot_1_comm_overlap = self.enable_dot_1_comm_overlap or self.enable_comm_overlap + if self.kernel_1_axes: + assert len(kernel_1_each_shape) == len(self.kernel_axes), ( + "Expected len(kernel_1_shape) to match len(kernel_1_axes)," + f"got kernel_shape {kernel_1_each_shape} and kernel_axes {self.kernel_1_axes}" + ) + else: + assert not self.enable_dot_1_comm_overlap, ( + "Communication + GEMM overlap for the 1st dot requires the kernel sharding to be " + "defined in `kernel_1_axes`." + ) + kernel_1_partitioning = nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1) kernel_1 = self.param( "wi_kernel", - nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1), + kernel_1_partitioning, num_activations, -2, kernel_1_each_shape, @@ -1097,9 +1236,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple + self.enable_dot_2_comm_overlap = self.enable_dot_2_comm_overlap or self.enable_comm_overlap + if self.kernel_2_axes: + assert len(kernel_2_shape) == len(self.kernel_2_axes), ( + "Expected len(kernel_2_shape) to match len(kernel_2_axes)," + f"got kernel_shape {kernel_2_shape} and kernel_axes {self.kernel_2_axes}" + ) + else: + assert not self.enable_dot_2_comm_overlap, ( + "Communication + GEMM overlap for the 2nd dot requires the kernel sharding to be " + "defined in `kernel_2_axes`." + ) + kernel_2_partitioning = nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2) kernel_2 = self.param( "wo_kernel", - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2), + kernel_2_partitioning, kernel_2_shape, self.dtype, ) @@ -1131,6 +1282,25 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name = "ffn1" ffn2_ckpt_name = "ffn2" + if not self.enable_dot_1_comm_overlap: + self.dot_1_comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) + ffn1_comm_overlaps = _generate_comm_overlap_meta( + inputs.shape, + self.layernorm_input_axes, + kernel_1.shape, + self.kernel_axes_1, + self.dot_1_comm_overlap_config, + ) + + if not self.enable_dot_2_comm_overlap: + self.dot_2_comm_overlap_config.update({"method": tex.CommOverlapMethod.NONE}) + ffn2_comm_overlaps = _generate_comm_overlap_meta( + inputs.shape, + self.dot_2_input_axes, + kernel_2.shape, + self.kernel_axes_2, + self.dot_2_comm_overlap_config, + ) if use_fused_layernorm_mlp: out = layernorm_mlp( y, @@ -1149,7 +1319,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=ffn1_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name, activation_type=normalized_acts, + batch_first=not self.transpose_batch_sequence, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), + ffn1_comm_overlaps=ffn1_comm_overlaps, + ffn2_comm_overlaps=ffn2_comm_overlaps, ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1167,7 +1340,9 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): layernorm_input_axes=self.layernorm_input_axes, dot_input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, + batch_first=not self.transpose_batch_sequence, quantizer_set=ffn1_quantizer_set, + comm_overlaps=ffn1_comm_overlaps, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) @@ -1177,7 +1352,9 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): contracting_dims=(axis, contract_ind), input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, + batch_first=not self.transpose_batch_sequence, quantizer_set=ffn1_quantizer_set, + comm_overlaps=ffn1_comm_overlaps, ) if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None: @@ -1259,6 +1436,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_2_input_axes, kernel_axes=self.kernel_axes_2, quantizer_set=ffn2_quantizer_set, + comm_overlaps=ffn2_comm_overlaps, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index ea66e78302..09bb0cfb9a 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -37,6 +37,8 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + batch_first: bool = True, + comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -57,6 +59,8 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix + batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: Set of quantizers for different tensor types Returns: @@ -80,6 +84,8 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, + comm_overlaps, quantizer_set, ) return output @@ -94,6 +100,8 @@ def layernorm_dense( 8, 9, 10, + 11, + 12, ), ) def _layernorm_dense( @@ -108,6 +116,8 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], + batch_first: bool, + comm_overlaps: tex.CommOverlapHelperSet, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -127,6 +137,8 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding + batch_first: Assume that X is batched in the first dimension. + comm_overlaps: A set of CommOverlapHelper objecst for FPROP, DGRAD and WGRAD GEMMs. quantizer_set: Set of quantizers Returns: @@ -144,6 +156,8 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, + comm_overlaps, quantizer_set, ) return output @@ -161,6 +175,8 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, + batch_first, + comm_overlaps, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -178,6 +194,10 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] + x_bdim = None + if x.ndim > 2: + x_bdim = 0 if batch_first else x.ndim - 2 + x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -187,30 +207,54 @@ def _layernorm_dense_fwd_rule( zero_centered_gamma, epsilon, norm_type, - quantizer_set.x, + quantizer=quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) - casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) + casted_kernel = tex.quantize( + kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out...) + # (batch..., sequence, hidden_in) x (hidden_in, hidden_out...) + # NOTE: Comm+GEMM overlap can only do AG->GEMM here to all-gather a sequence-parallel layernorm + # output because the weights for a QKV projection is always column-parallel. + use_bias = bias is not None output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + bias=bias if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + comm_overlap=comm_overlaps.fprop, ) - use_bias = bias is not None - if use_bias: + # If Comm+GEMM overlap for FPROP was configured to return the all-gathered layernorm output + # as the auxiliary output, we may need to transpose it here to match the expected data + # layout in the backward pass. Otherwise, the + casted_ln_out_for_bwd = casted_ln_out.get_tensor(TensorUsage.LHS_TRANS) + ln_out_transposed_dims = ( + *tuple(range(casted_ln_out_for_bwd.flatten_axis, casted_ln_out_for_bwd.ndim)), + *tuple(range(casted_ln_out_for_bwd.flatten_axis)), + ) + if comm_overlaps.fprop.output_all_gathered_lhs: + casted_ln_out_for_bwd.data = ( + output[-1].transpose(ln_out_transposed_dims) + if casted_ln_out_for_bwd.data_layout == "T" + else output[-1] + ) + output = output[0] + + if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), + casted_ln_out_for_bwd, casted_kernel.get_tensor(TensorUsage.RHS_TRANS), x.shape, kernel.shape, @@ -224,6 +268,7 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) return output, ctx @@ -236,6 +281,8 @@ def _layernorm_dense_bwd_rule( layernorm_input_axes, dot_input_axes, # pylint: disable=unused-argument kernel_axes, + batch_first, # pylint: disable=unused-argument + comm_overlaps, ctx, grad, ): @@ -265,10 +312,23 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, + x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( - grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad + grad, + is_dbias=use_bias, + flatten_axis=flatten_axis, + quantizer=quantizer_set.dgrad, + noop_scaled_tensor=True, + ) + casted_grad = with_sharding_constraint_by_logical_axes( + casted_grad, + comm_overlaps.fprop.get_logical_output_axes( + dot_input_axes, + kernel_axes, + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim,), ())), + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -280,26 +340,60 @@ def _layernorm_dense_bwd_rule( dim for dim in range(len(kernel_shape)) if dim not in k_contracting_dims_in_fwd ) + # If casted_ln_out has transposed data-layout, we need to untranspose it here, and then + # transpose it back after the bulk-AG. This should ideally never be necessary if the data + # layouts are handled correctly in the tensor usages. + dgrad_aux_in = None + casted_ln_out_transposed_axes = ( + *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), + *tuple(range(casted_ln_out.flatten_axis)), + ) + casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: + dgrad_aux_in = ( + casted_ln_out.data.transpose(casted_ln_out_transposed_axes) + if casted_ln_out.data_layout == "T" + else casted_ln_out.data + ) + # NT GEMM dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, - (g_constracting_dim, k_constracting_dim), + dimension_numbers=((g_constracting_dim, k_constracting_dim), ((x_bdim,), ())), + comm_overlap=comm_overlaps.dgrad, + aux_in=dgrad_aux_in, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) - g_constracting_dim = x_constracting_dim = tuple( range(0, len(x_shape) - len(x_contracting_dims_in_fwd)) ) # TN GEMM + casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) + if comm_overlaps.dgrad.is_bulk() and not comm_overlaps.fprop.output_all_gathered_lhs: + # LHS was bulk all-gathered during DGRAD and returned as auxiliary input + casted_ln_out.data = ( + dgrad[-1].transpose(casted_ln_out_transposed_axes) + if casted_ln_out.data_layout == "T" + else dgrad[-1] + ) + # DGRAD output will need to be bulk reduce-scattered during WGRAD + dgrad = dgrad[0] + wgrad = tex.gemm( casted_ln_out, - casted_grad.get_tensor(TensorUsage.RHS), - (x_constracting_dim, g_constracting_dim), + casted_grad_rhs, + dimension_numbers=((x_constracting_dim, g_constracting_dim), ((x_bdim,), (x_bdim,))), + comm_overlap=comm_overlaps.wgrad, + aux_in=(dgrad if comm_overlaps.wgrad.is_bulk() else None), ) + if comm_overlaps.wgrad.is_bulk(): + # DGRAD was bulk reduce-scattered during WGRAD and returned as auxiliary output + dgrad = wgrad[-1] + wgrad = wgrad[0] + dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) dx, dgamma, dbeta = tex.normalization_bwd( diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 18563fd255..84df3e29f1 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -37,7 +37,7 @@ def layernorm_mlp( beta: jnp.ndarray, kernels: List[jnp.ndarray], biases: List[jnp.ndarray], - norm_type: str, + norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, norm_input_axes: Tuple[str, ...] = None, @@ -48,6 +48,9 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + batch_first: bool = True, + ffn1_comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), + ffn2_comm_overlaps: tex.CommOverlapHelperSet = tex.CommOverlapHelperSet(), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -79,6 +82,9 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + ffn1_comm_overlaps: A set of CommOverlapHelper objects for FFN1 FPROP, DGRAD and WGRAD. + ffn2_comm_overlaps: A set of CommOverlapHelper objects for FFN2 FPROP, DGRAD and WGRAD. + batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -124,12 +130,15 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -149,6 +158,9 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + batch_first: bool, + ffn1_comm_overlaps: tex.CommOverlapHelperSet, + ffn2_comm_overlaps: tex.CommOverlapHelperSet, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -174,6 +186,9 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + batch_first: Assume that X is batched in the first dimension. + ffn1_comm_overlaps: A set of CommOverlapHelper objects for FFN1 FPROP, DGRAD and WGRAD. + ffn2_comm_overlaps: A set of CommOverlapHelper objects for FFN2 FPROP, DGRAD and WGRAD. quantizer_sets: Tuple of quantizer sets Returns: @@ -198,6 +213,9 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, quantizer_sets, ) return output @@ -222,6 +240,9 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -238,8 +259,6 @@ def _layernorm_mlp_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - del kernel_2_axes - ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets # x should be in shape of (batch..., hidden) @@ -254,6 +273,10 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] + x_bdim = None + if x.ndim > 2: + x_bdim = 0 if batch_first else x.ndim - 2 + use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -267,27 +290,37 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) - casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) + casted_kernel_1 = tex.quantize( + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + ) + casted_kernel_1 = with_sharding_constraint_by_logical_axes(casted_kernel_1, kernel_1_axes) # NN GEMM - # (batch..., hidden_in) x (hidden_in, hidden_out) + # (batch..., sequence, hidden_in) x (hidden_in, hidden_out) + # NOTE: Comm+GEMM overlap can only do AG->GEMM here to all-gather a sequence-parallel layernorm + # output. dot_1_output = tex.gemm( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + bias=bias_1 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + comm_overlap=ffn1_comm_overlaps.fprop, + ) + dot_1_output = with_sharding_constraint_by_logical_axes( + dot_1_output, + ffn1_comm_overlaps.fprop.get_logical_output_axes( + dot_1_input_axes, + kernel_1_axes, + ((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + ), ) - if dot_1_input_axes is not None and kernel_1_axes is not None: - dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), - ) - dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - - if use_bias_1: + if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) @@ -295,21 +328,32 @@ def _layernorm_mlp_fwd_rule( dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) - casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) + casted_act_out = tex.act_lu( + dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) - casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) + casted_kernel_2 = tex.quantize( + kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + ) + casted_kernel_2 = with_sharding_constraint_by_logical_axes(casted_kernel_2, kernel_2_axes) # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) + # NOTE: Comm+GEMM overlap can only do GEMM->RS to reduce-scatter the FFN2 output. We don't need + # an auxiliary input/output here for this because it's already handled in the custom op + # and the returned array is the final reduce-scattered result. dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), - (x_contracting_dims, k_contracting_dims), + dimension_numbers=((x_contracting_dims, k_contracting_dims), ((x_bdim,), ())), + bias=bias_2 if not tex.gemm_uses_jax_dot() else None, + fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + comm_overlap=ffn2_comm_overlaps.fprop, ) - if use_bias_2: + if use_bias_2 and tex.gemm_uses_jax_dot(): bias_2_shape = bias_2.shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) @@ -334,6 +378,7 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) return dot_2_output, ctx @@ -351,6 +396,9 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + batch_first, + ffn1_comm_overlaps, + ffn2_comm_overlaps, ctx, grad, ): @@ -367,7 +415,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first ( x, mu, @@ -386,15 +434,21 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, + x_bdim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets - # Since the sharding of outputs should be the same as dot_1's input - grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad + grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True + ) + casted_grad = with_sharding_constraint_by_logical_axes( + casted_grad, + ffn2_comm_overlaps.fprop.get_logical_output_axes( + dot_2_input_axes, + kernel_2_axes, + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim,), ())), + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -408,25 +462,43 @@ def _layernorm_mlp_bwd_rule( # NT GEMM # (batch..., hidden_out) x (hidden_in, hidden_out) + # NOTE: The only possible comm. overlap with FFN2 DGRAD is an AG+GEMM with all-gathered + # gradient returned in the auxiliary output to be re-used in the FFN2 WGRAD GEMM. dgrad_2 = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, - (g_contracting_dims_2, k_contracting_dims_2), + dimension_numbers=((g_contracting_dims_2, k_contracting_dims_2), ((x_bdim,), ())), + comm_overlap=ffn2_comm_overlaps.dgrad, ) - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - x_contracting_dims = g_contracting_dims = tuple( range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) ) # TN GEMM # (hidden, batch...,) x (hidden, batch...) + # NOTE: There is no possible comm. overlap with FFN2 WGRAD, but we need to re-use the + # all-gathered gradient returned in the auxiliary output of FFN2 DGRAD. + casted_grad_rhs = casted_grad.get_tensor(usage=TensorUsage.RHS) + if ffn2_comm_overlaps.dgrad.is_enabled: + casted_grad_rhs.data = ( + dgrad_2[-1].transpose( + *range(casted_grad_rhs.flatten_axis, casted_grad_rhs.ndim), + *range(casted_grad_rhs.flatten_axis) + ) + if casted_grad_rhs.data_layout == "T" + else dgrad_2[-1] + ) + dgrad_2 = dgrad_2[0] + wgrad_2 = tex.gemm( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), - (x_contracting_dims, g_contracting_dims), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), + comm_overlap=ffn2_comm_overlaps.wgrad, ) + + dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) casted_dact_out, dbias_1 = tex.quantize_dact_dbias( @@ -435,6 +507,15 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + noop_scaled_tensor=True, + ) + casted_dact_out = with_sharding_constraint_by_logical_axes( + casted_dact_out, + ffn1_comm_overlaps.fprop.get_logical_output_axes( + dot_1_input_axes, + kernel_1_axes, + ((x_contracting_dims_in_fwd, k_contracting_dims_in_fwd), ((x_bdim,), ())), + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -447,23 +528,56 @@ def _layernorm_mlp_bwd_rule( dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd ) + # If FFN1 DGRAD is bulk all-gathering the layernorm output, but the layernorm output + # has transposed data layout, we need to un-transpose it here before the all-gather and + # transpose it again before using it in FFN1 WGRAD. Also make sure we do not already have the + # the gathered layernorm output from FPROP. + # NOTE: This transpose should not be necessary if the tensor usages work correctly! + dgrad_1_aux_in = None + ln_out_transposed_dims = ( + *tuple(range(casted_ln_out.flatten_axis, casted_ln_out.ndim)), + *tuple(range(casted_ln_out.flatten_axis)), + ) + casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) + if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_all_gathered_lhs: + dgrad_1_aux_in = ( + casted_ln_out.data.transpose(ln_out_transposed_dims) + if casted_ln_out.data_layout == "T" + else casted_ln_out.data + ) + # NT GEMM dgrad_1 = tex.gemm( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, - (g_contracting_dims_1, k_contracting_dims_1), + dimension_numbers=((g_contracting_dims_1, k_contracting_dims_1), ((x_bdim,), ())), + comm_overlap=ffn1_comm_overlaps.dgrad, + aux_in=dgrad_1_aux_in, ) - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) + if ffn1_comm_overlaps.dgrad.is_bulk() and not ffn1_comm_overlaps.fprop.output_all_gathered_lhs: + casted_ln_out.data = ( + dgrad_1[-1].transpose(ln_out_transposed_dims) + if casted_ln_out.data_layout == "T" + else dgrad_1[-1] + ) + dgrad_1 = dgrad_1[0] # TN GEMM # (hidden, batch...) x (hidden, batch...) wgrad_1 = tex.gemm( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), - (x_contracting_dims, g_contracting_dims), + dimension_numbers=((x_contracting_dims, g_contracting_dims), ((x_bdim,), (x_bdim,))), + comm_overlap=ffn1_comm_overlaps.wgrad, + aux_in=(dgrad_1 if ffn1_comm_overlaps.wgrad.is_bulk() else None), ) + if ffn1_comm_overlaps.wgrad.is_bulk(): + # FFN1 DGRAD was bulk reduce-scattered during FFN2 WGRAD and returned as auxiliary output + dgrad_1 = wgrad_1[-1] + wgrad_1 = wgrad_1[0] + dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) dx, dgamma, dbeta = tex.normalization_bwd( diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 06a2562fb1..2459190f1a 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -36,6 +36,22 @@ def dequantize(scaled_tensor): """Dequantizing given tensor to higher precision.""" +@dataclass +class NoopDequantizer(Dequantizer): + """No-op Dequantizer Class""" + + @staticmethod + def _dequantize_func(data, *args, **kwargs): + """A no-op dequantize function that returns the data without any changes.""" + del args, kwargs + return data + + @staticmethod + def dequantize(scaled_tensor): + """A no-op dequantize function that simply returns the data array in the ScaledTensor.""" + return scaled_tensor.data + + class TensorScaleDequantizer(Dequantizer): """ TensorScaling Dequantizer Class @@ -152,6 +168,7 @@ def dequantize(scaled_tensor): ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NO_SCALING: NoopDequantizer, } diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index c0617eafbb..122265ea27 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -9,7 +9,9 @@ """ from contextlib import contextmanager from enum import Enum -from typing import Optional, Tuple, Dict, Union +from typing import Optional, Tuple, Dict, Union, Sequence +from functools import reduce +import operator import jax import jax.numpy as jnp @@ -29,6 +31,8 @@ "is_fp8_available", "update_collections", "get_delayed_scaling", + "apply_padding_to_scale_inv", + "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", ] @@ -471,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection: return new_coll +def remove_padding_from_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Slice padding out of padded inverse scale factors. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Inverse scale factor without padding. + """ + # Get expected unpadded scale shape and check if inverse scale already matches + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape: + return scale_inv + + # Get the padded scale shape and make sure inverse scale matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, + is_colwise=is_colwise, + is_padded=True, + flatten_axis=flatten_axis, + ) + assert scale_inv.shape == padded_scale_shape, ( + f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got " + f"{scale_inv.shape} instead." + ) + + # Reshape scale inverse to 2D in two stages to preserve the flatten axis + padded_scale_shape_2d = ( + reduce(operator.mul, padded_scale_shape[:flatten_axis]), + reduce(operator.mul, padded_scale_shape[flatten_axis:]), + ) + scale_inv_2d = jnp.reshape( + jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])), + padded_scale_shape_2d, + ) + + # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape + unpadded_scale_shape_2d = ( + reduce(operator.mul, unpadded_scale_shape[:flatten_axis]), + reduce(operator.mul, unpadded_scale_shape[flatten_axis:]), + ) + scale_inv_2d_unpadded = jnp.asarray( + scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]] + ) + + # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis + scale_inv_unpadded = jnp.reshape( + jnp.reshape( + scale_inv_2d_unpadded, + (*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]), + ), + unpadded_scale_shape, + ) + return scale_inv_unpadded + + +def apply_padding_to_scale_inv( + scale_inv: jax.Array, + scaling_mode: ScalingMode, + data_shape: Sequence[int], + is_colwise: bool = False, + flatten_axis: int = -1, +): + """ + Pad the scale inverse with zeros to match the necessary padded shape for this scaling + mode. + + Args: + scale_inv: Inverse scale factor. + data_shape: Shape of the quantized data the inverse scale belongs to. + scaling_mode: ScalingMode representing the quantization method. + is_colwise: Whether the data was quantized column-wise. + flatten_axis: The axis along with the data could be flattened to 2D. + + Returns: + Padded inverse scale factor. + """ + # Get the expected padded scale shape and check if inverse scale already matches + padded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis + ) + if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape: + return scale_inv + + # Get the expected unpadded scale shape and make sure inverse scales match + unpadded_scale_shape = scaling_mode.get_scale_shape( + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis + ) + assert scale_inv.shape == unpadded_scale_shape, ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got " + f"{scale_inv.shape}." + ) + + # Pad the scales with the lowest representable value (2^-127) and return + pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) + return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) + + NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 633be237f9..a1ea83152b 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -17,6 +17,7 @@ from transformer_engine_jax import QuantizeLayout +from .helper import apply_padding_to_scale_inv from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap from ..sharding import ( @@ -56,6 +57,11 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children, *aux_data) + @property + @abstractmethod + def ndim(self): + """Number of dimensions of the underlying quantized array.""" + @abstractmethod def dequantize(self): """Dequantizes the tensor back to its original precision. @@ -127,24 +133,16 @@ def __post_init__(self): 0 < self.flatten_axis < len(self.data.shape) ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" - expected_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis - ) - expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( - self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis - ) - if self.scale_inv.shape != expected_scale_shape: - assert self.scale_inv.shape == expected_unpadded_scale_shape, ( - f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" - f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" - f" {self.scale_inv.shape}" - ) - pad_width = tuple( - (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape) - ) - # This actually pad scale_inv with nan, should we pad it with 127 directly instead? - self.scale_inv = jnp.pad( - self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0 + if self.scaling_mode == ScalingMode.NO_SCALING: + self.scale_inv = jnp.empty((0,), dtype=jnp.float32) + + else: + self.scale_inv = apply_padding_to_scale_inv( + self.scale_inv, + self.scaling_mode, + self.data.shape, + is_colwise=self.is_colwise, + flatten_axis=self.flatten_axis, ) def tree_flatten(self): @@ -164,6 +162,10 @@ def tree_flatten(self): ) return (children, aux_data) + @property + def ndim(self): + return self.data.ndim + def dequantize(self): """Dequantizes the tensor using the stored dequantization function. @@ -199,6 +201,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return self # axis_names were given for N layout, so needs to be transpose for T layout + axis_names = logical_axis_names if self.data_layout == "T": assert self.flatten_axis > 0 assert len(logical_axis_names) == self.data.ndim @@ -207,8 +210,6 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st *logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis], ) - else: - axis_names = logical_axis_names data = with_sharding_constraint_by_logical_axes(self.data, axis_names) @@ -347,6 +348,11 @@ def tree_flatten(self): aux_data = () return (children, aux_data) + @property + def ndim(self): + """Number of dimensions of the underlying row-wise tensor.""" + return self.rowwise_tensor.ndim + def dequantize(self): """Dequantizes the tensor using the row-wise component's dequantization.