Skip to content

[JAX] GEMM custom op #1855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open

Conversation

denera
Copy link
Collaborator

@denera denera commented Jun 6, 2025

Description

XLA custom op + JAX primitive for nvte_cublas_gemm. Custom partitioning rules closely follow jax.nn.scaled_matmul and the sharded implementation performs jax.lax.psum or jax.lax.psum_scatter under the same conditions. Also implements Shardy rules in preparation of JAX switching default partitioner to Shardy in July.

IMPORTANT: Padded block scales are currently not compatible with Shardy because the padding breaks the size relationships we define with CompoundFactors in the sharding rules. Since Shardy rules are applied to the inner primitive, the only way to make these CompoundFactors work correctly is to lower unpadded scales into the XLA custom call and pad them in C++. This PR does not implement this approach. Instead, it removes the CompoundFactors entirely and de-couples the scaling factor rules from the related tensor rules. This should not cause any issues as long as the scales are sharded correctly outside the custom ops, but it would be safer in the long run to implement scale padding in C++ and restore CompoundFactor-based rules.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera force-pushed the jax/nvte-cublas-gemm-op branch 2 times, most recently from 1a3fdf3 to f571e0b Compare June 6, 2025 10:39
Comment on lines 666 to 707
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13, 14))
def _te_gemm_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, contracting_dims,
scaling_mode, lhs_scaled_colwise, rhs_scaled_colwise, fuse_bias, fuse_gelu,
grad, accumulate, use_split_accumulator):
return GemmPrimitive.outer_primitive.bind(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
contracting_dims=contracting_dims,
scaling_mode=scaling_mode,
lhs_scaled_colwise=lhs_scaled_colwise,
rhs_scaled_colwise=rhs_scaled_colwise,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)


def te_gemm_impl(
lhs: jax.Array,
rhs: jax.Array,
bias: jax.Array = None,
gelu_input: jax.Array = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1, ), (0, )),
scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
lhs_scaled_colwise: bool = False,
lhs_scale_inv: jax.Array = None,
rhs_scaled_colwise: bool = False,
rhs_scale_inv: jax.Array = None,
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
accumulate: bool = False,
use_split_accumulator: bool = False,
):
r"""
cuBLAS GEMM custom op.
Copy link
Collaborator

@phu0ngng phu0ngng Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
I noticed that you make a separate Python API te_gemm_impl().
Let's use the existing API gemm() and check if the Primitive is not enabled; then, we can go with JAX dot, as in other Primitives. You can have a look at the ActLuPrimitive for example https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions/activation.py#L1010-L1011.
Then we don't need a separate code path to call this custom op in VJPs, and no additional unit test either, as one can have a parameter with_te_gemm and set up an os.env with RegEx to disable the primitive in the other case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the custom TE path still needs its own ScaledTensor-based wrapper on top of the plain jax.Array-based custom op in order to expose options specific to our custom implementation (e.g. accumulation in FP8, bias and gelu fusions, forward/backward mode switch).

We can call this wrapper from inside gemm() in order to enable easy access to the GEMM custom op, but it will be important to have some way of doing more sophisticated things with this custom op especially when it comes to unifying communication overlap features down the line.

Copy link
Collaborator

@phu0ngng phu0ngng Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(e.g. accumulation in FP8, bias and gelu fusions, forward/backward mode switch).

I think these logics belong to the GEMM rather than the ScaledTensor, so we should not put them into the ScaledTensor.
We could have an additional config dict and pass it to gemm() as an optional argument or use the kwargs.

Copy link
Collaborator Author

@denera denera Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't suggesting we put it into ScaledTensor. I'm just saying that it will be convenient for there to be a dedicated te_gemm() wrapper API that supports ScaledTensor inputs because the GemmPrimitive is all based on jax.Arrays. We can then call te_gemm() from inside gemm() depending on an environment variable.

To be clear, the structure I have in mind is as follows:

  • GemmPrimitive with jax.Array inputs.
  • _te_gemm() wrapper with jax.Array inputs, decorated with jax.jit.
  • te_gemm() wrapper-on-wrapper that supports ScaledTensor inputs and calls _te_gemm() underneath.
  • Unified gemm() call that directs to either _jax_gemm() or te_gemm() depending on an environment variable.

We can merge _te_gemm and te_gemm together if JAX doesn't have any issues JITting with ScaledTensor inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Any reason why do we need to jit the _te_gemm() wrapper?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to have a jit decorator here but we still need an interface with hashable static arguments for it to be jittable in user code, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.

All our custom calls work with JIT in user code without separate interface with hashable static arguments.

if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Allocate new buffer for swizzled scaling factor
void *swizzled_scale_inv;
NVTE_CHECK_CUDA(cudaMallocAsync(&swizzled_scale_inv, product(scale_shape), stream));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use JAX's pooled allocator here instead of directly calling malloc? I'm unsure how PyTorch handles this. If JIT'd this may only be called once while making the CUDA graph so allocation perf may be less of an issue, but I'd think using JAX's allocator would still be preferred for managing memory limits, etc.

We could get JAX to allocate scratch space for these by adding additional unused outputs on the inner primitive, similar to our cuBLAS workspace output on the primitive.

Copy link
Collaborator

@phu0ngng phu0ngng Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could get JAX to allocate scratch space for these by adding additional unused outputs on the inner primitive, similar to our cuBLAS workspace output on the primitive.

We should definitely. We have already done this in the grouped_gemm PR (https://github.com/phu0ngng/TransformerEngine/blob/grouped_gemm/transformer_engine/jax/cpp_extensions/gemm.py#L107-L109)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right, we routinely do this for cuBLAS workspaces so it was definitely an oversight that I didn't do it here. Thanks for catching that!

@ptrendx ptrendx added the 2.5.0 label Jun 6, 2025
@@ -305,10 +316,508 @@ def gemm(
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
"""
if bool(int(os.getenv("NVTE_JAX_WITH_TE_GEMM", "0"))):
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding an env check! This can be simplified with some primitive logic we already have in place:

    if not GemmPrimitive.enabled():
        return te_gemm(...)
   return __jax_gem(...)

All our primitives inherit from BasePrimitive in transformerengine/transformer_engine/jax/cpp_extensions/base.py, which has the following check that lets us control each primitive to use a TE custom calls vs. TE pure-JAX implementations with a single environment variable with a regex.

class BasePrimitive:
...
    @classmethod
    def enabled(cls):
        """
        A custom call is marked as disabled if the `cls.__name__` does not fully match the
        `NVTE_JAX_CUSTOM_CALLS_RE` pattern.
        This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
        By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
        For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
        """
        pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
        pattern = re.compile(pattern)
        is_enabled = pattern.fullmatch(cls.__name__) is not None
        return is_enabled

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way around ;)

    if not GemmPrimitive.enabled():
        return __jax_gemm(...)
    return te_gemm(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks!

@denera denera force-pushed the jax/nvte-cublas-gemm-op branch from 913ac52 to c2c88d8 Compare June 12, 2025 18:00
@phu0ngng
Copy link
Collaborator

Hi,

I suggest adding an option in the unit test to test both jax_gemm and te_gemm.

One should be able to have a parameter with_te_gemm as,

if not with_te_gemm:
   os.environ['NVTE_JAX_CUSTOM_CALLS_RE']='^(?!GemmPrimitive$).+$'

Please trigger CI afterward.

@@ -48,7 +55,7 @@ def dense(
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set:
if quantizer_set == noop_quantizer_set and not GemmPrimitive.enabled():
Copy link
Collaborator

@phu0ngng phu0ngng Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I would prefer not to use the Primitive outside of its Python binding function.

Besides, we are planning to add NoopQuantizer and Tensor classes to wrap non-FP8 inputs soon-ish. After this work is done, we should be able to have the same VJP code path for all inputs.

Since we don't offer Collective Gemm yet, I don't see a reason to use the GEMM custom call for BF16 now. But if there is a need for it, I suggest adding a helper function in cpp_extensions/gemm.py and importing the helper here instead.

Copy link
Collaborator Author

@denera denera Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, we are planning to add NoopQuantizer and Tensor classes to wrap non-FP8 inputs soon-ish. After this work is done, we should be able to have the same VJP code path for all inputs.

This PR effectively implements the equivalent of this behavior in order to get GemmPrimitive and JAX dot seamlessly working with the same dense and fused-dense FWD/BWD functions. It's bit a of a shortcut -- normalization, activation and quantization calls all have a new kwarg that enables wrapping the jax.Array output in a ScaledTensor2x when there is no quantizer. A NoopQuantizer would have been a nicer looking approach but I did not want to add more bloat to this PR.

Since we don't offer Collective Gemm yet, I don't see a reason to use the GEMM custom call for BF16 now

Unfortunately, if we don't avoid this BF16 code path when the custom GEMM call is enabled, the dense tests fail because the custom GEMM call doesn't have its own separate FWD/BWD implementation. The dense() call implements exactly that, so it would be a redundant to do it all over again in the cpp_extensions.

I added a cached helper function though to avoid directly calling GemmPrimitive.enabled() outside cpp_extensions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's bit a of a shortcut -- normalization, activation and quantization calls all have a new kwarg that enables wrapping the jax.Array output in a ScaledTensor2x when there is no quantizer. A NoopQuantizer would have been a nicer looking approach but I did not want to add more bloat to this PR.

I would prefer not to have unnecessary additional kwargs and avoid changing the API unnecessarily.
In our case, the quantizer should be able to carry the scaling mode, which indicates whether it is a noop_scaled_tensor.
We can leave them as they are for now and refactor them when we have NoopQuantizer in the future.

Besides, with our OO design, the scaling mode should carry enough info for usage. I found it not very nice that in the ScaledTensor.post_init, we hardcode the size there.

if self.scaling_mode == ScalingMode.NO_SCALING:
self.scale_inv = jnp.empty((1,), dtype=jnp.float32)

Comment on lines 304 to 315
assert self.rowwise_tensor.dtype == self.colwise_tensor.dtype, (
"Row-wise and column-wise pair of `ScaledTensor1x`s forming this `ScaledTensor2x` have "
"different quantized data types."
)
assert self.rowwise_tensor.dq_dtype == self.colwise_tensor.dq_dtype, (
"Row-wise and column-wise pair of `ScaledTensor1x`s forming this `ScaledTensor2x` have "
"different de-quantized data types."
)
assert self.rowwise_tensor.scaling_mode == self.colwise_tensor.scaling_mode, (
"Row-wise and column-wise pair of `ScaledTensor1x`s forming this `ScaledTensor2x` have "
"different scaling modes."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we want to support DeepSeek and other possible Mixed-ScalingMode recipes, I think we should not enforce these.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point. I added these alongside the new ndim, dtype, shape, etc. properties in ScaledTensor2x that primarily pluck the information out of the rowwise tensor, so that there's at least a one-time sanity check of equivalency against the colwise tensor. In practice though, ndim is the only essential property used in the final working code so the rest is just bloat. I'll get rid of it in the next commit.

denera added 2 commits June 13, 2025 04:55
Signed-off-by: Alp Dener <adener@nvidia.com>

started GemmPrimitive, abstract done

Signed-off-by: Alp Dener <adener@nvidia.com>

gemm custom op working with BF16, needs testing for FP8/MXFP8

Signed-off-by: Alp Dener <adener@nvidia.com>

converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call

Signed-off-by: Alp Dener <adener@nvidia.com>

BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue

Signed-off-by: Alp Dener <adener@nvidia.com>

fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2

Signed-off-by: Alp Dener <adener@nvidia.com>

updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet

Signed-off-by: Alp Dener <adener@nvidia.com>

new GemmPrimitive passing all Dense tests

Signed-off-by: Alp Dener <adener@nvidia.com>

import cleanup and reverted code chunk movement

Signed-off-by: Alp Dener <adener@nvidia.com>

removed unused .transpose() implementations from ScaledTensors

Signed-off-by: Alp Dener <adener@nvidia.com>

all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl

Signed-off-by: Alp Dener <adener@nvidia.com>

removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions

Signed-off-by: Alp Dener <adener@nvidia.com>

removed unused changes to ScaledTensor classes and debug prints

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/nvte-cublas-gemm-op branch from 2fd4203 to da0709a Compare June 13, 2025 05:02
@denera
Copy link
Collaborator Author

denera commented Jun 13, 2025

/te-ci jax

denera and others added 9 commits June 13, 2025 06:55
Signed-off-by: Alp Dener <adener@nvidia.com>
… Blackwell, MXFP8 has issues with E5M2

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Collaborator Author

denera commented Jun 18, 2025

/te-ci JAX L0

phu0ngng added 2 commits June 18, 2025 05:54
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.8
Copy link
Collaborator

@phu0ngng phu0ngng Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
I'm fine with the 2-digit training loss, but why 1 digit for training accuracy?

Copy link
Collaborator Author

@denera denera Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my testing, the loss at epoch 5 drops to ~0.42 and accuracy rises to ~0.81. The thresholds are set to check that with a +0.01 and -0.01 tolerance for loss and accuracy respectively. Would it be more descriptive if we just wrote 0.80 here instead of 0.8? Alternatively, do you prefer a tighter tolerance and if so, what's the appropriate value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then we could put 0.80, just to make it clear.

@@ -57,6 +58,7 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension.
Copy link
Collaborator

@phu0ngng phu0ngng Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, why/which case do we need this batch_first argument?

Copy link
Collaborator Author

@denera denera Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need batch_first here in order to be able to invoke gemm() with a complete dimension_numbers argument that includes the correct batch dimensions. And the answer to why we need dimension_numbers API in gemm() is in your next question below. :)

Our Flax modules already havetranspose_batch_sequence options, presumably to match native Flax API. This was not being used before, but it is used now and propagated down to the underlying function implementations with the new batch_first argument.

As a bonus, I personally like that it aligns our gemm API with native JAX dot_general too, though that's not the motivation behind the change.

Dtype: Same as input dtype
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())),
Copy link
Collaborator

@phu0ngng phu0ngng Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need to expose dimension_numbers in the API instead of contracting_dims.
I don't think we ever need to do batched GEMM in TE, thus the batch_dims is always empty, i.e. ((), ()).

Copy link
Collaborator Author

@denera denera Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to make a decision in GemmPrimitive.partition() about whether the _sharded_impl() wrapper needs to execute jax.lax.psum() or jax.lax.psum_scatter() on the GEMM output. This decision partially depends on whether the non-batched contracting dimensions are sharded, which then requires us to differentiate between batched and non-batched contracting dimensions. If we don't separate these, then we would end up incorrectly triggering a data-parallel reduction in the layer's backward pass instead of letting the data-parallel reduction happen in the optimizer, after the model's backward pass is finished.

jax.nn.scaled_matmul() does the exact same thing too, but it assumes normalized 3D operands that are always in the (batch, leading, contracting) layout, so it does not need the user to give it the dimension numbers. It already has the information. GemmPrimitive is written to handle arbitrary dimensional operands in arbitrary layouts, so there's no way for us to infer the contracting and batch dimensions automatically like jax.nn.scaled_matmul(). We need it to be given to us in a complete dimension_numbers API.

When I was implementing this, I wondered how jax.lax.dot_general() handles this correctly without being given the batch dimensions. I don't have a concrete answer, but my guess is that it does not actually handle it correctly entirely on its own. I think it inserts a reduction op for the GEMM output whenever there is any sharded contracting dimension (batched or not, doesn't matter), and then relies on XLA to delay the reduction until the optimizer's parameter update. This would work only because jax.lax.dot_general() is a native XLA operation. Custom op primitives do not get the benefit of this so the reduction would be done where it is originally inserted.

