-
Notifications
You must be signed in to change notification settings - Fork 451
[JAX] Collective GEMM custom op + primitive + minimal supporting functions #1846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
denera
wants to merge
32
commits into
NVIDIA:main
Choose a base branch
from
denera:jax/collective-gemm-api
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This was referenced Jun 3, 2025
1a845e9
to
e92c81a
Compare
Signed-off-by: Alp Dener <adener@nvidia.com> started GemmPrimitive, abstract done Signed-off-by: Alp Dener <adener@nvidia.com> gemm custom op working with BF16, needs testing for FP8/MXFP8 Signed-off-by: Alp Dener <adener@nvidia.com> converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call Signed-off-by: Alp Dener <adener@nvidia.com> BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue Signed-off-by: Alp Dener <adener@nvidia.com> fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2 Signed-off-by: Alp Dener <adener@nvidia.com> updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet Signed-off-by: Alp Dener <adener@nvidia.com> new GemmPrimitive passing all Dense tests Signed-off-by: Alp Dener <adener@nvidia.com> import cleanup and reverted code chunk movement Signed-off-by: Alp Dener <adener@nvidia.com> removed unused .transpose() implementations from ScaledTensors Signed-off-by: Alp Dener <adener@nvidia.com> all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl Signed-off-by: Alp Dener <adener@nvidia.com> removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions Signed-off-by: Alp Dener <adener@nvidia.com> removed unused changes to ScaledTensor classes and debug prints Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
…erEngine into jax/nvte-cublas-gemm-op
… Blackwell, MXFP8 has issues with E5M2 Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com> all unit tests passing on H100x8 node Signed-off-by: Alp Dener <adener@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci linting fixes Signed-off-by: Alp Dener <adener@nvidia.com> fixed batch dimension numbers Signed-off-by: Alp Dener <adener@nvidia.com> fixed FP8 scale sharding rule when there are no FP8 scales Signed-off-by: Alp Dener <adener@nvidia.com> added error message for unsupported Shardy partitioner Signed-off-by: Alp Dener <adener@nvidia.com> fixed test tolerances for FP8 cases Signed-off-by: Alp Dener <adener@nvidia.com> fixed shardy test skip cases Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
…rtitioning rules work correctly Signed-off-by: Alp Dener <adener@nvidia.com>
…d GemmPrimitive to accept unpadded scales and pad them after sharding Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
4575b98
to
77eaa63
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
6b9dc0e
to
aeddd66
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
49536b2
to
74ab649
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
3f1214e
to
b4ff961
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
6762c45
to
95564fc
Compare
for more information, see https://pre-commit.ci
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces a new XLA custom op for calling
nvte_cublas_gemm
or related comm+GEMM overlap algorithms, the accompanying JAX primitive, and bare minimum Python wrappers required to work with the custom call.FWD/BWD autograd implementation will be tackled in a separate upcoming PR.
Type of change
Checklist: