Skip to content

[JAX] Support Flax sharding constraints #1933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,36 +261,37 @@ def train_and_evaluate(args):
fp8_recipe = None

device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, nn_partitioning.axis_rules(
((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]

with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
# Get the base axis rules and extend them with TE's rules.
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)

with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, flax.linen.logical_axis_rules(te_extended_axis_rules):
Comment on lines +274 to +276
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
I'm curious to learn what the difference is between having the jax.sharding.Mesh context before vs after the te.fp8_autocast context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik it doesn't matter the order of the Mesh vs. fp8_autocast, as long as both are before your model. But the order of te_flax.extend_logical_axis_rules and fp8_autocast does matter according to the docstring of the former.

So the ordering needs to be:

  1. fp8_autocast context
  2. Create Flax logical axis rule context. To support TE's hardcoded axis system in Flax (which UnfusedDotProductAttention requires), we extend the logical access rules with TE's rule table via te_flax.extend_logical_axis_rules, which must be inside an fp8_autocast.
  3. Create and initialize the model, training loop, etc.

Afaik, the Mesh can come anywhere before item 3.

I just pulled up the fp8_autocast to the top and merged the Mesh with the with block for the logical axis rule context in item 2 to reduce indentation. But if a smaller diff is preferred, I can do Mesh -> fp8_autocast -> logical axis rules


print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")

rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]

encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

# Get the base axis rules and extend them with TE's rules.
axis_rules = nn_partitioning.get_axis_rules()
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")

logical_partition_spec = nn.get_partition_spec(abs_var_collect)

# Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
Expand Down
161 changes: 85 additions & 76 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,89 +274,98 @@ def train_and_evaluate(args):
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))

in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks)

optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(
mesh,
PartitionSpec(
DEVICE_DP_AXIS,
),
)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)

in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)

if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)

if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None

for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(
mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)
)

test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks)

print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(
mesh,
PartitionSpec(
DEVICE_DP_AXIS,
),
)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)

in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)

return [train_loss, train_accuracy, test_loss, test_accuracy]
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)

if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None

for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)

test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
)

print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)

return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
Expand Down
41 changes: 21 additions & 20 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,31 +377,32 @@ def train_and_evaluate(args):
fp8_recipe = None

device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:

rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]

with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
with te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(customized_rules)

with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, flax.linen.logical_axis_rules(sharding_rules):

rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]

encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def __call__(
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

# (b, h, q, k): Last two axes are always replicated
attn_weights = with_sharding_constraint_by_logical_axes(
attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)
attn_weights, (BATCH_AXES, HEAD_AXES, None, None)
)

# When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias)
Expand Down
27 changes: 26 additions & 1 deletion transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import warnings
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list]
):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
A wrapper function to flax.linen.with_logical_constraint.

DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future.

If logical_axis_names = None, this means no sharding constraint is applied.

Expand All @@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes(
if not logical_axis_names:
return x

try:
# Check if Flax logical axis rules are available, if so use them
import flax

flax_rules = flax.linen.get_logical_axis_rules()
if len(flax_rules) > 0:
return flax.linen.with_logical_constraint(
x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT
)
except ImportError:
pass

warnings.warn(
"TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and"
" will be removed in a future version. Please use Flax logical axes with a"
" flax.linen.logical_axis_rules context and optionally use"
" transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your"
" rules.",
DeprecationWarning,
)

# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec)
Expand Down