diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 1f45d10faf..4f09cef30f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -261,36 +261,37 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, nn_partitioning.axis_rules( - ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + with te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), ): - rng = jax.random.PRNGKey(args.seed) - rng, params_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} - - input_shape = [args.batch_size, args.max_seq_len] - mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] - label_shape = [args.batch_size] - - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + # Get the base axis rules and extend them with TE's rules. + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh, flax.linen.logical_axis_rules(te_extended_axis_rules): + + print(f"Device mesh: {mesh}") + print(f"Axis rules: {te_extended_axis_rules}") + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + encoder = Net(num_embed, args.enable_sp) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - # Get the base axis rules and extend them with TE's rules. - axis_rules = nn_partitioning.get_axis_rules() - te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) - print(f"Device mesh: {mesh}") - print(f"Axis rules: {te_extended_axis_rules}") - logical_partition_spec = nn.get_partition_spec(abs_var_collect) # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 12148b0e29..95056b4165 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -274,89 +274,98 @@ def train_and_evaluate(args): fp8_recipe=fp8_recipe, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), ): - encoder = Net(num_embed) - inputs = jnp.zeros(input_shape, dtype=jnp.int32) - masks = jnp.zeros(mask_shape, dtype=jnp.uint8) - abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) - params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) - inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) - masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) - - in_shardings = (None, inputs_sharding, masks_sharding) - out_shardings = { - key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect - } - jit_encoder_init = jax.jit( - encoder.init, in_shardings=in_shardings, out_shardings=out_shardings - ) - var_collect = jit_encoder_init(init_rngs, inputs, masks) - - optimizer = optax.adamw(args.lr) - var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) - state = train_state.TrainState.create( - apply_fn=encoder.apply, params=params, tx=optimizer - ) - state_sharding = get_state_sharding(state, params_sharding) - labels_sharding = NamedSharding( - mesh, - PartitionSpec( - DEVICE_DP_AXIS, - ), - ) - in_shardings = ( - state_sharding, - inputs_sharding, - masks_sharding, - labels_sharding, - None, - None, - ) - out_shardings = (state_sharding, None, None, None) - jit_train_step = jax.jit( - train_step, in_shardings=in_shardings, out_shardings=out_shardings - ) - - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) - out_shardings = (None, None) - jit_eval_step = jax.jit( - eval_step, in_shardings=in_shardings, out_shardings=out_shardings - ) - - if args.use_fp8: - labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - check_fp8(state, var_collect, inputs, masks, labels) - - if args.dry_run: - labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} - jit_train_step(state, inputs, masks, labels, var_collect, rngs) - print("PASSED") - return None - - for epoch in range(1, args.epochs + 1): - rng, input_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} - - state, train_loss, train_accuracy, var_collect = train_epoch( - state, train_ds, args.batch_size, rngs, var_collect, jit_train_step + with flax.linen.logical_axis_rules(sharding_rules): + encoder = Net(num_embed) + inputs = jnp.zeros(input_shape, dtype=jnp.int32) + masks = jnp.zeros(mask_shape, dtype=jnp.uint8) + abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) + + params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) + inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) + masks_sharding = NamedSharding( + mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None) ) - test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + in_shardings = (None, inputs_sharding, masks_sharding) + out_shardings = { + key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect + } + jit_encoder_init = jax.jit( + encoder.init, in_shardings=in_shardings, out_shardings=out_shardings ) + var_collect = jit_encoder_init(init_rngs, inputs, masks) - print( - f"Epoch: {epoch:>2} " - f"Train Loss: {train_loss:.6f} " - f"Train Accuracy: {train_accuracy:.6f} " - f"Test Loss: {test_loss:.6f} " - f"Test Accuracy: {test_accuracy:.6f} " + optimizer = optax.adamw(args.lr) + var_collect, params = flax.core.pop(var_collect, PARAMS_KEY) + state = train_state.TrainState.create( + apply_fn=encoder.apply, params=params, tx=optimizer + ) + state_sharding = get_state_sharding(state, params_sharding) + labels_sharding = NamedSharding( + mesh, + PartitionSpec( + DEVICE_DP_AXIS, + ), + ) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) + out_shardings = (state_sharding, None, None, None) + jit_train_step = jax.jit( + train_step, in_shardings=in_shardings, out_shardings=out_shardings + ) + + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + ) + out_shardings = (None, None) + jit_eval_step = jax.jit( + eval_step, in_shardings=in_shardings, out_shardings=out_shardings ) - return [train_loss, train_accuracy, test_loss, test_accuracy] + if args.use_fp8: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + check_fp8(state, var_collect, inputs, masks, labels) + + if args.dry_run: + labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) + rngs = {DROPOUT_KEY: dropout_rng} + jit_train_step(state, inputs, masks, labels, var_collect, rngs) + print("PASSED") + return None + + for epoch in range(1, args.epochs + 1): + rng, input_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + + state, train_loss, train_accuracy, var_collect = train_epoch( + state, train_ds, args.batch_size, rngs, var_collect, jit_train_step + ) + + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, jit_eval_step + ) + + print( + f"Epoch: {epoch:>2} " + f"Train Loss: {train_loss:.6f} " + f"Train Accuracy: {train_accuracy:.6f} " + f"Test Loss: {test_loss:.6f} " + f"Test Accuracy: {test_accuracy:.6f} " + ) + + return [train_loss, train_accuracy, test_loss, test_accuracy] def encoder_parser(args): diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 580824cefa..d12d31d671 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -377,31 +377,32 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh: - - rng = jax.random.PRNGKey(args.seed) - rng, params_rng = jax.random.split(rng) - rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} - - input_shape = [args.batch_size, args.max_seq_len] - mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] - label_shape = [args.batch_size] - - with te.fp8_autocast( - enabled=args.use_fp8, - fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), - ): + with te.fp8_autocast( + enabled=args.use_fp8, + fp8_recipe=fp8_recipe, + mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + ): + customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) + sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) + + with jax.sharding.Mesh( + devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) + ) as mesh, flax.linen.logical_axis_rules(sharding_rules): + + rng = jax.random.PRNGKey(args.seed) + rng, params_rng = jax.random.split(rng) + rng, dropout_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + + input_shape = [args.batch_size, args.max_seq_len] + mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] + label_shape = [args.batch_size] + encoder = Net(num_embed) inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) - customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 97ab519b9c..f2c0bc2a1c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -180,8 +180,9 @@ def __call__( attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + # (b, h, q, k): Last two axes are always replicated attn_weights = with_sharding_constraint_by_logical_axes( - attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) + attn_weights, (BATCH_AXES, HEAD_AXES, None, None) ) # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a62e4769ed..e59c9de12d 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from enum import Enum from typing import Callable, Optional +import warnings from jax.interpreters import pxla import jax import jax.numpy as jnp @@ -117,7 +118,9 @@ def with_sharding_constraint_by_logical_axes( x: jnp.array, logical_axis_names: Optional[tuple | list] ): """ - A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. + A wrapper function to flax.linen.with_logical_constraint. + + DEPRECATED USE CASE: If no Flax logical axis rules are available, this function falls back to jax.lax.with_sharding_constraint using a hardcoded logical axis rule table from TE rules, such as BATCH_AXES. This functionality will be removed in the future. If logical_axis_names = None, this means no sharding constraint is applied. @@ -133,6 +136,28 @@ def with_sharding_constraint_by_logical_axes( if not logical_axis_names: return x + try: + # Check if Flax logical axis rules are available, if so use them + import flax + + flax_rules = flax.linen.get_logical_axis_rules() + if len(flax_rules) > 0: + return flax.linen.with_logical_constraint( + x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT + ) + except ImportError: + pass + + warnings.warn( + "TransformerEngine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated and" + " will be removed in a future version. Please use Flax logical axes with a" + " flax.linen.logical_axis_rules context and optionally use" + " transformer_engine.jax.flax.extend_logical_axis_rules to add BATCH_AXES, etc. to your" + " rules.", + DeprecationWarning, + ) + + # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) pspec = generate_pspec(logical_axis_names) return with_sharding_constraint(x, pspec)