Skip to content

[JAX] Support Flax sharding constraints #1933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Jul 7, 2025

Description

Updates our with_sharding_constraint_by_logical_axes helper function to support and prefer Flax logical axis rules when they exist in the current context. If no Flax logical axes rules exist in the current context, it will fall back to TE's hardcoded logical axes, though this functionality is no deprecated and will be removed in the future.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update with_sharding_constraint_by_logical_axes to check if Flax logical axis rules exist in the current context, and if so call the sharding constraint with flax.linen.with_logical_constraint
  • Update unfused attention in transformers.py to remove duplicate SEQLEN_AXIS usage in a single sharding constraint. This was not flagged before since our TE hardcoded axis system didn't check for this and it was always mapped to None/replicated.
  • Corrected logical axis rule context setup in encoder examples

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 2 commits July 7, 2025 10:32
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia changed the title [DRAFT] [JAX] Support Flax sharding constraints [JAX] Support Flax sharding constraints Jul 10, 2025
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 L1

jberchtold-nvidia and others added 3 commits July 10, 2025 13:00
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 L1

1 similar comment
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 L1

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 L1 jax

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 L1

@phu0ngng
Copy link
Collaborator

/te-ci L1

Comment on lines +274 to +276
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, flax.linen.logical_axis_rules(te_extended_axis_rules):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
I'm curious to learn what the difference is between having the jax.sharding.Mesh context before vs after the te.fp8_autocast context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik it doesn't matter the order of the Mesh vs. fp8_autocast, as long as both are before your model. But the order of te_flax.extend_logical_axis_rules and fp8_autocast does matter according to the docstring of the former.

So the ordering needs to be:

  1. fp8_autocast context
  2. Create Flax logical axis rule context. To support TE's hardcoded axis system in Flax (which UnfusedDotProductAttention requires), we extend the logical access rules with TE's rule table via te_flax.extend_logical_axis_rules, which must be inside an fp8_autocast.
  3. Create and initialize the model, training loop, etc.

Afaik, the Mesh can come anywhere before item 3.

I just pulled up the fp8_autocast to the top and merged the Mesh with the with block for the logical axis rule context in item 2 to reduce indentation. But if a smaller diff is preferred, I can do Mesh -> fp8_autocast -> logical axis rules

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants