Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove deprecated XLA flag #1010

Merged
merged 13 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_NVLS_ENABLE=0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ index 89974dd..388d2ec 100755
-MODEL_DIR_LOCAL=${7:-"model_dir"}
-MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL}
-NUM_MICROBATCHES=${8:-0}
+export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+
+#! Change these values !#
+FT_TASK=${FT_TASK:=mnli2} # currently supported: mnli2, squad1
Expand Down Expand Up @@ -751,7 +751,7 @@ index 18bb722..f807105 100755
-MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL}
-NUM_MICROBATCHES=${6:-0}
-MP=${7:-1}
+export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"

-echo Model Parallel partitions: ${MP}
+#! Change these values !#
Expand Down Expand Up @@ -3323,8 +3323,8 @@ index cd563ec..e075df3 100755
-set -x
+set -eoux pipefail

-export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
-export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
+export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}"

#! Change these values !#
Expand Down Expand Up @@ -3442,8 +3442,8 @@ index d083540..56919a5 100755
#BENCHMARK_MODE=True
STAT_PERIOD=100 #only used if BENCHMARK_MODE is set

-export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
-export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
+export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}"

#! Change these values !#
Expand Down
1 change: 0 additions & 1 deletion .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb
| XLA Flags | Value | Explanation |
| --------- | ----- | ----------- |
| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels |
| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels |
kocchop marked this conversation as resolved.
Show resolved Hide resolved

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
Expand All @@ -300,6 +299,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).

For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.

## Profiling JAX programs on GPU
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.

Expand Down
4 changes: 3 additions & 1 deletion rosetta/docs/GPU_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Fine-grain control to improve performance by initializing a NCCL communicator to
- --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX)
- --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging)

## Previously used XLA Flags


The following flags were used previously used but no longer required.
- --xla_gpu_enable_triton_gemm=false (use cuBLAS instead of Trition GeMM kernels); starting from JAX 0.4.32 we don't need it.
kocchop marked this conversation as resolved.
Show resolved Hide resolved

1 change: 0 additions & 1 deletion rosetta/docs/NATIVE_FP8.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ python -m paxml.main \

Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions.


## Transformer Engine vs Native FP8 Support
Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components.

Expand Down
1 change: 0 additions & 1 deletion rosetta/docs/PGLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ PGLE found latency for async op custom-call-start.1 and (assumed)custom-call-don
In order to get the best performance with PGLE, here is a list of all recommended XLA flags:
```
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
Expand Down
6 changes: 1 addition & 5 deletions rosetta/rosetta/projects/maxtext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta

```
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
Expand Down
8 changes: 1 addition & 7 deletions rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@ export NCCL_IB_SL=1

# Set XLA Flags
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
Expand All @@ -68,9 +64,7 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization
--xla_gpu_enable_custom_fusions=false
--xla_gpu_enable_address_computation_fusion=false"
--xla_disable_hlo_passes=rematerialization"

# Make directories that may not exist
mkdir -p $BASE_WORKSPACE_DIR
Expand Down
8 changes: 4 additions & 4 deletions rosetta/rosetta/projects/pax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta
For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows:

```
BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
--xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ...
BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_all_reduce_combine_threshold_bytes=33554432
--xla_gpu_graph_level=0" bash run_pile_multinode.sh ...
```

# Configs
Expand Down
Loading