-
Notifications
You must be signed in to change notification settings - Fork 451
[JAX] GEMM custom op #1855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] GEMM custom op #1855
Conversation
1a3fdf3
to
f571e0b
Compare
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13, 14)) | ||
def _te_gemm_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, contracting_dims, | ||
scaling_mode, lhs_scaled_colwise, rhs_scaled_colwise, fuse_bias, fuse_gelu, | ||
grad, accumulate, use_split_accumulator): | ||
return GemmPrimitive.outer_primitive.bind( | ||
lhs, | ||
lhs_scale_inv, | ||
rhs, | ||
rhs_scale_inv, | ||
bias, | ||
gelu_input, | ||
contracting_dims=contracting_dims, | ||
scaling_mode=scaling_mode, | ||
lhs_scaled_colwise=lhs_scaled_colwise, | ||
rhs_scaled_colwise=rhs_scaled_colwise, | ||
fuse_bias=fuse_bias, | ||
fuse_gelu=fuse_gelu, | ||
grad=grad, | ||
accumulate=accumulate, | ||
use_split_accumulator=use_split_accumulator, | ||
) | ||
|
||
|
||
def te_gemm_impl( | ||
lhs: jax.Array, | ||
rhs: jax.Array, | ||
bias: jax.Array = None, | ||
gelu_input: jax.Array = None, | ||
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1, ), (0, )), | ||
scaling_mode: ScalingMode = ScalingMode.NO_SCALING, | ||
lhs_scaled_colwise: bool = False, | ||
lhs_scale_inv: jax.Array = None, | ||
rhs_scaled_colwise: bool = False, | ||
rhs_scale_inv: jax.Array = None, | ||
fuse_bias: bool = False, | ||
fuse_gelu: bool = False, | ||
grad: bool = False, | ||
accumulate: bool = False, | ||
use_split_accumulator: bool = False, | ||
): | ||
r""" | ||
cuBLAS GEMM custom op. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I noticed that you make a separate Python API te_gemm_impl()
.
Let's use the existing API gemm()
and check if the Primitive is not enabled; then, we can go with JAX dot, as in other Primitives. You can have a look at the ActLuPrimitive for example https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions/activation.py#L1010-L1011.
Then we don't need a separate code path to call this custom op in VJPs, and no additional unit test either, as one can have a parameter with_te_gemm
and set up an os.env
with RegEx to disable the primitive in the other case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the custom TE path still needs its own ScaledTensor
-based wrapper on top of the plain jax.Array
-based custom op in order to expose options specific to our custom implementation (e.g. accumulation in FP8, bias and gelu fusions, forward/backward mode switch).
We can call this wrapper from inside gemm()
in order to enable easy access to the GEMM custom op, but it will be important to have some way of doing more sophisticated things with this custom op especially when it comes to unifying communication overlap features down the line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(e.g. accumulation in FP8, bias and gelu fusions, forward/backward mode switch).
I think these logics belong to the GEMM rather than the ScaledTensor, so we should not put them into the ScaledTensor.
We could have an additional config dict and pass it to gemm()
as an optional argument or use the kwargs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't suggesting we put it into ScaledTensor
. I'm just saying that it will be convenient for there to be a dedicated te_gemm()
wrapper API that supports ScaledTensor
inputs because the GemmPrimitive is all based on jax.Array
s. We can then call te_gemm()
from inside gemm()
depending on an environment variable.
To be clear, the structure I have in mind is as follows:
GemmPrimitive
withjax.Array
inputs._te_gemm()
wrapper withjax.Array
inputs, decorated withjax.jit
.te_gemm()
wrapper-on-wrapper that supportsScaledTensor
inputs and calls_te_gemm()
underneath.- Unified
gemm()
call that directs to either_jax_gemm()
orte_gemm()
depending on an environment variable.
We can merge _te_gemm and te_gemm together if JAX doesn't have any issues JITting with ScaledTensor inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Any reason why do we need to jit the _te_gemm()
wrapper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have to have a jit decorator here but we still need an interface with hashable static arguments for it to be jittable in user code, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so.
All our custom calls work with JIT in user code without separate interface with hashable static arguments.
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { | ||
// Allocate new buffer for swizzled scaling factor | ||
void *swizzled_scale_inv; | ||
NVTE_CHECK_CUDA(cudaMallocAsync(&swizzled_scale_inv, product(scale_shape), stream)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to use JAX's pooled allocator here instead of directly calling malloc? I'm unsure how PyTorch handles this. If JIT'd this may only be called once while making the CUDA graph so allocation perf may be less of an issue, but I'd think using JAX's allocator would still be preferred for managing memory limits, etc.
We could get JAX to allocate scratch space for these by adding additional unused outputs on the inner primitive, similar to our cuBLAS workspace output on the primitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could get JAX to allocate scratch space for these by adding additional unused outputs on the inner primitive, similar to our cuBLAS workspace output on the primitive.
We should definitely. We have already done this in the grouped_gemm
PR (https://github.com/phu0ngng/TransformerEngine/blob/grouped_gemm/transformer_engine/jax/cpp_extensions/gemm.py#L107-L109)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you're right, we routinely do this for cuBLAS workspaces so it was definitely an oversight that I didn't do it here. Thanks for catching that!
@@ -305,10 +316,508 @@ def gemm( | |||
If quantizer_set is provided: | |||
A ScaledTensor containing the quantized matrix multiplication result. | |||
""" | |||
if bool(int(os.getenv("NVTE_JAX_WITH_TE_GEMM", "0"))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding an env check! This can be simplified with some primitive logic we already have in place:
if not GemmPrimitive.enabled():
return te_gemm(...)
return __jax_gem(...)
All our primitives inherit from BasePrimitive
in transformerengine/transformer_engine/jax/cpp_extensions/base.py
, which has the following check that lets us control each primitive to use a TE custom calls vs. TE pure-JAX implementations with a single environment variable with a regex.
class BasePrimitive:
...
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.__name__` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.__name__) is not None
return is_enabled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way around ;)
if not GemmPrimitive.enabled():
return __jax_gemm(...)
return te_gemm(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thanks!
913ac52
to
c2c88d8
Compare
Hi, I suggest adding an option in the unit test to test both One should be able to have a parameter
Please trigger CI afterward. |
transformer_engine/jax/dense.py
Outdated
@@ -48,7 +55,7 @@ def dense( | |||
Transformed output tensor | |||
""" | |||
# Remove when tex.quantize() can handle quantizer=None | |||
if quantizer_set == noop_quantizer_set: | |||
if quantizer_set == noop_quantizer_set and not GemmPrimitive.enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I would prefer not to use the Primitive outside of its Python binding function.
Besides, we are planning to add NoopQuantizer
and Tensor
classes to wrap non-FP8 inputs soon-ish. After this work is done, we should be able to have the same VJP code path for all inputs.
Since we don't offer Collective Gemm yet, I don't see a reason to use the GEMM custom call for BF16 now. But if there is a need for it, I suggest adding a helper function in cpp_extensions/gemm.py
and importing the helper here instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, we are planning to add NoopQuantizer and Tensor classes to wrap non-FP8 inputs soon-ish. After this work is done, we should be able to have the same VJP code path for all inputs.
This PR effectively implements the equivalent of this behavior in order to get GemmPrimitive and JAX dot seamlessly working with the same dense and fused-dense FWD/BWD functions. It's bit a of a shortcut -- normalization, activation and quantization calls all have a new kwarg that enables wrapping the jax.Array output in a ScaledTensor2x when there is no quantizer. A NoopQuantizer would have been a nicer looking approach but I did not want to add more bloat to this PR.
Since we don't offer Collective Gemm yet, I don't see a reason to use the GEMM custom call for BF16 now
Unfortunately, if we don't avoid this BF16 code path when the custom GEMM call is enabled, the dense tests fail because the custom GEMM call doesn't have its own separate FWD/BWD implementation. The dense() call implements exactly that, so it would be a redundant to do it all over again in the cpp_extensions.
I added a cached helper function though to avoid directly calling GemmPrimitive.enabled() outside cpp_extensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's bit a of a shortcut -- normalization, activation and quantization calls all have a new kwarg that enables wrapping the jax.Array output in a ScaledTensor2x when there is no quantizer. A NoopQuantizer would have been a nicer looking approach but I did not want to add more bloat to this PR.
I would prefer not to have unnecessary additional kwargs and avoid changing the API unnecessarily.
In our case, the quantizer should be able to carry the scaling mode, which indicates whether it is a noop_scaled_tensor
.
We can leave them as they are for now and refactor them when we have NoopQuantizer in the future.
Besides, with our OO design, the scaling mode should carry enough info for usage. I found it not very nice that in the ScaledTensor.post_init
, we hardcode the size there.
TransformerEngine/transformer_engine/jax/quantize/tensor.py
Lines 144 to 145 in 50d319b
if self.scaling_mode == ScalingMode.NO_SCALING: | |
self.scale_inv = jnp.empty((1,), dtype=jnp.float32) |
assert self.rowwise_tensor.dtype == self.colwise_tensor.dtype, ( | ||
"Row-wise and column-wise pair of `ScaledTensor1x`s forming this `ScaledTensor2x` have " | ||
"different quantized data types." | ||
) | ||
assert self.rowwise_tensor.dq_dtype == self.colwise_tensor.dq_dtype, ( | ||
"Row-wise and column-wise pair of `ScaledTensor1x`s forming this `ScaledTensor2x` have " | ||
"different de-quantized data types." | ||
) | ||
assert self.rowwise_tensor.scaling_mode == self.colwise_tensor.scaling_mode, ( | ||
"Row-wise and column-wise pair of `ScaledTensor1x`s forming this `ScaledTensor2x` have " | ||
"different scaling modes." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we want to support DeepSeek and other possible Mixed-ScalingMode recipes, I think we should not enforce these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point. I added these alongside the new ndim, dtype, shape, etc. properties in ScaledTensor2x that primarily pluck the information out of the rowwise tensor, so that there's at least a one-time sanity check of equivalency against the colwise tensor. In practice though, ndim is the only essential property used in the final working code so the rest is just bloat. I'll get rid of it in the next commit.
Signed-off-by: Alp Dener <adener@nvidia.com> started GemmPrimitive, abstract done Signed-off-by: Alp Dener <adener@nvidia.com> gemm custom op working with BF16, needs testing for FP8/MXFP8 Signed-off-by: Alp Dener <adener@nvidia.com> converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call Signed-off-by: Alp Dener <adener@nvidia.com> BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue Signed-off-by: Alp Dener <adener@nvidia.com> fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2 Signed-off-by: Alp Dener <adener@nvidia.com> updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet Signed-off-by: Alp Dener <adener@nvidia.com> new GemmPrimitive passing all Dense tests Signed-off-by: Alp Dener <adener@nvidia.com> import cleanup and reverted code chunk movement Signed-off-by: Alp Dener <adener@nvidia.com> removed unused .transpose() implementations from ScaledTensors Signed-off-by: Alp Dener <adener@nvidia.com> all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl Signed-off-by: Alp Dener <adener@nvidia.com> removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions Signed-off-by: Alp Dener <adener@nvidia.com> removed unused changes to ScaledTensor classes and debug prints Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
2fd4203
to
da0709a
Compare
for more information, see https://pre-commit.ci
/te-ci jax |
Signed-off-by: Alp Dener <adener@nvidia.com>
…erEngine into jax/nvte-cublas-gemm-op
… Blackwell, MXFP8 has issues with E5M2 Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
/te-ci JAX L0 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
||
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") | ||
def test_te_bf16(self): | ||
"""Test Transformer Engine with BF16""" | ||
actual = train_and_evaluate(self.args) | ||
assert actual[0] < 0.455 and actual[1] > 0.785 | ||
assert actual[0] < 0.43 and actual[1] > 0.8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I'm fine with the 2-digit training loss, but why 1 digit for training accuracy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my testing, the loss at epoch 5 drops to ~0.42 and accuracy rises to ~0.81. The thresholds are set to check that with a +0.01 and -0.01 tolerance for loss and accuracy respectively. Would it be more descriptive if we just wrote 0.80 here instead of 0.8? Alternatively, do you prefer a tighter tolerance and if so, what's the appropriate value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Then we could put 0.80, just to make it clear.
@@ -57,6 +58,7 @@ def layernorm_dense( | |||
layernorm_input_axes: Logical axes for sharding the layernorm input | |||
dot_input_axes: Logical axes for sharding the matrix multiplication input | |||
kernel_axes: Logical axes for sharding the weight matrix | |||
batch_first: Assume that X is batched in the first dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, why/which case do we need this batch_first
argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need batch_first
here in order to be able to invoke gemm()
with a complete dimension_numbers
argument that includes the correct batch dimensions. And the answer to why we need dimension_numbers
API in gemm()
is in your next question below. :)
Our Flax modules already havetranspose_batch_sequence
options, presumably to match native Flax API. This was not being used before, but it is used now and propagated down to the underlying function implementations with the new batch_first
argument.
As a bonus, I personally like that it aligns our gemm API with native JAX dot_general too, though that's not the motivation behind the change.
Dtype: Same as input dtype | ||
If quantizer_set is provided: | ||
A ScaledTensor containing the quantized matrix multiplication result. | ||
dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]]] = (((-1,), (0,)), ((), ())), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we need to expose dimension_numbers
in the API instead of contracting_dims
.
I don't think we ever need to do batched GEMM in TE, thus the batch_dims
is always empty, i.e. ((), ())
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to make a decision in GemmPrimitive.partition()
about whether the _sharded_impl()
wrapper needs to execute jax.lax.psum()
or jax.lax.psum_scatter()
on the GEMM output. This decision partially depends on whether the non-batched contracting dimensions are sharded, which then requires us to differentiate between batched and non-batched contracting dimensions. If we don't separate these, then we would end up incorrectly triggering a data-parallel reduction in the layer's backward pass instead of letting the data-parallel reduction happen in the optimizer, after the model's backward pass is finished.
jax.nn.scaled_matmul()
does the exact same thing too, but it assumes normalized 3D operands that are always in the (batch, leading, contracting)
layout, so it does not need the user to give it the dimension numbers. It already has the information. GemmPrimitive
is written to handle arbitrary dimensional operands in arbitrary layouts, so there's no way for us to infer the contracting and batch dimensions automatically like jax.nn.scaled_matmul()
. We need it to be given to us in a complete dimension_numbers
API.
When I was implementing this, I wondered how jax.lax.dot_general()
handles this correctly without being given the batch dimensions. I don't have a concrete answer, but my guess is that it does not actually handle it correctly entirely on its own. I think it inserts a reduction op for the GEMM output whenever there is any sharded contracting dimension (batched or not, doesn't matter), and then relies on XLA to delay the reduction until the optimizer's parameter update. This would work only because jax.lax.dot_general()
is a native XLA operation. Custom op primitives do not get the benefit of this so the reduction would be done where it is originally inserted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I was quite confused and I think we are mixing two concepts here.
In the jax.lax.dot_general()
, the batch_dims
in the dimension_numbers = (contract_dims, batch_dims)
are the batch dimensions in the BatchedGEMM.
Here, you want an argument to indicate which dimensions in the inputs are the batch dimensions, but we don't mean to do any BatchedGEMM, which does not share the same purpose as the batch_dims
in the dimension_numbers
of the jax.lax.dot_general()
.
I would suggest using a separate argument, i.e. input_batch_dims
, to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GemmPrimitive is written to handle arbitrary-dimensional operands in arbitrary layouts.
Just a minor note here. Our dense
VJP and DenseGeneral
support arbitrary operands' dimensions but do not support any arbitrary layouts, i.e. only NN is supported. We don't have any plans to support that in the near future.
If someone uses batch_first = False
, I think all the FP8 GEMM will proceed incorrectly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest using a separate argument, i.e. input_batch_dims, to avoid confusion.
That's a fair point about BatchedGEMM that I did not think about. I will separate this into contracting_dims
and batch_dims
kwargs, and document in the docstring that these batched dimensions are not for performing actual batched matrix-multiply.
Just a minor note here. Our dense VJP and DenseGeneral support arbitrary operands' dimensions but do not support any arbitrary layouts, i.e. only NN is supported. We don't have any plans to support that in the near future.
Yes, dense
does not support arbitrary layouts and dded an explicit check in the function for this.
However, GemmPrimitive
independently does support arbitrary layouts even though we don't necessarily use all available layouts everywhere in TE/JAX. The custom op is effectively written to do an actual general matrix multiplication without any assumptions about specific use-cases that occur in TE.
If someone uses batch_first = False, I think all the FP8 GEMM will proceed incorrectly.
This may be true for the native-JAX implementations of our GEMM call, but the new custom op works correctly. We don't test for it in our CI pipelines but I did confirm this manually, at least for direct gemm()
calls (not through layers).
Tensor scaling is inherently agnostic to logical axis layouts, and the correctness of block scales depends on using the correct flatten_axis
for any given logical layout. GemmPrimitive
independently infers this flatten_axis
from the contracting dimensions, and the underlying XLA custom call uses it to construct a flattened 2D TensorWrapper
view of the XLA buffers. So as long as the flatten_axis
used to quantize a tensor and the flatten_axis
we infer from the contracting dimensions in the GEMM custom op are consistent with each other, the operation supports either batch- or sequence-first inputs.
I did not test whether the fused ops and Flax modules work with sequence-first inputs, but at least at first glance, I don't see anything there that would break as long as the GEMM works (which does with the custom op). The flatten_axis
we compute for quantization in these fused ops all look consistent to me with contracting dimensions.
Maybe not in this PR but we should consider eventually adding sequence-first input tests to our CI. Until then, we can perhaps keep the transpose_batch_sequence
option in the Flax modules and batch_first
option in the underlying fused ops, but either issue a warning or raise a NotImplementedError
to make it clear that we do not officially support sequence-first inputs (yet). What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone uses batch_first = False, I think all the FP8 GEMM will proceed incorrectly.
To make it clear, the dense
VJP with FP8 on Hopper and MXFP8 on Blackwell won't work correctly when batch_first = False
, even if the GEMM custom call can handle it, as the quantization layouts will not be correct.
Right now, the quantization layouts are decided based on the recipes and GPU arch with the assumption that one only does NN in the forward. For example, we always have rowwise
for x
and colwise
for kernel
in the MXFP8 recipe. When batch_first = False
and with these quantize layouts, the inputs are not quantized along the contracting dimension, thus producing wrong results.
I would suggest we keep this argument in the tex.gemm()
but do not expose it in the dense()
VJP API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make it clear, the dense VJP with FP8 on Hopper and MXFP8 on Blackwell won't work correctly when batch_first = False as the quantization layouts will not be correct. Right now, the quantization layouts are decided based on the recipes and GPU arch with the assumption that one only does NN in the forward.
Sure, but whether you have a batch- or sequence-first input does not change the GEMM layout. For example, a batch-first QKV projection like (B, S, H) x (H, 3H) = (B, S, 3H)
and a sequence first QKV projection like (S, B, H) x (H, 3H) = (S, B, 3H)
are both NN-layout GEMM operations (and quantizing it for FP8 on Hopper would transpose the kernel to comply with the NT-layout restriction).
Neither the transpose_batch_sequence
option in our TE/Flax modules nor the batch_first
option in the underlying fused ops imply complete freedom about where the batch dimension is. They simply tell the module or fused op whether the input tensor should be interpreted as (B, S, H)
or (S, B, H)
.
In layernorm_dense()
and layernorm_mlp()
, we don't support user-chosen contracting dimensions, require 2D or 3D inputs, and in the case of 3D inputs, choose the batch index as either 0 or 1 depending on batch_first
. The contracting dimension here is always 2 for the input and 0 for the kernel, thus maintaining an NN layout GEMM in the forward pass (with kernel transposed during quantization for NT-layout FP8 on Hopper).
In the case of the dense()
op, we allow the user to specify contracting dimensions and support arbitrary dimension inputs, but enforce them to comply with NN layout. We consider the input to be batched only if the input # of dimensions is +2 more than the # of contracting dimensions, and then choose the batch index from the non-contracting dimensions based on what batch_first
is set to.
I would suggest we keep this argument in the tex.gemm() but do not expose it in the dense() VJP API.
I think we should expose it simply because te.flax.DenseGeneral()
already exposes a transpose_batch_sequence
option that predates this PR. Even if batch_first=False
does not work for the underlying fused ops (and I still don't see why it wouldn't work), we should still pass this down through the Flax module and then raise a warning/error inside the fused op for any unsupported behavior. Otherwise, we are just letting it fail silently, which is the worst possible choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. In this case, I agree that we should raise a warning for now and work on the transpose_batch_sequence
in the following PR.
Signed-off-by: Alp Dener <adener@nvidia.com>
/te-ci JAX L0 L1 |
|
||
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") | ||
def test_te_bf16(self): | ||
"""Test Transformer Engine with BF16""" | ||
actual = train_and_evaluate(self.args) | ||
assert actual[0] < 0.535 and actual[1] > 0.73 | ||
assert actual[0] < 0.53 and actual[1] > 0.74 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With two extra epochs I'd expect the loss to improved more than -0.005. Can you run this with the TE custom op gemm disabled to see if the loss and accuracy are in the same range with the JAX gemms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these tolerances were incorrectly leftover from before I updated the epochs to 5. I updated them after checking again with both native JAX GEMM and our custom GEMM. Thanks for catching it!
import os | ||
|
||
if enabled: | ||
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a context instead that stores the previous value of NVTE_JAX_CUSTOM_CALLS_RE and restores it once we exit the context?
If we want to test the whole suite with some TE cusotm ops disabled with NVTE_JAX_CUSTOM_CALLS_RE, this function will clear that filter for all tests that run after this function is called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I will update this to a context in all the tests.
) | ||
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) | ||
|
||
assert_allclose(primitive_out, ref_out, dtype=q_dtype) | ||
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the dtype here updated to a hardcoded e4m3 because we know that one of the operands to the gemm will always be e4m3 as e5m2 X e5m2 isn't supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a plain GEMM test with no FWD/BWD mechanism, I hardcoded e4m3 simply because it results in a tighter tolerance. I wanted to test the pure GEMM result a bit more stringently than the layers. Otherwise, the general principle in tests involving FWD/BWD through layers is to test forward outputs with e4m3 tolerances and gradients with e5m2.
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) | ||
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) | ||
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) | ||
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these dtypes use bwd_dtype
above instead of hardcoded to e5m2? The quantizer set above will use e4m3 as the bwd dtype when we have MXFP8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for tests below where grad assertion is always e5m2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forward output is hardcoded to e4m3 but gradients are hardcoded to e5m2 as I mentioned above, to align test tolerances with FWD/BWD FP8 types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the dtype of the gradients is not always e5m2, right? In the logic above, there is bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
. So for MXFP8 this will compute gradients in e4m3 but have a tolerance from e5m2, right?
import os | ||
|
||
if enabled: | ||
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here as the _use_jax_fp8_gemm defined in test_custom_call_compute.py
@@ -666,7 +668,7 @@ def _quantize_dbias_impl( | |||
is_outer=True, | |||
) | |||
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise | |||
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): | |||
if force_1x_quantization: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this line changed in this PR?
Given we mostly if not always force 1x quantization, I think in practice it does give the same result currently.
However, if we ever go back to 2x quantization this will be wrong. Even with 2x quantization, the scale buffers need to be linked, because the TE custom call only writes to one of the scale buffers for tensor scaling (since rowwise and colwise scale will be identical), so we should keep the original check here quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
Was the original check causing an issue for you that needed fixing in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On line 642, I set force_1x_quantization = (quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() and is_1x_kernel_supported)
so functionally it is still performing the same check, with an additional safeguard on whether the TE/common kernel supports 1x.
The original check without this safeguard was attempting to force 1x quantization with the fused quantize_dbias kernel, which is not supported on Hopper. In those situations, we want to follow the should_apply_1x_fused_dbias_war_for_arch_l_100()
code path instead to compute dbias via the un-fused native-JAX implementation and invoke the TE/common quantize kernel without dbias fusion.
This only became an issue when dense()
with the new GemmPrimitive()
had to execute the custom FWD/BWD with BF16 inputs too instead of letting JAX natively differentiate its own jax.lax.dot_general()
. The TE/JAX quantize calls needed to return dummy ScaledTensor
s when the quantizer did not exist or the scaling mode was set to NO_SCALING. That change exposed a particular case of forcing 1x without considering if it's supported by the underlying kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on the unsupported kernel that wasn't caught in our testing, that's a good change to improve this.
But separate from whether we force 1x quantization or not, we need to populate colwise scales with the rowwise scale outputs for tensor scaling. With this code change, if we find out 2x quantization is more performant on certain setups, we would no longer be able to just disable force_1x_quantization
, it'd break.
The original less restrictive check of quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
is required here. TE common never writes the colwise scale for tensor scaling. It's only an output on our JAX primitive for MXFP8 support.
From cast_transpose.cu
, the kernel only takes a single output scale_inv_ptr
to write to, which is the rowwise one
template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
const IType *__restrict__ const input, const CType *__restrict__ const noop,
OType *__restrict__ const output_c, OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr, CType *__restrict__ const amax_ptr,
CType *__restrict__ const scale_inv_ptr, const size_t row_length, const size_t num_rows) {
The original check covers the two cases:
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
colwise_scale_inv = rowwise_scale_inv
- We want to do 2x quantization. In this case the TE kernel populates only the rowwise_scale_inv as output. We set the colwwise_scale_inv to the rowwise_scale_inv so it is correct
- 2x quantization was requested (
quantizer.is_2x2x()
) but we force 1x quantization. This still covers this case because even if we update the local variableq_layout
to onlyROWWISE
, the originalquantizer.is_2x2x()
stays unchanged and we set the colwise_scale_inv correctly.
The updated check in this PR only covers case 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To recap a conversation off GitHub about this, we've determined that the should_apply_1x_fused_dbias_war_for_arch_l_100()
check further up the function was already modified in this PR to account for the failure case of fused-quantize-dbias on Hopper not supporting 1x quantization, so I'm reverting the change with force_1x_quantization
later in the function.
…ernal custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
663a574
to
4db82a0
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
/te-ci JAX L0 L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
/te-ci JAX L0 L1 |
Signed-off-by: Alp Dener <adener@nvidia.com>
/te-ci JAX L0 L1 |
Description
XLA custom op + JAX primitive for
nvte_cublas_gemm
. Custom partitioning rules closely followjax.nn.scaled_matmul
and the sharded implementation performsjax.lax.psum
orjax.lax.psum_scatter
under the same conditions. Also implements Shardy rules in preparation of JAX switching default partitioner to Shardy in July.IMPORTANT: Padded block scales are currently not compatible with Shardy because the padding breaks the size relationships we define with
CompoundFactor
s in the sharding rules. Since Shardy rules are applied to the inner primitive, the only way to make theseCompoundFactor
s work correctly is to lower unpadded scales into the XLA custom call and pad them in C++. This PR does not implement this approach. Instead, it removes theCompoundFactor
s entirely and de-couples the scaling factor rules from the related tensor rules. This should not cause any issues as long as the scales are sharded correctly outside the custom ops, but it would be safer in the long run to implement scale padding in C++ and restoreCompoundFactor
-based rules.Type of change
Checklist: