diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 164fa5912..0dc26c8c1 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true +export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true - --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization} diff --git a/README.md b/README.md index 66d9b2a4e..1764c5f00 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33). +For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index c5456e3c4..fabbc6963 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to - --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX) - --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging) +## Previously used XLA Flags - +The following flags were used previously used but no longer required. +- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed +- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default +- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index dd3aa1bae..069b06fdd 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -111,13 +111,11 @@ Enabling this feature is effortless. Users only need to include the option `--fd In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like: ```bash export XLA_FLAGS=" \ - --xla_gpu_enable_reduction_epilogue_fusion=false \ --xla_gpu_enable_triton_gemm=false \ - --xla_gpu_enable_cudnn_fmha=false \ - --xla_gpu_enable_cudnn_layer_norm=true \ - --xla_gpu_enable_cublaslt=true \ - --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_all_reduce_combine_threshold_bytes=51200 " + --xla_gpu_enable_pipelined_all_reduce=false \ + --xla_gpu_enable_pipelined_all_gather=false \ + --xla_gpu_enable_pipelined_reduce_scatter=false \ +" export ENABLE_TE=0 python -m paxml.main \ ... @@ -125,8 +123,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. - +Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 02e5f5294..2425ddffe 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -70,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true ---xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index fde5a9125..2320a7ed9 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -67,12 +67,9 @@ In order to obtain the best performance, please set the appropriate XLA flags. W The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText. ``` -XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true +XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true + --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -80,7 +77,6 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index e96eaa781..0ca3fd802 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -53,11 +53,8 @@ export NCCL_IB_SL=1 # Set XLA Flags export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -65,12 +62,9 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false - --xla_disable_hlo_passes=rematerialization - --xla_gpu_enable_custom_fusions=false - --xla_gpu_enable_address_computation_fusion=false" + --xla_disable_hlo_passes=rematerialization" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index 6ac4dc150..d1829b847 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -138,10 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows: ``` -BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 - --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... +BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm=false + --xla_gpu_all_reduce_combine_threshold_bytes=33554432 + --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... ``` # Configs