Skip to content

[JAX] Collective GEMM custom op + primitive + minimal supporting functions #1846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Jun 3, 2025

Description

This PR introduces a new XLA custom op for calling nvte_cublas_gemm or related comm+GEMM overlap algorithms, the accompanying JAX primitive, and bare minimum Python wrappers required to work with the custom call.

FWD/BWD autograd implementation will be tackled in a separate upcoming PR.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

denera and others added 22 commits June 13, 2025 04:55
Signed-off-by: Alp Dener <adener@nvidia.com>

started GemmPrimitive, abstract done

Signed-off-by: Alp Dener <adener@nvidia.com>

gemm custom op working with BF16, needs testing for FP8/MXFP8

Signed-off-by: Alp Dener <adener@nvidia.com>

converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call

Signed-off-by: Alp Dener <adener@nvidia.com>

BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue

Signed-off-by: Alp Dener <adener@nvidia.com>

fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2

Signed-off-by: Alp Dener <adener@nvidia.com>

updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet

Signed-off-by: Alp Dener <adener@nvidia.com>

new GemmPrimitive passing all Dense tests

Signed-off-by: Alp Dener <adener@nvidia.com>

import cleanup and reverted code chunk movement

Signed-off-by: Alp Dener <adener@nvidia.com>

removed unused .transpose() implementations from ScaledTensors

Signed-off-by: Alp Dener <adener@nvidia.com>

all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl

Signed-off-by: Alp Dener <adener@nvidia.com>

removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions

Signed-off-by: Alp Dener <adener@nvidia.com>

removed unused changes to ScaledTensor classes and debug prints

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
… Blackwell, MXFP8 has issues with E5M2

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>

all unit tests passing on H100x8 node

Signed-off-by: Alp Dener <adener@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

linting fixes

Signed-off-by: Alp Dener <adener@nvidia.com>

fixed batch dimension numbers

Signed-off-by: Alp Dener <adener@nvidia.com>

fixed FP8 scale sharding rule when there are no FP8 scales

Signed-off-by: Alp Dener <adener@nvidia.com>

added error message for unsupported Shardy partitioner

Signed-off-by: Alp Dener <adener@nvidia.com>

fixed test tolerances for FP8 cases

Signed-off-by: Alp Dener <adener@nvidia.com>

fixed shardy test skip cases

Signed-off-by: Alp Dener <adener@nvidia.com>
…rtitioning rules work correctly

Signed-off-by: Alp Dener <adener@nvidia.com>
…d GemmPrimitive to accept unpadded scales and pad them after sharding

Signed-off-by: Alp Dener <adener@nvidia.com>
denera and others added 4 commits June 27, 2025 16:39
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/collective-gemm-api branch from 4575b98 to 77eaa63 Compare July 2, 2025 07:23
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/collective-gemm-api branch from 6b9dc0e to aeddd66 Compare July 4, 2025 07:42
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/collective-gemm-api branch from 49536b2 to 74ab649 Compare July 4, 2025 07:57
denera added 2 commits July 4, 2025 08:22
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/collective-gemm-api branch from 3f1214e to b4ff961 Compare July 4, 2025 09:03
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax/collective-gemm-api branch from 6762c45 to 95564fc Compare July 4, 2025 10:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants