Skip to content

[JAX] GEMM custom op #1855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cf1774c
added XLA FFI custom op for TE/common nvte_cublas_gemm
denera Jun 4, 2025
da0709a
minor unit test cleanup
denera Jun 13, 2025
e5b933c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2025
92dec51
FP8 tests passing on Blackwell but MXFP8 outputs NaN
denera Jun 13, 2025
50d319b
Merge branch 'jax/nvte-cublas-gemm-op' of github.com:denera/Transform…
denera Jun 13, 2025
9eba586
reverted dense and fuseddense changes, FP8 test passing on Hopper and…
denera Jun 14, 2025
b80e284
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2025
a7aa2f4
MXFP8 issue traced to scale factor padding with NaNs instead of zeros
denera Jun 17, 2025
1be8773
padding scale with 2^-127 instead of nans
phu0ngng Jun 17, 2025
75008de
fix bug on rhs_scale_inv usage
phu0ngng Jun 17, 2025
5b0c1f5
cleanup E8M0 type converter use it in gemm.cpp
phu0ngng Jun 17, 2025
b49d586
segfault fixed, passing all unittests on Blackwell
denera Jun 18, 2025
b760460
merge with main
phu0ngng Jun 18, 2025
bd9bca3
fix for fuseddense tests
phu0ngng Jun 18, 2025
8fcb1bb
fix workspace alignment
phu0ngng Jun 18, 2025
b2b4159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2025
ae4828c
fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul
denera Jun 18, 2025
17d7a51
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jun 24, 2025
ddaaab9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2025
44e5b81
moved reshape of encoder output in encoder examples to make custom pa…
denera Jun 25, 2025
a281c97
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jun 25, 2025
b8ca0b1
added helper functions for padding and unpadding block scales, change…
denera Jun 27, 2025
3ee96ba
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jun 27, 2025
7187582
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2025
0230a5e
updated shardy rules for all custom ops to decouple block scale rules…
denera Jul 8, 2025
dedf5e9
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jul 8, 2025
875f401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2025
995fb11
fixed linting errors
denera Jul 9, 2025
cb3613d
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jul 9, 2025
d9e55a9
changed unit test use_jax_gemm option to be a context to preserve ext…
denera Jul 9, 2025
9c2a56c
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jul 9, 2025
e850ab5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2025
4db82a0
fixed typo in test utils
denera Jul 9, 2025
14da5c8
added sequence-first input warnings
denera Jul 9, 2025
d5cb233
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2025
8cbe0e2
fixed datasets version for JAX examples
denera Jul 10, 2025
3c82160
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jul 10, 2025
66eab76
reverting modification to force_1x_quantization decision
denera Jul 10, 2025
9781ebf
corrected gemm function syntax in unit tests
denera Jul 11, 2025
bb174bb
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jul 11, 2025
8532ad0
Merge remote-tracking branch 'upstream/main' into jax/nvte-cublas-gem…
denera Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
LOG_FILE="${TEST_CASE}_gpu_${i}.log"

# Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
Expand Down
48 changes: 36 additions & 12 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
assert_params_sufficiently_sharded,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode

Expand Down Expand Up @@ -465,37 +466,37 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

def setUp(self):
"""Run 3 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"])
"""Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "5"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
Expand All @@ -504,7 +505,7 @@ def test_te_delayed_scaling_fp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
Expand All @@ -513,14 +514,14 @@ def test_te_mxfp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
Expand All @@ -529,7 +530,7 @@ def test_te_delayed_scaling_fp8_shardy(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
Expand All @@ -539,9 +540,32 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80

# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80


if __name__ == "__main__":
Expand Down
33 changes: 22 additions & 11 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode

Expand Down Expand Up @@ -430,45 +431,45 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

def setUp(self):
"""Run 3 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"])
"""Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "6"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
Expand All @@ -477,9 +478,7 @@ def test_te_delayed_scaling_fp8_shardy(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73

# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
Expand All @@ -488,7 +487,19 @@ def test_te_current_scaling_fp8_shardy(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75


if __name__ == "__main__":
Expand Down
33 changes: 21 additions & 12 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
get_fp8_recipe_from_name_string,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
Expand Down Expand Up @@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing"""
args = encoder_parser([])
"""Run 5 epochs for testing"""
args = encoder_parser(["--epochs", "5"])

num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
Expand All @@ -607,55 +607,64 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None)
assert result[0] < 0.505 and result[1] > 0.755
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.506 and result[1] > 0.753
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.506 and result[1] > 0.753

# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753
assert result[0] < 0.43 and result[1] > 0.80

@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80


if __name__ == "__main__":
Expand Down
Loading