Copy link
Collaborator

@phu0ngng phu0ngng Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I was quite confused and I think we are mixing two concepts here.

In the jax.lax.dot_general(), the batch_dims in the dimension_numbers = (contract_dims, batch_dims) are the batch dimensions in the BatchedGEMM.

Here, you want an argument to indicate which dimensions in the inputs are the batch dimensions, but we don't mean to do any BatchedGEMM, which does not share the same purpose as the batch_dims in the dimension_numbers of the jax.lax.dot_general().

I would suggest using a separate argument, i.e. input_batch_dims, to avoid confusion.

Copy link
Collaborator

@phu0ngng phu0ngng Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GemmPrimitive is written to handle arbitrary-dimensional operands in arbitrary layouts.

Just a minor note here. Our dense VJP and DenseGeneral support arbitrary operands' dimensions but do not support any arbitrary layouts, i.e. only NN is supported. We don't have any plans to support that in the near future.

If someone uses batch_first = False, I think all the FP8 GEMM will proceed incorrectly.

Copy link
Collaborator Author

@denera denera Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using a separate argument, i.e. input_batch_dims, to avoid confusion.

That's a fair point about BatchedGEMM that I did not think about. I will separate this into contracting_dims and batch_dims kwargs, and document in the docstring that these batched dimensions are not for performing actual batched matrix-multiply.

Just a minor note here. Our dense VJP and DenseGeneral support arbitrary operands' dimensions but do not support any arbitrary layouts, i.e. only NN is supported. We don't have any plans to support that in the near future.

Yes, dense does not support arbitrary layouts and dded an explicit check in the function for this.

However, GemmPrimitive independently does support arbitrary layouts even though we don't necessarily use all available layouts everywhere in TE/JAX. The custom op is effectively written to do an actual general matrix multiplication without any assumptions about specific use-cases that occur in TE.

If someone uses batch_first = False, I think all the FP8 GEMM will proceed incorrectly.

This may be true for the native-JAX implementations of our GEMM call, but the new custom op works correctly. We don't test for it in our CI pipelines but I did confirm this manually, at least for direct gemm() calls (not through layers).

Tensor scaling is inherently agnostic to logical axis layouts, and the correctness of block scales depends on using the correct flatten_axis for any given logical layout. GemmPrimitive independently infers this flatten_axis from the contracting dimensions, and the underlying XLA custom call uses it to construct a flattened 2D TensorWrapper view of the XLA buffers. So as long as the flatten_axis used to quantize a tensor and the flatten_axis we infer from the contracting dimensions in the GEMM custom op are consistent with each other, the operation supports either batch- or sequence-first inputs.

I did not test whether the fused ops and Flax modules work with sequence-first inputs, but at least at first glance, I don't see anything there that would break as long as the GEMM works (which does with the custom op). The flatten_axis we compute for quantization in these fused ops all look consistent to me with contracting dimensions.

Maybe not in this PR but we should consider eventually adding sequence-first input tests to our CI. Until then, we can perhaps keep the transpose_batch_sequence option in the Flax modules and batch_first option in the underlying fused ops, but either issue a warning or raise a NotImplementedError to make it clear that we do not officially support sequence-first inputs (yet). What do you think?

Copy link
Collaborator

@phu0ngng phu0ngng Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone uses batch_first = False, I think all the FP8 GEMM will proceed incorrectly.

To make it clear, the dense VJP with FP8 on Hopper and MXFP8 on Blackwell won't work correctly when batch_first = False, even if the GEMM custom call can handle it, as the quantization layouts will not be correct.

Right now, the quantization layouts are decided based on the recipes and GPU arch with the assumption that one only does NN in the forward. For example, we always have rowwise for x and colwise for kernel in the MXFP8 recipe. When batch_first = False and with these quantize layouts, the inputs are not quantized along the contracting dimension, thus producing wrong results.

I would suggest we keep this argument in the tex.gemm() but do not expose it in the dense() VJP API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it clear, the dense VJP with FP8 on Hopper and MXFP8 on Blackwell won't work correctly when batch_first = False as the quantization layouts will not be correct. Right now, the quantization layouts are decided based on the recipes and GPU arch with the assumption that one only does NN in the forward.

Sure, but whether you have a batch- or sequence-first input does not change the GEMM layout. For example, a batch-first QKV projection like (B, S, H) x (H, 3H) = (B, S, 3H) and a sequence first QKV projection like (S, B, H) x (H, 3H) = (S, B, 3H) are both NN-layout GEMM operations (and quantizing it for FP8 on Hopper would transpose the kernel to comply with the NT-layout restriction).

Neither the transpose_batch_sequence option in our TE/Flax modules nor the batch_first option in the underlying fused ops imply complete freedom about where the batch dimension is. They simply tell the module or fused op whether the input tensor should be interpreted as (B, S, H) or (S, B, H).

In layernorm_dense() and layernorm_mlp(), we don't support user-chosen contracting dimensions, require 2D or 3D inputs, and in the case of 3D inputs, choose the batch index as either 0 or 1 depending on batch_first. The contracting dimension here is always 2 for the input and 0 for the kernel, thus maintaining an NN layout GEMM in the forward pass (with kernel transposed during quantization for NT-layout FP8 on Hopper).

In the case of the dense() op, we allow the user to specify contracting dimensions and support arbitrary dimension inputs, but enforce them to comply with NN layout. We consider the input to be batched only if the input # of dimensions is +2 more than the # of contracting dimensions, and then choose the batch index from the non-contracting dimensions based on what batch_first is set to.

I would suggest we keep this argument in the tex.gemm() but do not expose it in the dense() VJP API.

I think we should expose it simply because te.flax.DenseGeneral() already exposes a transpose_batch_sequence option that predates this PR. Even if batch_first=False does not work for the underlying fused ops (and I still don't see why it wouldn't work), we should still pass this down through the Flax module and then raise a warning/error inside the fused op for any unsupported behavior. Otherwise, we are just letting it fail silently, which is the worst possible choice.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. In this case, I agree that we should raise a warning for now and work on the transpose_batch_sequence in the following PR.

@denera
Copy link
Collaborator Author

denera commented Jul 9, 2025

/te-ci JAX L0 L1


@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.53 and actual[1] > 0.74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With two extra epochs I'd expect the loss to improved more than -0.005. Can you run this with the TE custom op gemm disabled to see if the loss and accuracy are in the same range with the JAX gemms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tolerances were incorrectly leftover from before I updated the epochs to 5. I updated them after checking again with both native JAX GEMM and our custom GEMM. Thanks for catching it!

import os

if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a context instead that stores the previous value of NVTE_JAX_CUSTOM_CALLS_RE and restores it once we exit the context?

If we want to test the whole suite with some TE cusotm ops disabled with NVTE_JAX_CUSTOM_CALLS_RE, this function will clear that filter for all tests that run after this function is called

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I will update this to a context in all the tests.

)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)

assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the dtype here updated to a hardcoded e4m3 because we know that one of the operands to the gemm will always be e4m3 as e5m2 X e5m2 isn't supported?

Copy link
Collaborator Author

@denera denera Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a plain GEMM test with no FWD/BWD mechanism, I hardcoded e4m3 simply because it results in a tighter tolerance. I wanted to test the pure GEMM result a bit more stringently than the layers. Otherwise, the general principle in tests involving FWD/BWD through layers is to test forward outputs with e4m3 tolerances and gradients with e5m2.

assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these dtypes use bwd_dtype above instead of hardcoded to e5m2? The quantizer set above will use e4m3 as the bwd dtype when we have MXFP8

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for tests below where grad assertion is always e5m2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forward output is hardcoded to e4m3 but gradients are hardcoded to e5m2 as I mentioned above, to align test tolerances with FWD/BWD FP8 types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the dtype of the gradients is not always e5m2, right? In the logic above, there is bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,. So for MXFP8 this will compute gradients in e4m3 but have a tolerance from e5m2, right?

