-
Notifications
You must be signed in to change notification settings - Fork 451
[JAX] Support Flax sharding constraints #1933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Support Flax sharding constraints #1933
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
/te-ci L0 L1 |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
/te-ci L0 L1 |
1 similar comment
/te-ci L0 L1 |
/te-ci L0 L1 jax |
…axis rule setup Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
/te-ci L0 L1 |
/te-ci L1 |
with jax.sharding.Mesh( | ||
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) | ||
) as mesh, flax.linen.logical_axis_rules(te_extended_axis_rules): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I'm curious to learn what the difference is between having the jax.sharding.Mesh
context before vs after the te.fp8_autocast
context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Afaik it doesn't matter the order of the Mesh vs. fp8_autocast, as long as both are before your model. But the order of te_flax.extend_logical_axis_rules
and fp8_autocast
does matter according to the docstring of the former.
So the ordering needs to be:
fp8_autocast
context- Create Flax logical axis rule context. To support TE's hardcoded axis system in Flax (which UnfusedDotProductAttention requires), we extend the logical access rules with TE's rule table via
te_flax.extend_logical_axis_rules
, which must be inside anfp8_autocast
. - Create and initialize the model, training loop, etc.
Afaik, the Mesh can come anywhere before item 3.
I just pulled up the fp8_autocast
to the top and merged the Mesh
with the with
block for the logical axis rule context in item 2 to reduce indentation. But if a smaller diff is preferred, I can do Mesh -> fp8_autocast -> logical axis rules
Description
Updates our
with_sharding_constraint_by_logical_axes
helper function to support and prefer Flax logical axis rules when they exist in the current context. If no Flax logical axes rules exist in the current context, it will fall back to TE's hardcoded logical axes, though this functionality is no deprecated and will be removed in the future.Type of change
Changes
with_sharding_constraint_by_logical_axes
to check if Flax logical axis rules exist in the current context, and if so call the sharding constraint withflax.linen.with_logical_constraint
transformers.py
to remove duplicate SEQLEN_AXIS usage in a single sharding constraint. This was not flagged before since our TE hardcoded axis system didn't check for this and it was always mapped toNone
/replicated.Checklist: