From 4f5eda4e6950895cae27fab83f5ff3620d74ba18 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 7 Jul 2025 10:32:20 -0700 Subject: [PATCH 1/7] Support flax sharding constraints Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/sharding.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a62e4769ed..6ec1f1a37b 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -117,7 +117,9 @@ def with_sharding_constraint_by_logical_axes( x: jnp.array, logical_axis_names: Optional[tuple | list] ): """ - A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. + A wrapper function to flax.linen.with_logical_constraint. + + DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future. If logical_axis_names = None, this means no sharding constraint is applied. @@ -133,6 +135,16 @@ def with_sharding_constraint_by_logical_axes( if not logical_axis_names: return x + try: + # Check if Flax logical axis rules are available, if so use them + import flax + flax_rules = flax.linen.get_logical_axis_rules() + if len(flax_rules) > 0: + return flax.linen.with_logical_constraint(x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT) + except ImportError: + pass + + # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) pspec = generate_pspec(logical_axis_names) return with_sharding_constraint(x, pspec) From 6658974871d621d10dec520ff204344bdf8a2be2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jul 2025 23:58:37 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/sharding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6ec1f1a37b..85eabeafb1 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -138,9 +138,12 @@ def with_sharding_constraint_by_logical_axes( try: # Check if Flax logical axis rules are available, if so use them import flax + flax_rules = flax.linen.get_logical_axis_rules() if len(flax_rules) > 0: - return flax.linen.with_logical_constraint(x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT) + return flax.linen.with_logical_constraint( + x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT + ) except ImportError: pass From cfa29ebe0ad2f16b48e297ddeaf8aa82d7c12305 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 10 Jul 2025 12:59:05 -0700 Subject: [PATCH 3/7] Add warning for deprecated TE logical axes Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/sharding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 85eabeafb1..452422bab5 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from enum import Enum from typing import Callable, Optional +import warnings from jax.interpreters import pxla import jax import jax.numpy as jnp @@ -147,6 +148,8 @@ def with_sharding_constraint_by_logical_axes( except ImportError: pass + warnings.warn("TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and will be removed in a future version. Please use Flax logical axes with a flax.linen.logical_axis_rules context instead.", DeprecationWarning) + # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) pspec = generate_pspec(logical_axis_names) From 467357553bb45a6eb175baf68f7c2c4639790edb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Jul 2025 20:00:54 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/sharding.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 452422bab5..ad22a56393 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -148,7 +148,12 @@ def with_sharding_constraint_by_logical_axes( except ImportError: pass - warnings.warn("TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and will be removed in a future version. Please use Flax logical axes with a flax.linen.logical_axis_rules context instead.", DeprecationWarning) + warnings.warn( + "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and" + " will be removed in a future version. Please use Flax logical axes with a" + " flax.linen.logical_axis_rules context instead.", + DeprecationWarning, + ) # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) From 22f68c71b563155918507b3fa5c8775e910695b0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 11 Jul 2025 10:59:39 -0700 Subject: [PATCH 5/7] Update transformer attention weight sharding axes and update example axis rule setup Signed-off-by: Jeremy Berchtold --- .../encoder/test_model_parallel_encoder.py | 49 +++--- examples/jax/encoder/test_multigpu_encoder.py | 161 +++++++++--------- .../encoder/test_multiprocessing_encoder.py | 41 ++--- transformer_engine/jax/flax/transformer.py | 3 +- transformer_engine/jax/sharding.py | 4 +- 5 files changed, 136 insertions(+), 122 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7d5fefddaa..82f971e606 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -262,36 +262,37 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, nn_partitioning.axis_rules( - ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + with te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), ): - rng = jax.random.PRNGKey(args.seed) - rng, params_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} - - input_shape = [args.batch_size, args.max_seq_len] - mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] - label_shape = [args.batch_size] - - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + # Get the base axis rules and extend them with TE's rules. + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh, flax.linen.logical_axis_rules(te_extended_axis_rules): + + print(f"Device mesh: {mesh}") + print(f"Axis rules: {te_extended_axis_rules}") + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + encoder = Net(num_embed, args.enable_sp) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - # Get the base axis rules and extend them with TE's rules. - axis_rules = nn_partitioning.get_axis_rules() - te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) - print(f"Device mesh: {mesh}") - print(f"Axis rules: {te_extended_axis_rules}") - logical_partition_spec = nn.get_partition_spec(abs_var_collect) # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 40431b8cdd..94e2929289 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -275,89 +275,98 @@ def train_and_evaluate(args): fp8_recipe=fp8_recipe, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), ): - encoder = Net(num_embed) - inputs = jnp.zeros(input_shape, dtype=jnp.int32) - masks = jnp.zeros(mask_shape, dtype=jnp.uint8) - abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) - params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) - inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) - masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) - - in_shardings = (None, inputs_sharding, masks_sharding) - out_shardings = { - key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect - } - jit_encoder_init = jax.jit( - encoder.init, in_shardings=in_shardings, out_shardings=out_shardings - ) - var_collect = jit_encoder_init(init_rngs, inputs, masks) - - optimizer = optax.adamw(args.lr) - var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) - state = train_state.TrainState.create( - apply_fn=encoder.apply, params=params, tx=optimizer - ) - state_sharding = get_state_sharding(state, params_sharding) - labels_sharding = NamedSharding( - mesh, - PartitionSpec( - DEVICE_DP_AXIS, - ), - ) - in_shardings = ( - state_sharding, - inputs_sharding, - masks_sharding, - labels_sharding, - None, - None, - ) - out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit( - train_step, in_shardings=in_shardings, out_shardings=out_shardings - ) - - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) - out_shardings = (None, None) - jit_eval_step = jax.jit( - eval_step, in_shardings=in_shardings, out_shardings=out_shardings - ) - - if args.use_fp8: - labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - check_fp8(state, var_collect, inputs, masks, labels) - - if args.dry_run: - labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} - jit_train_step(state, inputs, masks, labels, var_collect, rngs) - print("PASSED") - return None - - for epoch in range(1, args.epochs + 1): - rng, input_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} - - state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, jit_train_step + with flax.linen.logical_axis_rules(sharding_rules): + encoder = Net(num_embed) + inputs = jnp.zeros(input_shape, dtype=jnp.int32) + masks = jnp.zeros(mask_shape, dtype=jnp.uint8) + abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding( + mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None) ) - test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + in_shardings = (None, inputs_sharding, masks_sharding) + out_shardings = { + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect + } + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings ) + var_collect = jit_encoder_init(init_rngs, inputs, masks) - print( - f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} " + optimizer = optax.adamw(args.lr) + var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding( + mesh, + PartitionSpec( + DEVICE_DP_AXIS, + ), + ) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) + + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + ) + out_shardings = (None, None) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings ) - return [train_loss, train_accuracy, test_loss, test_accuracy] + if args.use_fp8: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + check_fp8(state, var_collect, inputs, masks, labels) + + if args.dry_run: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + rngs = {DROPOUT_KEY: dropout_rng} + jit_train_step(state, inputs, masks, labels, var_collect, rngs) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step + ) + + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, jit_eval_step + ) + + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) + + return [train_loss, train_accuracy, test_loss, test_accuracy] def encoder_parser(args): diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c1da4db4a9..264f7d34e4 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -377,31 +377,32 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh: - - rng = jax.random.PRNGKey(args.seed) - rng, params_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} - - input_shape = [args.batch_size, args.max_seq_len] - mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] - label_shape = [args.batch_size] - - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + with te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): + customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) + + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh, flax.linen.logical_axis_rules(sharding_rules): + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 97ab519b9c..f2c0bc2a1c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -180,8 +180,9 @@ def __call__( attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + # (b, h, q, k): Last two axes are always replicated attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) + attn_weights, (BATCH_AXES, HEAD_AXES, None, None) ) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index ad22a56393..e59c9de12d 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -151,7 +151,9 @@ def with_sharding_constraint_by_logical_axes( warnings.warn( "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and" " will be removed in a future version. Please use Flax logical axes with a" - " flax.linen.logical_axis_rules context instead.", + " flax.linen.logical_axis_rules context and optionally use" + " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your" + " rules.", DeprecationWarning, ) From a910e27ec4e1f574635b445037e257e947b72f89 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 15 Jul 2025 13:47:21 -0700 Subject: [PATCH 6/7] Revert "Update transformer attention weight sharding axes and update example axis rule setup" This reverts commit 062a58a6b1d9d4921b355fb737d04fb777d758c4. Signed-off-by: Jeremy Berchtold --- .../encoder/test_model_parallel_encoder.py | 49 +++--- examples/jax/encoder/test_multigpu_encoder.py | 161 +++++++++--------- .../encoder/test_multiprocessing_encoder.py | 41 +++-- transformer_engine/jax/flax/transformer.py | 3 +- transformer_engine/jax/sharding.py | 4 +- 5 files changed, 122 insertions(+), 136 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 82f971e606..7d5fefddaa 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -262,37 +262,36 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh, nn_partitioning.axis_rules( + ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) ): - # Get the base axis rules and extend them with TE's rules. - axis_rules = flax.linen.get_logical_axis_rules() - axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) - - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, flax.linen.logical_axis_rules(te_extended_axis_rules): - - print(f"Device mesh: {mesh}") - print(f"Axis rules: {te_extended_axis_rules}") - - rng = jax.random.PRNGKey(args.seed) - rng, params_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} - - input_shape = [args.batch_size, args.max_seq_len] - mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] - label_shape = [args.batch_size] - + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + + with te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): encoder = Net(num_embed, args.enable_sp) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + # Get the base axis rules and extend them with TE's rules. + axis_rules = nn_partitioning.get_axis_rules() + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + print(f"Device mesh: {mesh}") + print(f"Axis rules: {te_extended_axis_rules}") + logical_partition_spec = nn.get_partition_spec(abs_var_collect) # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 94e2929289..40431b8cdd 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -275,98 +275,89 @@ def train_and_evaluate(args): fp8_recipe=fp8_recipe, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), ): + encoder = Net(num_embed) + inputs = jnp.zeros(input_shape, dtype=jnp.int32) + masks = jnp.zeros(mask_shape, dtype=jnp.uint8) + abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + sharding_rules = te_flax.extend_logical_axis_rules(tuple()) - with flax.linen.logical_axis_rules(sharding_rules): - encoder = Net(num_embed) - inputs = jnp.zeros(input_shape, dtype=jnp.int32) - masks = jnp.zeros(mask_shape, dtype=jnp.uint8) - abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - - params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) - inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) - masks_sharding = NamedSharding( - mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None) + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) + + in_shardings = (None, inputs_sharding, masks_sharding) + out_shardings = { + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect + } + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + ) + var_collect = jit_encoder_init(init_rngs, inputs, masks) + + optimizer = optax.adamw(args.lr) + var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding( + mesh, + PartitionSpec( + DEVICE_DP_AXIS, + ), + ) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) + + in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + out_shardings = (None, None) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings + ) + + if args.use_fp8: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + check_fp8(state, var_collect, inputs, masks, labels) + + if args.dry_run: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + rngs = {DROPOUT_KEY: dropout_rng} + jit_train_step(state, inputs, masks, labels, var_collect, rngs) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) - in_shardings = (None, inputs_sharding, masks_sharding) - out_shardings = { - key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect - } - jit_encoder_init = jax.jit( - encoder.init, in_shardings=in_shardings, out_shardings=out_shardings + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, jit_eval_step ) - var_collect = jit_encoder_init(init_rngs, inputs, masks) - optimizer = optax.adamw(args.lr) - var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) - state = train_state.TrainState.create( - apply_fn=encoder.apply, params=params, tx=optimizer - ) - state_sharding = get_state_sharding(state, params_sharding) - labels_sharding = NamedSharding( - mesh, - PartitionSpec( - DEVICE_DP_AXIS, - ), - ) - in_shardings = ( - state_sharding, - inputs_sharding, - masks_sharding, - labels_sharding, - None, - None, - ) - out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit( - train_step, in_shardings=in_shardings, out_shardings=out_shardings - ) - - in_shardings = ( - state_sharding, - inputs_sharding, - masks_sharding, - labels_sharding, - None, - ) - out_shardings = (None, None) - jit_eval_step = jax.jit( - eval_step, in_shardings=in_shardings, out_shardings=out_shardings + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " ) - if args.use_fp8: - labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - check_fp8(state, var_collect, inputs, masks, labels) - - if args.dry_run: - labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} - jit_train_step(state, inputs, masks, labels, var_collect, rngs) - print("PASSED") - return None - - for epoch in range(1, args.epochs + 1): - rng, input_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} - - state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, jit_train_step - ) - - test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step - ) - - print( - f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} " - ) - - return [train_loss, train_accuracy, test_loss, test_accuracy] + return [train_loss, train_accuracy, test_loss, test_accuracy] def encoder_parser(args): diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 264f7d34e4..c1da4db4a9 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -377,32 +377,31 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): - customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) - - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, flax.linen.logical_axis_rules(sharding_rules): - - rng = jax.random.PRNGKey(args.seed) - rng, params_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} - - input_shape = [args.batch_size, args.max_seq_len] - mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] - label_shape = [args.batch_size] - + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh: + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + + with te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index f2c0bc2a1c..97ab519b9c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -180,9 +180,8 @@ def __call__( attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) - # (b, h, q, k): Last two axes are always replicated attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, None, None) + attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) ) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index e59c9de12d..ad22a56393 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -151,9 +151,7 @@ def with_sharding_constraint_by_logical_axes( warnings.warn( "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and" " will be removed in a future version. Please use Flax logical axes with a" - " flax.linen.logical_axis_rules context and optionally use" - " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your" - " rules.", + " flax.linen.logical_axis_rules context instead.", DeprecationWarning, ) From a46f60730703ec336891e9e8b0881939e8618c92 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 15 Jul 2025 14:47:45 -0700 Subject: [PATCH 7/7] Update examples Signed-off-by: Jeremy Berchtold --- .../encoder/test_model_parallel_encoder.py | 27 ++++++++++--------- examples/jax/encoder/test_multigpu_encoder.py | 17 +++++++----- .../encoder/test_multiprocessing_encoder.py | 21 ++++++++------- transformer_engine/jax/flax/transformer.py | 3 ++- transformer_engine/jax/sharding.py | 4 ++- 5 files changed, 41 insertions(+), 31 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7d5fefddaa..34e029330c 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -264,8 +264,10 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, nn_partitioning.axis_rules( - ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + ) as mesh, te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -276,22 +278,21 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + + with flax.linen.logical_axis_rules(te_extended_axis_rules): + + print(f"Device mesh: {mesh}") + print(f"Axis rules: {te_extended_axis_rules}") + encoder = Net(num_embed, args.enable_sp) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - # Get the base axis rules and extend them with TE's rules. - axis_rules = nn_partitioning.get_axis_rules() - te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) - print(f"Device mesh: {mesh}") - print(f"Axis rules: {te_extended_axis_rules}") - logical_partition_spec = nn.get_partition_spec(abs_var_collect) # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 40431b8cdd..32263145f8 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -259,7 +259,13 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu,)) - with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh: + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS,) + ) as mesh, te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), + ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) @@ -270,17 +276,14 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), - ): + # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast + sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + with flax.linen.logical_axis_rules(sharding_rules): encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c1da4db4a9..f112740a30 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -379,8 +379,11 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh: - + ) as mesh, te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) @@ -390,18 +393,18 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + # Create custom Flax logical axis rules for sharding. + customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + # Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast. + sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) + + with flax.linen.logical_axis_rules(sharding_rules): + encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 97ab519b9c..f2c0bc2a1c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -180,8 +180,9 @@ def __call__( attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + # (b, h, q, k): Last two axes are always replicated attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) + attn_weights, (BATCH_AXES, HEAD_AXES, None, None) ) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index ad22a56393..e59c9de12d 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -151,7 +151,9 @@ def with_sharding_constraint_by_logical_axes( warnings.warn( "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and" " will be removed in a future version. Please use Flax logical axes with a" - " flax.linen.logical_axis_rules context instead.", + " flax.linen.logical_axis_rules context and optionally use" + " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your" + " rules.", DeprecationWarning, )