import os

if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here as the _use_jax_fp8_gemm defined in test_custom_call_compute.py

@@ -666,7 +668,7 @@ def _quantize_dbias_impl(
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
if force_1x_quantization:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this line changed in this PR?

Given we mostly if not always force 1x quantization, I think in practice it does give the same result currently.

However, if we ever go back to 2x quantization this will be wrong. Even with 2x quantization, the scale buffers need to be linked, because the TE custom call only writes to one of the scale buffers for tensor scaling (since rowwise and colwise scale will be identical), so we should keep the original check here quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()

Was the original check causing an issue for you that needed fixing in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 642, I set force_1x_quantization = (quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() and is_1x_kernel_supported) so functionally it is still performing the same check, with an additional safeguard on whether the TE/common kernel supports 1x.

The original check without this safeguard was attempting to force 1x quantization with the fused quantize_dbias kernel, which is not supported on Hopper. In those situations, we want to follow the should_apply_1x_fused_dbias_war_for_arch_l_100() code path instead to compute dbias via the un-fused native-JAX implementation and invoke the TE/common quantize kernel without dbias fusion.

This only became an issue when dense() with the new GemmPrimitive() had to execute the custom FWD/BWD with BF16 inputs too instead of letting JAX natively differentiate its own jax.lax.dot_general(). The TE/JAX quantize calls needed to return dummy ScaledTensors when the quantizer did not exist or the scaling mode was set to NO_SCALING. That change exposed a particular case of forcing 1x without considering if it's supported by the underlying kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the unsupported kernel that wasn't caught in our testing, that's a good change to improve this.

But separate from whether we force 1x quantization or not, we need to populate colwise scales with the rowwise scale outputs for tensor scaling. With this code change, if we find out 2x quantization is more performant on certain setups, we would no longer be able to just disable force_1x_quantization, it'd break.

The original less restrictive check of quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() is required here. TE common never writes the colwise scale for tensor scaling. It's only an output on our JAX primitive for MXFP8 support.

From cast_transpose.cu, the kernel only takes a single output scale_inv_ptr to write to, which is the rowwise one

template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
    const IType *__restrict__ const input, const CType *__restrict__ const noop,
    OType *__restrict__ const output_c, OType *__restrict__ const output_t,
    const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
    CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {

The original check covers the two cases:

if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
    colwise_scale_inv = rowwise_scale_inv
  1. We want to do 2x quantization. In this case the TE kernel populates only the rowwise_scale_inv as output. We set the colwwise_scale_inv to the rowwise_scale_inv so it is correct
  2. 2x quantization was requested (quantizer.is_2x2x()) but we force 1x quantization. This still covers this case because even if we update the local variable q_layout to only ROWWISE, the original quantizer.is_2x2x() stays unchanged and we set the colwise_scale_inv correctly.

The updated check in this PR only covers case 2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To recap a conversation off GitHub about this, we've determined that the should_apply_1x_fused_dbias_war_for_arch_l_100() check further up the function was already modified in this PR to account for the failure case of fused-quantize-dbias on Hopper not supporting 1x quantization, so I'm reverting the change with force_1x_quantization later in the function.

denera and others added 3 commits July 9, 2025 19:37
…ernal custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/nvte-cublas-gemm-op branch from 663a574 to 4db82a0 Compare July 9, 2025 20:16
@denera denera added the 2.6.0 label Jul 9, 2025
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Collaborator Author

denera commented Jul 9, 2025

/te-ci JAX L0 L1

@denera
Copy link
Collaborator Author

denera commented Jul 10, 2025

/te-ci JAX L0 L1

@denera
Copy link
Collaborator Author

denera commented Jul 11, 2025

/te-ci JAX L0 L1

phu0ngng
phu0ngng previously approved these changes Jul 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants