Skip to content

[JAX] Add collective GEMM without compute/communication overlap #1675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions tests/jax/test_distributed_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from functools import partial
from collections.abc import Iterable

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.experimental import mesh_utils

import transformer_engine.jax as te
from transformer_engine.jax.gemm import gemm
from transformer_engine.jax.quantize import helper

from utils import assert_allclose


jax.config.update("jax_enable_compilation_cache", False)


# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P)
# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128)
# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P)

# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128)
# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P)
# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128)

BATCH = 4
BASE_SIZE = 16
SEQ_LEN = BASE_SIZE * 8
HIDDEN_SIZE = BASE_SIZE * 6
FFN_HIDDEN_SIZE = BASE_SIZE * 16

COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"]
MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"]
NUM_DEVICES = 4

is_fp8_supported, no_fp8_reason = helper.is_fp8_available()


def _get_mesh(parallel_dist):
jax.clear_caches()

batched = False
fsdp = False
mesh_shape = dict(tp=NUM_DEVICES)
resources = dict(cp_resource="tp", tp_resource="tp")
if parallel_dist in ["DP_TP", "FSDP_TP"]:
batched = True
tp = NUM_DEVICES // 2
dp = NUM_DEVICES // tp
mesh_shape.update(dict(tp=tp, dp=dp))
resources.update(dict(dp_resource="dp"))
if parallel_dist == "FSDP_TP":
fsdp = True
dp = 1
zp = NUM_DEVICES // tp
mesh_shape.update(dict(tp=tp, dp=1, zp=zp))
resources.update(dict(fsdp_resource="zp"))
mesh_resource = te.MeshResource(**resources)

devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES])

mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys()))

return mesh, mesh_resource, batched, fsdp


def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False):
fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]

# Operand and output shapes
lhs_shape = (
[SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE]
)
rhs_shape = (
[HIDDEN_SIZE, FFN_HIDDEN_SIZE]
if fwd_comm_type == "ALL_GATHER"
else [FFN_HIDDEN_SIZE, HIDDEN_SIZE]
)
out_shape = [lhs_shape[0], rhs_shape[1]]

if batched:
lhs_shape = [BATCH] + lhs_shape
out_shape = [BATCH] + out_shape

# Operand and output partition specs
lhs_spec = (
[mesh_resource.tp_resource, None]
if fwd_comm_type == "ALL_GATHER"
else [None, mesh_resource.tp_resource]
)
rhs_spec = (
[None, mesh_resource.tp_resource]
if fwd_comm_type == "ALL_GATHER"
else [mesh_resource.tp_resource, None]
)
out_spec = [None, rhs_spec[-1]]

# Modify RHS operand for FP8
fsdp_gathered_rhs_spec = rhs_spec.copy()
if fp8_gemm:
rhs_shape = list(reversed(rhs_shape))
rhs_spec = list(reversed(rhs_spec))
fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec))

# Add batch dimensions and specs
if batched:
if fsdp:
lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec
rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec]
out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec
else:
lhs_spec = [mesh_resource.dp_resource] + lhs_spec
out_spec = [mesh_resource.dp_resource] + out_spec

# Allocate global operands on device
key = jax.random.PRNGKey(42)
split_keys = jax.random.split(key, 3 if fwd_bwd else 2)
mu = 0.0
sigma = 0.023
shapes = (lhs_shape, rhs_shape)
if fwd_bwd:
shapes += (out_shape,)
global_operands = list(
map(
lambda key, shape: jax.device_put(
mu + (sigma * jax.random.normal(key, shape, dtype=dtype)),
NamedSharding(mesh, PartitionSpec(None)),
),
split_keys,
shapes,
)
)

# Allocate sharded operands on device
partition_axes = (lhs_spec, rhs_spec)
if fwd_bwd:
partition_axes += (out_spec,)
local_operands = list(
map(
lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))),
global_operands,
partition_axes,
)
)

# Tranpose global RHS back to non-transpoosed orientation if it was originally allocated
# for FP8 GEMM
if fp8_gemm:
rhs_global = jnp.matrix_transpose(global_operands[1])
global_operands = (global_operands[0], rhs_global, *global_operands[2:])

return (
local_operands,
global_operands,
(out_shape, out_spec),
fsdp_gathered_rhs_spec,
)


def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False):
num_operands = 3 if fwd_bwd else 2
ref_operands = tensors[:num_operands]
test_outputs = tensors[num_operands:]

# Check number of dimensions
assert test_outputs[0].ndim == len(expected_out_shape), (
f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected "
+ f"({len(expected_out_shape)})"
)

# Pad test output spec for unsharded dimensions
test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim)

for i in range(test_outputs[0].ndim):
# Check shape
assert test_outputs[0].shape[i] == expected_out_shape[i], (
f"Output with shape {test_outputs[0].shape} does not match expected shape "
+ f"{expected_out_shape} in dimension index {i}."
)

# Check shardings (with padded output spec)
spec_mismatch = False
if isinstance(expected_out_specs[i], str):
if test_spec[i] != expected_out_specs[i]:
spec_mismatch = True
elif isinstance(expected_out_specs[i], Iterable):
if not isinstance(test_spec[i], type(expected_out_specs[i])):
if test_spec[i] not in expected_out_specs[i]:
spec_mismatch = True
elif len(test_spec[i]) != len(expected_out_specs[i]):
spec_mismatch = True
else:
for j in range(len(expected_out_specs[i])):
if test_spec[i][j] != expected_out_specs[i][j]:
spec_mismatch = True
break
elif expected_out_specs[i] == None:
if test_spec[i] != None:
spec_mismatch = True
else:
raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.")
if spec_mismatch:
raise AssertionError(
f"Output sharding {test_spec} does not match expected sharding "
+ f"{expected_out_specs} in dimension index {i}."
)

def _native_gemm_fwd_bwd(lhs, rhs, grad):
fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs)
lhs_grad, rhs_grad = vjp_fn(grad)
return fwd_out, lhs_grad, rhs_grad

ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot)

out_names = ["output"]
ref_outputs = ref_fn(*ref_operands)
if not fwd_bwd:
ref_outputs = [ref_outputs]
else:
out_names += ["dgrad", "wgrad"]

for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)):
test_out_global = jax.lax.with_sharding_constraint(
test_out, NamedSharding(mesh, PartitionSpec(None))
)
try:
assert_allclose(ref_out, test_out_global)
except AssertionError as err:
raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err))


@pytest.mark.parametrize("comm_type", COMM_TYPES)
@pytest.mark.parametrize("mesh_type", MESH_TYPES)
def test_gemm_impl(comm_type, mesh_type):
mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type)

(
local_operands,
global_operands,
output_info,
fsdp_gathered_rhs_spec,
) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp)

@jax.jit
def _test_fn(lhs, rhs):
rhs_no_fsdp = jax.lax.with_sharding_constraint(
rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec))
)
return te.cpp_extensions.collective_gemm_impl(lhs, rhs_no_fsdp, batched_output=batched)

with te.sharding.global_shard_guard(mesh_resource):
output, *_ = _test_fn(*local_operands)

_check_output(mesh, *output_info, *global_operands, output)


@pytest.mark.parametrize("comm_type", COMM_TYPES)
@pytest.mark.parametrize("mesh_type", MESH_TYPES)
def test_gemm_fwd_bwd(comm_type, mesh_type):
mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type)

(
local_operands,
global_operands,
output_info,
fsdp_gathered_rhs_spec,
) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True)

@jax.jit
def _test_fn(lhs, rhs, grad):
# Gather weights in FSDP axis
rhs_no_fsdp = jax.lax.with_sharding_constraint(
rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec))
)

# FWD pass
fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp)

# BWD pass
lhs_grad, rhs_grad = vjp_fn(grad)

return fwd_out, lhs_grad, rhs_grad

print(
f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n"
+ f" LHS sharding: {local_operands[0].sharding.spec}\n"
+ f" RHS sharding: {local_operands[1].sharding.spec}\n"
)

with te.sharding.global_shard_guard(mesh_resource):
output, dgrad, wgrad = _test_fn(*local_operands)

print(
f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: "
+ f"{output.shape} | {output.sharding.spec}\n"
+ f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n"
+ f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n"
)

_check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True)
30 changes: 15 additions & 15 deletions transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,39 +91,39 @@
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
py::class_<transformer_engine::CommOverlapCore, \
pybind11::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local()) \
.def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \
py::call_guard<py::gil_scoped_release>()) \
.def(pybind11::init([]() { return new transformer_engine::CommOverlapCore(); }), \
pybind11::call_guard<pybind11::gil_scoped_release>()) \
.def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \
py::call_guard<py::gil_scoped_release>()) \
pybind11::call_guard<pybind11::gil_scoped_release>()) \
.def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \
py::call_guard<py::gil_scoped_release>()) \
pybind11::call_guard<pybind11::gil_scoped_release>()) \
.def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \
py::call_guard<py::gil_scoped_release>()); \
py::class_<transformer_engine::CommOverlapBase, \
pybind11::call_guard<pybind11::gil_scoped_release>()); \
pybind11::class_<transformer_engine::CommOverlapBase, \
std::shared_ptr<transformer_engine::CommOverlapBase>, \
transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \
.def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \
py::call_guard<py::gil_scoped_release>()); \
py::class_<transformer_engine::CommOverlapP2PBase, \
.def(pybind11::init([]() { return new transformer_engine::CommOverlapBase(); }), \
pybind11::call_guard<pybind11::gil_scoped_release>()); \
pybind11::class_<transformer_engine::CommOverlapP2PBase, \
std::shared_ptr<transformer_engine::CommOverlapP2PBase>, \
transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \
pybind11::module_local()) \
.def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \
py::call_guard<py::gil_scoped_release>()); \
.def(pybind11::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \
pybind11::call_guard<pybind11::gil_scoped_release>()); \
m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
pybind11::call_guard<pybind11::gil_scoped_release>(), pybind11::arg("device_id") = -1); \
m.def( \
"get_stream_priority_range", \
[](int device_id = -1) { \
int low_pri, high_pri; \
transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \
return std::make_pair(low_pri, high_pri); \
}, \
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
pybind11::call_guard<pybind11::gil_scoped_release>(), pybind11::arg("device_id") = -1); \
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
py::call_guard<py::gil_scoped_release>());
pybind11::call_guard<pybind11::gil_scoped_release>());

#endif
7 changes: 4 additions & 3 deletions transformer_engine/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension

module_name = "transformer_engine_jax"

def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_jax"

if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
Expand Down Expand Up @@ -79,7 +79,9 @@ def _load_library():
spec.loader.exec_module(solib)


_load_library()
if module_name not in sys.modules:
_load_library()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
Any reasons for these changes?

from . import flax
from . import quantize

Expand All @@ -101,7 +103,6 @@ def _load_library():
)

__all__ = [
"fp8_autocast",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do need to export fp8_autocast.

"MeshResource",
"MajorShardingType",
"ShardingResource",
Expand Down
Loading