diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py
index dafafdff47..1df16cc016 100644
--- a/benchmarks/attention/benchmark_attention.py
+++ b/benchmarks/attention/benchmark_attention.py
@@ -9,11 +9,11 @@
import torch
import nvtx
import transformer_engine
-from tests.pytorch.fused_attn.test_fused_attn import (
+from tests.pytorch.utils import (
ModelConfig,
- _get_attention_backends,
- _run_dot_product_attention,
+ get_available_attention_backends,
)
+from tests.pytorch.attention.test_attention import _run_dot_product_attention
pd.set_option("display.precision", 4)
@@ -197,7 +197,7 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
- available_backends, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
index e9eec14d99..97f1bcd7ec 100644
--- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
+++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
@@ -5,7 +5,7 @@
import os
import torch
from typing import Tuple
-from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
+from tests.pytorch.utils import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state
diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb
index 53a5eede74..6cd56d23da 100644
--- a/docs/examples/attention/attention.ipynb
+++ b/docs/examples/attention/attention.ipynb
@@ -375,7 +375,7 @@
"\n",
"Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n",
- "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
+ "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
]
},
{
@@ -394,10 +394,10 @@
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
- "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
- "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
- "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
- "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
+ "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
+ "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
+ "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
+ "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)"
]
},
{
@@ -458,7 +458,7 @@
" \n",
"\n",
"\n",
- "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
+ "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"
\n",
"Note\n",
@@ -548,7 +548,7 @@
"id": "dda4a589",
"metadata": {},
"source": [
- "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
+ "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n",
"\n",
"### 3.3 Attention Bias\n",
"\n",
@@ -594,7 +594,7 @@
"\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n",
- "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)."
+ "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)."
]
},
{
@@ -612,7 +612,7 @@
"\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n",
- "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
+ "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
]
}
],
diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py
index 2c32e8b5f7..cf650265bc 100644
--- a/docs/examples/attention/example_attention.py
+++ b/docs/examples/attention/example_attention.py
@@ -9,11 +9,11 @@
import torch
import nvtx
import transformer_engine
-from tests.pytorch.fused_attn.test_fused_attn import (
+from tests.pytorch.utils import (
ModelConfig,
- _get_attention_backends,
- _run_dot_product_attention,
+ get_available_attention_backends,
)
+from tests.pytorch.attention.test_attention import _run_dot_product_attention
# data type
dtype = torch.bfloat16
@@ -90,7 +90,7 @@ def main():
models = ["test_0"]
for model in models:
config = model_configs[model]
- available_backends, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index 7fe439b37f..9a924282b5 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
-python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
-python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
+python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
+python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh
index 09ef661c4a..f0436d4ff8 100644
--- a/qa/L1_pytorch_distributed_unittest/test.sh
+++ b/qa/L1_pytorch_distributed_unittest/test.sh
@@ -28,7 +28,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
-python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
+python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh
index 547849e950..7e9616cd03 100644
--- a/qa/L3_pytorch_FA_versions_test/test.sh
+++ b/qa/L3_pytorch_FA_versions_test/test.sh
@@ -41,6 +41,6 @@ do
fi
# Run tests
- NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
+ NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
done
diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py
index f9e5c8ad2e..29a9bc2b9f 100644
--- a/tests/jax/test_fused_attn.py
+++ b/tests/jax/test_fused_attn.py
@@ -372,7 +372,7 @@ def _check_configs(self):
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
- if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
+ if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("Unsupported inputs combination or device compute capability.")
if (
diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py
similarity index 99%
rename from tests/pytorch/fused_attn/run_fused_attn_with_cp.py
rename to tests/pytorch/attention/run_attention_with_cp.py
index f1db30d992..0ad64204f7 100644
--- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+++ b/tests/pytorch/attention/run_attention_with_cp.py
@@ -13,7 +13,7 @@
get_cu_seqlens_on_cp_rank,
)
import transformer_engine_torch as tex
-from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
+from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/attention/test_attention.py
similarity index 91%
rename from tests/pytorch/fused_attn/test_fused_attn.py
rename to tests/pytorch/attention/test_attention.py
index a05e64fca3..2c7d86d3d6 100644
--- a/tests/pytorch/fused_attn/test_fused_attn.py
+++ b/tests/pytorch/attention/test_attention.py
@@ -4,8 +4,9 @@
import logging
import math
import os
+import sys
+import pathlib
from typing import Any, Dict, List, Tuple, Union, Optional
-from contextlib import contextmanager
import pytest
import torch
@@ -21,7 +22,6 @@
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
- AttentionParams,
)
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
@@ -48,6 +48,17 @@
restore_from_saved,
)
+_current_file = pathlib.Path(__file__).resolve()
+sys.path.append(str(_current_file.parent.parent))
+from utils import (
+ _rng_states,
+ reset_rng_states,
+ ModelConfig,
+ dtype_tols,
+ logging_context,
+ get_available_attention_backends,
+)
+
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
@@ -55,171 +66,8 @@
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
-_cpu_rng_state = torch.get_rng_state()
-_cuda_rng_state = torch.cuda.get_rng_state()
-
-
-def reset_rng_states() -> None:
- """Revert back to initial RNG state"""
- torch.set_rng_state(_cpu_rng_state)
- torch.cuda.set_rng_state(_cuda_rng_state)
-
-
-@pytest.fixture(autouse=True)
-def reset_global_fp8_state():
- yield
- fp8.FP8GlobalStateManager.reset()
-
-
-class ModelConfig:
- def __init__(
- self,
- batch_size: int,
- num_heads: int,
- num_gqa_groups: int,
- head_dim_qk: int,
- max_seqlen_q: int,
- max_seqlen_kv: int,
- dropout_p: float,
- attn_mask_type: str,
- attn_bias_type: str,
- head_dim_v: int = None,
- alibi_type: str = "none",
- num_layers: int = 1,
- bias_shape: str = "1hss",
- window_size: Tuple[int, int] = (-1, -1),
- total_requests: int = None,
- max_ctx_len: int = None,
- ):
- self.batch_size = batch_size
- self.num_heads = num_heads
- self.num_gqa_groups = num_gqa_groups
- self.head_dim_qk = head_dim_qk
- self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
- self.hidden_size = num_heads * head_dim_qk
- self.hidden_size_kv = num_gqa_groups * self.head_dim_v
- self.max_seqlen_q = max_seqlen_q
- self.max_seqlen_kv = max_seqlen_kv
- self.dropout_p = dropout_p
- self.attn_mask_type = attn_mask_type
- self.attn_bias_type = attn_bias_type
- self.alibi_type = alibi_type
- self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
- self.num_layers = num_layers
- self.bias_shape = bias_shape
- self.window_size = window_size
- self.total_requests = total_requests
- self.max_ctx_len = max_ctx_len
-
-
-@contextmanager
-def logging_context(highest_level=logging.WARNING):
- previous_level = logging.root.manager.disable
- logging.disable(highest_level)
- try:
- yield
- finally:
- logging.disable(previous_level)
-
-
-def _get_attention_backends(
- config: ModelConfig,
- qkv_dtype: torch.dtype,
- qkv_layout: str,
- window_size: Tuple[int, int] = (-1, -1),
- pad_between_seqs: bool = False,
- context_parallel: bool = False,
- deterministic: bool = False,
- fp8: bool = False,
- fp8_meta: Optional[Dict[str, Any]] = None,
- is_training: bool = True,
- inference_params: Optional[InferenceParams] = None,
-) -> Tuple[List, List]:
- """Check if what attention backends support a model configuration"""
-
- os.environ["NVTE_FLASH_ATTN"] = "1"
- os.environ["NVTE_FUSED_ATTN"] = "1"
- os.environ["NVTE_UNFUSED_ATTN"] = "1"
- _attention_backends["backend_selection_requires_update"] = True
-
- alibi_slopes_shape = None
- if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
- if config.bias_shape == "1hss":
- alibi_slopes_shape = [config.num_heads]
- if config.bias_shape == "bhss":
- alibi_slopes_shape = [config.batch_size, config.num_heads]
-
- core_attention_bias_shape = (
- config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
- )
- core_attention_bias_requires_grad = False
- # d=256 is supported by cuDNN 9.0+ for inference but not training
- if (
- config.attn_bias_type == "post_scale_bias"
- and config.head_dim_qk <= 128
- and config.head_dim_v <= 128
- ):
- core_attention_bias_requires_grad = True
-
- fused_attn_backends = []
- available_backends = None
- flash_attention_backend = None
- fused_attention_backend = None
-
- def test():
- attention_params = AttentionParams(
- qkv_dtype=qkv_dtype,
- qkv_layout=qkv_layout,
- batch_size=config.batch_size,
- num_heads=config.num_heads,
- num_gqa_groups=config.num_gqa_groups,
- max_seqlen_q=config.max_seqlen_q,
- max_seqlen_kv=config.max_seqlen_kv,
- head_dim_qk=config.head_dim_qk,
- head_dim_v=config.head_dim_v,
- attn_mask_type=config.attn_mask_type,
- window_size=window_size,
- alibi_slopes_shape=alibi_slopes_shape,
- core_attention_bias_type=config.attn_bias_type,
- core_attention_bias_shape=core_attention_bias_shape,
- core_attention_bias_requires_grad=core_attention_bias_requires_grad,
- pad_between_seqs=pad_between_seqs,
- attention_dropout=config.dropout_p,
- context_parallel=context_parallel,
- deterministic=deterministic,
- fp8=fp8,
- fp8_meta=fp8_meta,
- is_training=is_training,
- inference_params=inference_params,
- )
- (
- use_flash_attention,
- use_fused_attention,
- flash_attention_backend,
- fused_attention_backend,
- use_unfused_attention,
- available_backends,
- ) = get_attention_backend(attention_params)
- # Set attention.py _attention_backends var using return value
- # from get_attention_backend()
- _attention_backends["use_flash_attention"] = use_flash_attention
- _attention_backends["use_fused_attention"] = use_fused_attention
- _attention_backends["flash_attention_backend"] = flash_attention_backend
- _attention_backends["fused_attention_backend"] = fused_attention_backend
- _attention_backends["use_unfused_attention"] = use_unfused_attention
- _attention_backends["backend_selection_requires_update"] = False
- return available_backends, flash_attention_backend, fused_attention_backend
-
- backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
- with logging_context():
- for i in range(3):
- os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
- _attention_backends["backend_selection_requires_update"] = True
- available_backends, flash_attention_backend, fused_attention_backend = test()
- if fused_attention_backend == FusedAttnBackend[backends[i]]:
- fused_attn_backends.append(fused_attention_backend)
- return available_backends, flash_attention_backend, fused_attn_backends
-
+_rng_states = None
+reset_rng_states()
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias
@@ -278,7 +126,7 @@ def test_dot_product_attention(
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True
- available_backends, _, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
@@ -289,7 +137,7 @@ def test_dot_product_attention(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
- available_backends, _, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
@@ -1168,7 +1016,7 @@ def test_transformer_layer(
# Test backend availability
is_training = True
- available_backends, _, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
@@ -1179,7 +1027,7 @@ def test_transformer_layer(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
- available_backends, _, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
@@ -1492,6 +1340,150 @@ def _run_transformer_layer(
return out, inp.grad
+model_configs_fp8_extra_state = {
+ "large": ModelConfig(2, 4, 4, 128, 128, 128, 0.0, "no_mask", "no_bias", num_layers=1),
+}
+
+
+@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
+@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
+@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
+@pytest.mark.parametrize("model", ["large"])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+def test_sanity_attention_extra_state(model, dtype):
+ config = model_configs_fp8_extra_state[model]
+ # Test backend availability
+ is_training = True
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=torch.float8_e4m3fn,
+ qkv_layout="sb3hd",
+ is_training=is_training,
+ )
+ flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
+ if not fused_attn_supported and not flash_attn_supported:
+ pytest.skip("No attention backend available.")
+
+ outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
+ outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
+ outputs_checkpoint_v1_6 = _run_attention_extra_state(
+ dtype, config, mimic_v1_6=True, checkpoint=True
+ )
+
+ # Check that results match
+ tols = dtype_tols(dtype)
+ if dtype in (torch.float16, torch.bfloat16):
+ tols.update(dict(rtol=2e-2, atol=2e-3))
+ for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
+ torch.testing.assert_close(
+ test,
+ ref,
+ **tols,
+ )
+ for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
+ torch.testing.assert_close(
+ test,
+ ref,
+ **tols,
+ )
+
+
+def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
+ steps = 10
+ path = "checkpoint.pt"
+ fp8_enabled = True
+ fp8_recipe = recipe.DelayedScaling(
+ margin=0,
+ fp8_format=recipe.Format.HYBRID,
+ amax_history_len=1,
+ amax_compute_algo="most_recent",
+ fp8_dpa=fp8_enabled,
+ fp8_mha=False,
+ )
+
+ reset_rng_states()
+ hidden_states = torch.randn(
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
+ dtype=dtype,
+ device="cuda",
+ requires_grad=True,
+ )
+
+ def get_model(dtype, config):
+ sigma = 0.023
+ init_method = init_method_normal(sigma)
+ output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
+
+ with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
+ block = TransformerLayer(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ config.num_heads,
+ init_method=init_method,
+ output_layer_init_method=output_layer_init_method,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ fuse_qkv_params=True,
+ params_dtype=dtype,
+ device="cuda",
+ )
+ return block
+
+ block = get_model(dtype, config)
+ for i in range(steps // 2):
+ with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
+ output = block(hidden_states, None)
+ loss = output.sum()
+ loss.backward()
+
+ if checkpoint:
+ sd = block.state_dict()
+ if mimic_v1_6:
+ sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
+ "self_attention.core_attention._extra_state"
+ ]
+ del sd["self_attention.core_attention._extra_state"]
+ torch.save(sd, path)
+
+ param_grads = []
+ for p in block.parameters():
+ if p.requires_grad:
+ param_grads.append(p.grad.clone())
+
+ _cpu_rng_state_new = torch.get_rng_state()
+ _cuda_rng_state_new = torch.cuda.get_rng_state()
+
+ del block
+ block = get_model(dtype, config)
+ block.load_state_dict(torch.load(path, weights_only=False))
+ torch.set_rng_state(_cpu_rng_state_new)
+ torch.cuda.set_rng_state(_cuda_rng_state_new)
+
+ for p in block.parameters():
+ if p.requires_grad:
+ p.grad = param_grads.pop(0)
+
+ assert not param_grads, "Oops!"
+
+ for i in range((steps + 1) // 2):
+ with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
+ output = block(hidden_states, None)
+ loss = output.sum()
+ loss.backward()
+
+ torch.cuda.synchronize()
+
+ if os.path.exists(path):
+ os.remove(path)
+
+ outputs = [output, hidden_states.grad]
+ for p in block.parameters():
+ if p.requires_grad:
+ outputs.append(p.grad)
+
+ return outputs
+
+
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
@@ -1554,18 +1546,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
- if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
- 9,
- 7,
- 0,
- ):
- pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
- if (
- FlashAttentionUtils.v3_is_installed
- and not is_training
- and "padding" not in config.attn_mask_type
- ):
+ # Test backend availability
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=torch.float8_e4m3fn,
+ qkv_layout=qkv_format.replace("hd", "h3d"),
+ is_training=is_training,
+ )
+ flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
+ # Skip if only unfused backend is supported
+ if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
+ pytest.skip("Less than two backends to compare.")
+ if not fp8_dpa_bwd:
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=dtype,
+ qkv_layout=qkv_format.replace("hd", "h3d"),
+ is_training=is_training,
+ )
+ _, fused_attn_supported, _ = available_backends
+ if not fused_attn_supported:
+ pytest.skip("No attention backend available.")
+
+ if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
@@ -1591,11 +1595,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
- if (
- FlashAttentionUtils.v3_is_installed
- and not is_training
- and "padding" not in config.attn_mask_type
- ):
+ if flash_attn_supported:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
@@ -1768,23 +1768,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
# if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1
- if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
- 9,
- 7,
- 0,
- ):
- pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
- if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
- pytest.skip("qkv_layout not applicable for MQA/GQA")
-
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
- if (
- FlashAttentionUtils.v3_is_installed
- and not is_training
- and "padding" not in config.attn_mask_type
- ):
+ # Test backend availability
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=torch.float8_e4m3fn,
+ qkv_layout=qkv_layout,
+ is_training=is_training,
+ )
+ flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
+ # Skip if only unfused backend is supported
+ if flash_attn_supported + fused_attn_supported < 1:
+ pytest.skip("No FP8 attention backend available.")
+ if not fp8_dpa_bwd:
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=dtype,
+ qkv_layout=qkv_layout,
+ is_training=is_training,
+ )
+ _, fused_attn_supported, _ = available_backends
+ if not fused_attn_supported:
+ pytest.skip("No attention backend available.")
+ if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
+ pytest.skip("qkv_layout not applicable for MQA/GQA")
+
+ if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
@@ -1813,11 +1824,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
- if (
- FlashAttentionUtils.v3_is_installed
- and not is_training
- and "padding" not in config.attn_mask_type
- ):
+ if flash_attn_supported:
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
@@ -2027,6 +2034,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]
+ # Test backend availability
+ is_training = True
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=torch.float8_e4m3fn,
+ qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
+ is_training=is_training,
+ )
+ flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
+ if not (fused_attn_backends and unfused_attn_supported):
+ pytest.skip("Not enough backends to run this test with.")
+
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py
similarity index 92%
rename from tests/pytorch/fused_attn/test_fused_attn_with_cp.py
rename to tests/pytorch/attention/test_attention_with_cp.py
index 458070c9b0..fe21568e55 100644
--- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+++ b/tests/pytorch/attention/test_attention_with_cp.py
@@ -4,6 +4,8 @@
import os
import subprocess
+import sys
+import pathlib
import pytest
import torch
@@ -12,7 +14,10 @@
get_cudnn_version,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
-from test_fused_attn import ModelConfig
+
+_current_file = pathlib.Path(__file__).resolve()
+sys.path.append(str(_current_file.parent.parent))
+from utils import ModelConfig, get_available_attention_backends
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
@@ -43,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
- script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
+ script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
@@ -175,6 +180,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
+ dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=dtypes[dtype],
+ qkv_layout="_".join([qkv_format] * 3),
+ window_size=config.window_size,
+ context_parallel=True,
+ )
+ _, fused_attn_supported, _ = available_backends
+ if not fused_attn_supported:
+ pytest.skip("No attention backend available.")
subprocess.run(
get_bash_arguments(
diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py
similarity index 98%
rename from tests/pytorch/fused_attn/test_kv_cache.py
rename to tests/pytorch/attention/test_kv_cache.py
index 9673094597..6ab538ddc2 100644
--- a/tests/pytorch/fused_attn/test_kv_cache.py
+++ b/tests/pytorch/attention/test_kv_cache.py
@@ -5,18 +5,14 @@
from collections import OrderedDict
from typing import List
import os
+import sys
+import pathlib
import logging
import math
import pytest
import torch
-from test_fused_attn import (
- ModelConfig,
- reset_rng_states,
- _get_attention_backends,
-)
-
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
@@ -34,13 +30,21 @@
is_bf16_compatible,
)
+_current_file = pathlib.Path(__file__).resolve()
+sys.path.append(str(_current_file.parent.parent))
+from utils import (
+ ModelConfig,
+ _rng_states,
+ reset_rng_states,
+ get_available_attention_backends,
+)
+
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
-_cpu_rng_state = torch.get_rng_state()
-_cuda_rng_state = torch.cuda.get_rng_state()
-
+_rng_states = None
+reset_rng_states()
param_types = [torch.float16]
if is_bf16_compatible():
@@ -470,7 +474,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
- available_backends, _, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py
index 87494f3c21..c4a3128d67 100644
--- a/tests/pytorch/test_cpu_offloading.py
+++ b/tests/pytorch/test_cpu_offloading.py
@@ -10,6 +10,8 @@
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
+from utils import ModelConfig, get_available_attention_backends
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@@ -22,10 +24,13 @@
recipe.DelayedScaling(),
]
-SIZE = 512
-NUM_HEADS = 8
-NUM_LAYERS = 5
-EPSILON = 0.1
+model_config = {
+ "small": ModelConfig(8, 8, 8, 64, 512, 512, num_layers=5, eps=0.1),
+}
+SIZE = model_config["small"].hidden_size
+NUM_HEADS = model_config["small"].num_heads
+NUM_LAYERS = model_config["small"].num_layers
+EPSILON = model_config["small"].eps
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
@@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
+ if model_key in ["multihead_attention", "transformer_layer"]:
+ available_backends, *_ = get_available_attention_backends(
+ model_config["small"],
+ qkv_dtype=torch.bfloat16,
+ qkv_layout="sbhd_sbhd_sbhd",
+ )
+ _, fused_attn_supported, _ = available_backends
+ if not fused_attn_supported:
+ pytest.skip("Fused attention backend not available.")
+ os.environ["NVTE_FLASH_ATTN"] = "0"
+ _attention_backends["backend_selection_requires_update"] = True
+
without_offloading = _measure_memory_between_forward_and_backward(
models_list, fp8_recipe, False
)
diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py
index 7bfe506f26..5b4e8f8bbc 100644
--- a/tests/pytorch/test_cuda_graphs.py
+++ b/tests/pytorch/test_cuda_graphs.py
@@ -23,6 +23,7 @@
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
+from utils import ModelConfig, _rng_states, reset_rng_states
# Check if FP8 is supported.
@@ -37,22 +38,12 @@
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
-_cpu_rng_state = torch.get_rng_state()
-_cuda_rng_state = torch.cuda.get_rng_state()
+_rng_states = None
+reset_rng_states()
-
-@dataclass
-class ModelConfig:
- """Data tensor dimensions within Transformer model"""
-
- sequence_length: int
- batch_size: int
- hidden_size: int
- num_heads: int
- kv_channels: int
-
-
-model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
+model_configs = {
+ "small": ModelConfig(32, 2, 2, 32, 2, 2),
+}
fp8_recipes = [
recipe.DelayedScaling(),
@@ -67,18 +58,6 @@ class ModelConfig:
dtypes.append(torch.bfloat16)
-def reset_rng_states() -> None:
- """Revert to initial RNG state."""
- torch.set_rng_state(_cpu_rng_state)
- torch.cuda.set_rng_state(_cuda_rng_state)
-
-
-@pytest.fixture(autouse=True)
-def reset_global_fp8_state():
- yield
- FP8GlobalStateManager.reset()
-
-
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Check that two lists of tensors match exactly."""
assert len(l1) == len(l2), "Unequal number of outputs."
@@ -107,7 +86,7 @@ def generate_data(
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
return gen_func(
- model_config.sequence_length,
+ model_config.max_seqlen_q,
model_config.batch_size,
model_config.hidden_size,
device="cuda",
@@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention(
gen_func = torch.ones if warmup else torch.randn
return [
gen_func(
- model_config.sequence_length,
+ model_config.max_seqlen_q,
model_config.batch_size,
model_config.num_heads,
model_config.kv_channels,
@@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config.batch_size,
1,
- model_config.sequence_length,
- model_config.sequence_length,
+ model_config.max_seqlen_q,
+ model_config.max_seqlen_kv,
),
dtype=torch.bool,
device="cuda",
@@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs(
(
model_config.batch_size,
1,
- model_config.sequence_length,
- model_config.sequence_length,
+ model_config.max_seqlen_q,
+ model_config.max_seqlen_kv,
),
dtype=torch.bool,
device="cuda",
diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py
index 440be43a04..7934425d1f 100644
--- a/tests/pytorch/test_numerics.py
+++ b/tests/pytorch/test_numerics.py
@@ -40,11 +40,13 @@
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
+from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.common import recipe
import transformer_engine_torch as tex
+from utils import ModelConfig, _rng_states, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@@ -59,30 +61,18 @@
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
-_cpu_rng_state = torch.get_rng_state()
-_cuda_rng_state = torch.cuda.get_rng_state()
+_rng_states = None
+reset_rng_states()
torch._dynamo.config.recompile_limit = 16
-class ModelConfig:
- def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
- self.hidden_size = hidden_size
- self.eps = eps
- self.num_attention_heads = num_attention_heads
- self.embed = embed
- self.num_layers = num_layers
- self.seq_len = seq_len
-
-
model_configs = {
- "small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
- "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
+ "small": ModelConfig(1, 8, 8, 16, 128, 128, num_layers=4),
+ "126m": ModelConfig(1, 12, 12, 64, 2048, 2048, num_layers=12),
}
-
model_configs_inference = {
- # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
- "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
+ "126m": ModelConfig(1, 12, 12, 64, 256, 256, num_layers=12),
}
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
@@ -124,6 +114,18 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq
]
+def is_fused_attn_available(
+ config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
+):
+ available_backends, _, fused_attn_backends = get_available_attention_backends(
+ config,
+ qkv_dtype=dtype,
+ qkv_layout=qkv_layout,
+ is_training=is_training,
+ )
+ return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
+
+
def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
@@ -173,18 +175,6 @@ def assert_allclose(
raise AssertionError(msg)
-def reset_rng_states() -> None:
- """revert back to initial RNG state."""
- torch.set_rng_state(_cpu_rng_state)
- torch.cuda.set_rng_state(_cuda_rng_state)
-
-
-@pytest.fixture(autouse=True)
-def reset_global_fp8_state():
- yield
- FP8GlobalStateManager.reset()
-
-
class TorchScaledMaskedSoftmax(nn.Module):
def __init__(self) -> None:
super().__init__()
@@ -531,13 +521,13 @@ def _test_e2e_selective_recompute(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
@@ -546,13 +536,13 @@ def _test_e2e_selective_recompute(
)
te_inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
- te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
+ te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
te_out = block(
@@ -626,13 +616,13 @@ def _test_e2e_full_recompute(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
@@ -641,14 +631,14 @@ def _test_e2e_full_recompute(
)
te_inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=use_reentrant,
)
if use_reentrant:
te_inp_hidden_states.retain_grad()
- te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
+ te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if recompute:
@@ -757,13 +747,13 @@ def _test_e2e_checkpointing_get_model(config, dtype):
return TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
@@ -775,7 +765,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
reset_rng_states()
te_inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
@@ -805,14 +795,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
if p.requires_grad:
param_grads.append(p.grad.clone())
- global _cpu_rng_state, _cuda_rng_state
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False))
- reset_rng_states()
+ torch.set_rng_state(_cpu_rng_state)
+ torch.cuda.set_rng_state(_cuda_rng_state)
for p in block.parameters():
if p.requires_grad:
@@ -845,6 +835,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
+ if not is_fused_attn_available(config, dtype):
+ pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
@@ -865,13 +857,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
reset_rng_states()
inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
- inp_attn_mask = get_causal_attn_mask(config.seq_len)
+ inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
out = block(inp_hidden_states, attention_mask=inp_attn_mask)
loss = out.sum()
@@ -891,11 +883,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
+ if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
+ pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
- num_attention_heads=config.num_attention_heads,
+ num_attention_heads=config.num_heads,
layernorm_epsilon=config.eps,
attention_dropout=0.1,
hidden_dropout=0.1,
@@ -910,7 +904,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
TorchGPT(
config.hidden_size,
config.eps,
- config.num_attention_heads,
+ config.num_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
@@ -971,13 +965,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
reset_rng_states()
inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
- inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
+ inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None
forward_kwargs = {}
if te:
@@ -1002,10 +996,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
+ if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
+ pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
fuse_qkv_params=True,
params_dtype=dtype,
qkv_weight_interleaved=False,
@@ -1016,7 +1012,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
torch_mha = (
TorchMHA(
config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
)
.to(dtype=dtype)
.cuda()
@@ -1062,7 +1058,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False,
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
@@ -1094,11 +1090,12 @@ def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
mask = torch.triu(
- torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
+ torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"),
+ diagonal=1,
)
query, key, value = [
torch.randn(
- (config.seq_len, bs, config.num_attention_heads, config.embed),
+ (config.max_seqlen_q, bs, config.num_heads, config.kv_channels),
dtype=dtype,
device="cuda",
requires_grad=True,
@@ -1127,8 +1124,8 @@ def test_dpa_accuracy(dtype, bs, model):
te_dpa = (
DotProductAttention(
- config.num_attention_heads,
- config.embed,
+ config.num_heads,
+ config.kv_channels,
attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
@@ -1137,7 +1134,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = (
TorchDotProductAttention(
- config.embed,
+ config.kv_channels,
0.0, # dropout
)
.to(dtype=dtype)
@@ -1726,7 +1723,7 @@ def _test_grouped_linear_accuracy(
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
@@ -1739,14 +1736,14 @@ def _test_grouped_linear_accuracy(
split_size = 16
if recipe.mxfp8():
split_size = 128
- m = config.seq_len // split_size
+ m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist)
m_splits = m_splits * split_size
- assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms
+ assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms
else:
- m_splits = torch.tensor([config.seq_len])
+ m_splits = torch.tensor([config.max_seqlen_q])
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, GroupedLinear):
@@ -1812,7 +1809,7 @@ def test_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
- if config.seq_len % 16 != 0 and fp8:
+ if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
@@ -2064,14 +2061,14 @@ def _generate_random_numbers(n, total_sum):
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn(
- (config.seq_len * bs, config.hidden_size),
+ (config.max_seqlen_q * bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_hidden_states.retain_grad()
- m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs)
+ m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs)
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(block, TorchGroupedLinearWithPadding):
@@ -2124,7 +2121,7 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
- if config.seq_len % 16 != 0 and fp8:
+ if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
@@ -2258,9 +2255,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
# Placeholders used for graph capture.
static_input = torch.randn(
- config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
+ config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
+ )
+ static_target = torch.randn(
+ config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype
)
- static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
@@ -2324,7 +2323,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
block_args = (
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
@@ -2332,7 +2331,7 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
@@ -2367,13 +2366,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
attention_dropout=0.1,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
params_dtype=dtype,
@@ -2382,13 +2381,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
)
te_inp_hidden_states = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_hidden_states.retain_grad()
- te_inp_attn_mask = get_causal_attn_mask(config.seq_len)
+ te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q)
with fp8_autocast(enabled=True, fp8_recipe=recipe):
te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask)
@@ -2451,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_sbhd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
@@ -2472,13 +2471,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_bshd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
@@ -2490,13 +2489,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
block_thd = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0,
attention_dropout=0,
- kv_channels=config.embed,
+ kv_channels=config.kv_channels,
params_dtype=dtype,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
@@ -2511,15 +2510,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical"
x_sbhd = torch.randn(
- (config.seq_len, bs, config.hidden_size),
+ (config.max_seqlen_q, bs, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
x_bshd = x_sbhd.transpose(0, 1).contiguous()
- x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous()
- x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len
+ x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous()
+ x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q
# To make sure forward is also identical (just in case some module decides
# to act fancy)
@@ -2546,165 +2545,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
x_thd,
cu_seqlens_q=x_thd_cumsum,
cu_seqlens_kv=x_thd_cumsum,
- max_seqlen_q=config.seq_len,
- max_seqlen_kv=config.seq_len,
+ max_seqlen_q=config.max_seqlen_q,
+ max_seqlen_kv=config.max_seqlen_kv,
)
torch.testing.assert_close(
y_bshd,
- y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
- )
-
-
-@pytest.mark.parametrize("dtype", param_types)
-@pytest.mark.parametrize("bs", batch_sizes)
-@pytest.mark.parametrize("model_key", model_configs_inference.keys())
-@pytest.mark.parametrize("use_RoPE", all_boolean)
-@pytest.mark.parametrize("input_format", input_formats_inference)
-@pytest.mark.parametrize("module", module_inference)
-@pytest.mark.parametrize("backend", backends_inference)
-@pytest.mark.parametrize("is_paged", [False, True])
-def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
- reset_rng_states()
-
- if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
- pytest.skip("FusedAttention and FlashAttention do not support FP32")
- if use_RoPE:
- pytest.skip("KV cache does not support starting positions for RoPE")
- if (
- backend == "FusedAttention"
- and get_device_compute_capability() == (8, 9)
- and get_cudnn_version() < (9, 12, 0)
- ):
- pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
-
- os.environ["NVTE_FLASH_ATTN"] = "0"
- os.environ["NVTE_FUSED_ATTN"] = "0"
- os.environ["NVTE_UNFUSED_ATTN"] = "0"
-
- if backend == "FlashAttention":
- os.environ["NVTE_FLASH_ATTN"] = "1"
- elif backend == "FusedAttention":
- os.environ["NVTE_FUSED_ATTN"] = "1"
- elif backend == "UnfusedAttention":
- os.environ["NVTE_UNFUSED_ATTN"] = "1"
-
- config = model_configs_inference[model_key]
-
- S = config.seq_len
- B = bs
- H = config.num_attention_heads
- D = config.hidden_size
- head_size = config.embed
- layer_number = 1
-
- # Limits the max size of KV-cache
- B_max = B
- S_max = S
-
- if module == "TransformerLayer":
- model = TransformerLayer(
- hidden_size=D,
- ffn_hidden_size=4 * D,
- num_attention_heads=H,
- attn_input_format=input_format,
- self_attn_mask_type="causal",
- enc_dec_attn_mask_type="causal",
- layer_number=layer_number,
- attention_dropout=0.0,
- params_dtype=dtype,
- device="cuda",
- ).eval()
- else:
- model = (
- MultiheadAttention(
- hidden_size=D,
- num_attention_heads=H,
- qkv_format=input_format,
- layer_number=layer_number,
- attention_dropout=0.0,
- attn_mask_type="causal",
- params_dtype=dtype,
- )
- .cuda()
- .eval()
- )
-
- inference_params = InferenceParams(
- max_batch_size=B_max,
- max_sequence_length=S_max,
- num_heads_kv=H,
- head_dim_k=head_size,
- dtype=dtype,
- is_paged=is_paged,
- total_num_pages=int(B_max * S_max / 256),
- page_size=256,
- )
-
- rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
-
- input = torch.randn((S, B, D), dtype=dtype, device="cuda")
- if input_format == "bshd":
- input = input.transpose(0, 1).contiguous()
-
- incremental_output = torch.zeros_like(input)
-
- # Generate output for the entire sequence
- full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
-
- # Incrementaly generate outputs using KV-cache
- step_dict = OrderedDict(zip(list(range(B)), [1] * B))
- for i in range(S):
- inference_params.pre_step(step_dict)
-
- if input_format == "sbhd":
- incremental_input = input[i].view(1, B, D)
- else:
- incremental_input = input[:, i, :].view(B, 1, D)
-
- seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
- cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
- cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
- cu_seqlens_kv = cu_seqlens_q.clone()
-
- mask_type = "padding"
- kwargs = {}
- if module == "TransformerLayer":
- kwargs["self_attn_mask_type"] = mask_type
- else:
- kwargs["attn_mask_type"] = mask_type
- line_output = model(
- hidden_states=incremental_input,
- inference_params=inference_params,
- rotary_pos_emb=rotary_freqs if use_RoPE else None,
- **kwargs,
- max_seqlen_q=1,
- max_seqlen_kv=S,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_kv=cu_seqlens_kv,
+ y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
)
- if input_format == "sbhd":
- incremental_output[i, :, :] = line_output.view(B, D)
- else:
- incremental_output[:, i, :] = line_output.view(B, D)
-
- if module == "TransformerLayer":
- atol = {
- torch.float32: 5e-3,
- torch.half: 5e-3,
- torch.bfloat16: 5e-2,
- }
- else:
- atol = {
- torch.float32: 1e-3,
- torch.half: 1e-3,
- torch.bfloat16: 1e-2,
- }
-
- # Check if the fully generated output matches the one generated incrementally
- assert_allclose(full_output, incremental_output, atol[dtype])
-
@pytest.mark.parametrize(
"shape",
diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py
index 00dff53da0..473de29887 100644
--- a/tests/pytorch/test_sanity.py
+++ b/tests/pytorch/test_sanity.py
@@ -46,7 +46,7 @@
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint
-from utils import dtype_tols
+from utils import ModelConfig, dtype_tols
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@@ -59,8 +59,6 @@
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
-_cpu_rng_state = torch.get_rng_state()
-_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
@@ -105,37 +103,19 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
return torch.min(amax_history, dim=0).values
-def reset_rng_states() -> None:
- """revert back to initial RNG state."""
- global _cpu_rng_state, _cuda_rng_state
- torch.set_rng_state(_cpu_rng_state)
- torch.cuda.set_rng_state(_cuda_rng_state)
-
-
-@dataclass
-class ModelConfig:
- """Transformer model configuration"""
-
- num_layers: int
- seq_len: int
- batch_size: int
- hidden_size: int
- num_attention_heads: int
- kv_channels: Optional[int] = None
-
- def is_fp8_supported(self):
- if self.seq_len * self.batch_size % 16:
- return False
- if self.hidden_size % 16:
- return False
- return True
+def is_fp8_supported(config: ModelConfig):
+ if config.max_seqlen_q * config.batch_size % 16 or config.max_seqlen_q * config.batch_size % 16:
+ return False
+ if config.hidden_size % 16 or config.hidden_size_kv % 16:
+ return False
+ return True
model_configs = {
- "126m": ModelConfig(12, 2048, 2, 768, 12),
- "small": ModelConfig(2, 32, 2, 64, 2),
- "weird": ModelConfig(2, 37, 3, 69, 3),
- "large": ModelConfig(1, 128, 2, 512, 4, 128),
+ "126m": ModelConfig(2, 12, 12, 64, 2048, 2048, num_layers=12),
+ "small": ModelConfig(2, 2, 2, 32, 32, 32, num_layers=2),
+ "weird": ModelConfig(3, 3, 3, 23, 37, 37, num_layers=2),
+ "large": ModelConfig(2, 4, 4, 128, 128, 128, num_layers=1),
}
fp8_recipes = [
@@ -171,12 +151,6 @@ def _disable_wgrads(block):
p.requires_grad = False
-@pytest.fixture(autouse=True)
-def reset_global_fp8_state():
- yield
- FP8GlobalStateManager.reset()
-
-
def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Initialize loss function and optimizer.
loss_fn = torch.nn.MSELoss()
@@ -184,7 +158,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
# Placeholders used for capture.
static_input = torch.randn(
- config.seq_len,
+ config.max_seqlen_q,
config.batch_size,
config.hidden_size,
device="cuda",
@@ -192,7 +166,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
requires_grad=True,
)
static_target = torch.randn(
- config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
+ config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
)
real_input = torch.rand_like(static_input)
@@ -236,7 +210,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=torch.float32,
device="cuda",
requires_grad=True,
@@ -244,7 +218,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states.retain_grad()
te_inp_attn_mask = torch.randint(
2,
- (1, 1, config.seq_len, config.seq_len),
+ (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
@@ -271,14 +245,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
- (1, 1, config.seq_len, config.seq_len),
+ (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
@@ -311,7 +285,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
@@ -337,7 +311,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
@@ -345,7 +319,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_attn_mask = torch.randint(
2,
- (config.batch_size, 1, 1, config.seq_len),
+ (config.batch_size, 1, 1, config.max_seqlen_q),
dtype=torch.bool,
device="cuda",
)
@@ -363,21 +337,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=True,
)
te_inp_attn_mask = torch.randint(
2,
- (1, 1, config.seq_len, config.seq_len),
+ (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
enc_dec_attn_mask = torch.randint(
2,
- (config.batch_size, 1, 1, config.seq_len),
+ (config.batch_size, 1, 1, config.max_seqlen_kv),
dtype=torch.bool,
device="cuda",
)
@@ -405,7 +379,7 @@ def _test_sanity_common(
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
device="cuda",
requires_grad=not skip_dgrad,
@@ -433,7 +407,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
+ (config.max_seqlen_q, config.batch_size, config.hidden_size),
device="cuda",
requires_grad=True,
)
@@ -494,7 +468,7 @@ def test_sanity_layernorm_linear(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -528,7 +502,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -555,7 +529,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
- num_tokens = bs * config.seq_len
+ num_tokens = bs * config.max_seqlen_q
if fp8_recipe is not None:
if not fp8_available:
@@ -564,7 +538,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
@@ -600,7 +574,7 @@ def test_sanity_grouped_linear(
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
- num_tokens = bs * config.seq_len * (num_gemms - 1)
+ num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if fp8_recipe is not None:
if not fp8_available:
@@ -609,7 +583,7 @@ def test_sanity_grouped_linear(
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
@@ -621,7 +595,7 @@ def test_sanity_grouped_linear(
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
- m_splits = [bs * config.seq_len] * num_gemms
+ m_splits = [bs * config.max_seqlen_q] * num_gemms
if empty_split == "first":
m_splits[0] = 0
elif empty_split == "last":
@@ -665,7 +639,7 @@ def test_sanity_layernorm_mlp(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -719,7 +693,7 @@ def test_sanity_gpt(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -729,7 +703,7 @@ def test_sanity_gpt(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -788,7 +762,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -798,7 +772,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -849,7 +823,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -859,7 +833,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -908,7 +882,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -918,7 +892,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -945,7 +919,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -955,7 +929,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -985,7 +959,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -995,7 +969,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -1028,7 +1002,7 @@ def test_sanity_gradient_accumulation_fusion(
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -1038,7 +1012,7 @@ def test_sanity_gradient_accumulation_fusion(
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -1074,7 +1048,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
- if not config.is_fp8_supported():
+ if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
sigma = 0.023
@@ -1084,7 +1058,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
- config.num_attention_heads,
+ config.num_heads,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
hidden_dropout=0.1,
@@ -1156,133 +1130,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
torch.cuda.synchronize()
-@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
-@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
-@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
-@pytest.mark.parametrize("model", ["large"])
-@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-def test_sanity_attention_extra_state(model, dtype):
- config = model_configs[model]
- outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
- outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
- outputs_checkpoint_v1_6 = _run_attention_extra_state(
- dtype, config, mimic_v1_6=True, checkpoint=True
- )
-
- # Check that results match
- tols = dtype_tols(dtype)
- if dtype in (torch.float16, torch.bfloat16):
- tols.update(dict(rtol=2e-2, atol=2e-3))
- for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)):
- torch.testing.assert_close(
- test,
- ref,
- **tols,
- )
- for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)):
- torch.testing.assert_close(
- test,
- ref,
- **tols,
- )
-
-
-def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
- steps = 10
- path = "checkpoint.pt"
- fp8_enabled = True
- fp8_recipe = recipe.DelayedScaling(
- margin=0,
- fp8_format=recipe.Format.HYBRID,
- amax_history_len=1,
- amax_compute_algo="most_recent",
- fp8_dpa=fp8_enabled,
- fp8_mha=False,
- )
-
- reset_rng_states()
- hidden_states = torch.randn(
- (config.seq_len, config.batch_size, config.hidden_size),
- dtype=dtype,
- device="cuda",
- requires_grad=True,
- )
-
- def get_model(dtype, config):
- sigma = 0.023
- init_method = init_method_normal(sigma)
- output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
-
- with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe):
- block = TransformerLayer(
- config.hidden_size,
- 4 * config.hidden_size,
- config.num_attention_heads,
- init_method=init_method,
- output_layer_init_method=output_layer_init_method,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- fuse_qkv_params=True,
- params_dtype=dtype,
- device="cuda",
- )
- return block
-
- block = get_model(dtype, config)
- for i in range(steps // 2):
- with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
- output = block(hidden_states, None)
- loss = output.sum()
- loss.backward()
-
- if checkpoint:
- sd = block.state_dict()
- if mimic_v1_6:
- sd["self_attention.core_attention.fused_attention._extra_state"] = sd[
- "self_attention.core_attention._extra_state"
- ]
- del sd["self_attention.core_attention._extra_state"]
- torch.save(sd, path)
-
- param_grads = []
- for p in block.parameters():
- if p.requires_grad:
- param_grads.append(p.grad.clone())
-
- _cpu_rng_state_new = torch.get_rng_state()
- _cuda_rng_state_new = torch.cuda.get_rng_state()
-
- del block
- block = get_model(dtype, config)
- block.load_state_dict(torch.load(path, weights_only=False))
- torch.set_rng_state(_cpu_rng_state_new)
- torch.cuda.set_rng_state(_cuda_rng_state_new)
-
- for p in block.parameters():
- if p.requires_grad:
- p.grad = param_grads.pop(0)
-
- assert not param_grads, "Oops!"
-
- for i in range((steps + 1) // 2):
- with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe):
- output = block(hidden_states, None)
- loss = output.sum()
- loss.backward()
-
- torch.cuda.synchronize()
-
- if os.path.exists(path):
- os.remove(path)
-
- outputs = [output, hidden_states.grad]
- for p in block.parameters():
- if p.requires_grad:
- outputs.append(p.grad)
-
- return outputs
-
-
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_replace_raw_data_for_float8tensor():
"""Test the functionality of replace_raw_data"""
diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py
index 61ccfc6f29..52b3f6f3ab 100644
--- a/tests/pytorch/utils.py
+++ b/tests/pytorch/utils.py
@@ -4,12 +4,29 @@
from __future__ import annotations
+import logging
+import os
+from contextlib import contextmanager
+
+import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
+from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
+from transformer_engine.pytorch.attention.dot_product_attention.utils import (
+ get_attention_backend,
+ AttentionParams,
+ AttentionLogging,
+)
+from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
+
+# Initialize RNG state
+seed = 1234
+torch.manual_seed(seed)
+torch.cuda.manual_seed(seed)
def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
@@ -106,3 +123,184 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
raise ValueError(f"Unsupported quantization scheme ({name})")
+
+
+# Cached RNG state
+_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+
+
+def reset_rng_states() -> None:
+ """Revert to deterministic RNG state"""
+ global _rng_states
+ if _rng_states is None:
+ torch.manual_seed(1234)
+ torch.cuda.manual_seed(1234)
+ _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
+ else:
+ cpu_rng_state, cuda_rng_state = _rng_states
+ torch.set_rng_state(cpu_rng_state)
+ torch.cuda.set_rng_state(cuda_rng_state)
+
+
+@pytest.fixture(autouse=True)
+def reset_global_fp8_state():
+ yield
+ fp8.FP8GlobalStateManager.reset()
+
+
+class ModelConfig:
+ def __init__(
+ self,
+ batch_size: int,
+ num_heads: int,
+ num_gqa_groups: int,
+ head_dim_qk: int,
+ max_seqlen_q: int,
+ max_seqlen_kv: int,
+ dropout_p: float = 0.0,
+ attn_mask_type: str = "no_mask",
+ attn_bias_type: str = "no_bias",
+ head_dim_v: int = None,
+ alibi_type: str = "none",
+ num_layers: int = 1,
+ bias_shape: str = "1hss",
+ window_size: Tuple[int, int] = (-1, -1),
+ total_requests: int = None,
+ max_ctx_len: int = None,
+ eps: float = 1e-5,
+ ):
+ self.batch_size = batch_size
+ self.num_heads = num_heads
+ self.num_gqa_groups = num_gqa_groups
+ self.head_dim_qk = head_dim_qk
+ self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
+ if self.head_dim_qk == self.head_dim_v:
+ self.kv_channels = self.head_dim_qk
+ else:
+ self.kv_channels = (self.head_dim_qk, self.head_dim_v)
+ self.hidden_size = num_heads * head_dim_qk
+ self.hidden_size_kv = num_gqa_groups * self.head_dim_v
+ self.max_seqlen_q = max_seqlen_q
+ self.max_seqlen_kv = max_seqlen_kv
+ self.dropout_p = dropout_p
+ self.attn_mask_type = attn_mask_type
+ self.attn_bias_type = attn_bias_type
+ self.alibi_type = alibi_type
+ self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
+ self.num_layers = num_layers
+ self.bias_shape = bias_shape
+ self.window_size = window_size
+ self.total_requests = total_requests
+ self.max_ctx_len = max_ctx_len
+ self.eps = eps
+
+
+@contextmanager
+def logging_context(highest_level=logging.WARNING):
+ previous_level = logging.root.manager.disable
+ logging.disable(highest_level)
+ try:
+ yield
+ finally:
+ logging.disable(previous_level)
+
+
+def get_available_attention_backends(
+ config: ModelConfig,
+ qkv_dtype: torch.dtype,
+ qkv_layout: str,
+ window_size: Tuple[int, int] = (-1, -1),
+ pad_between_seqs: bool = False,
+ context_parallel: bool = False,
+ deterministic: bool = False,
+ fp8: bool = False,
+ fp8_meta: Optional[Dict[str, Any]] = None,
+ is_training: bool = True,
+ inference_params: Optional[InferenceParams] = None,
+) -> Tuple[List, List]:
+ """Check for all available attention backends that support a model configuration"""
+
+ os.environ["NVTE_FLASH_ATTN"] = "1"
+ os.environ["NVTE_FUSED_ATTN"] = "1"
+ os.environ["NVTE_UNFUSED_ATTN"] = "1"
+ _attention_backends["backend_selection_requires_update"] = True
+
+ alibi_slopes_shape = None
+ if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
+ if config.bias_shape == "1hss":
+ alibi_slopes_shape = [config.num_heads]
+ if config.bias_shape == "bhss":
+ alibi_slopes_shape = [config.batch_size, config.num_heads]
+
+ core_attention_bias_shape = (
+ config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
+ )
+ core_attention_bias_requires_grad = False
+ # d=256 is supported by cuDNN 9.0+ for inference but not training
+ if (
+ config.attn_bias_type == "post_scale_bias"
+ and config.head_dim_qk <= 128
+ and config.head_dim_v <= 128
+ ):
+ core_attention_bias_requires_grad = True
+
+ fused_attn_backends = []
+ available_backends = None
+ flash_attention_backend = None
+ fused_attention_backend = None
+
+ def test():
+ attention_params = AttentionParams(
+ qkv_dtype=qkv_dtype,
+ qkv_layout=qkv_layout,
+ batch_size=config.batch_size,
+ num_heads=config.num_heads,
+ num_gqa_groups=config.num_gqa_groups,
+ max_seqlen_q=config.max_seqlen_q,
+ max_seqlen_kv=config.max_seqlen_kv,
+ head_dim_qk=config.head_dim_qk,
+ head_dim_v=config.head_dim_v,
+ attn_mask_type=config.attn_mask_type,
+ window_size=window_size,
+ alibi_slopes_shape=alibi_slopes_shape,
+ core_attention_bias_type=config.attn_bias_type,
+ core_attention_bias_shape=core_attention_bias_shape,
+ core_attention_bias_requires_grad=core_attention_bias_requires_grad,
+ pad_between_seqs=pad_between_seqs,
+ attention_dropout=config.dropout_p,
+ context_parallel=context_parallel,
+ deterministic=deterministic,
+ fp8=fp8,
+ fp8_meta=fp8_meta,
+ is_training=is_training,
+ inference_params=inference_params,
+ )
+ (
+ use_flash_attention,
+ use_fused_attention,
+ flash_attention_backend,
+ fused_attention_backend,
+ use_unfused_attention,
+ available_backends,
+ ) = get_attention_backend(attention_params)
+ # Set attention.py _attention_backends var using return value
+ # from get_attention_backend()
+ _attention_backends["use_flash_attention"] = use_flash_attention
+ _attention_backends["use_fused_attention"] = use_fused_attention
+ _attention_backends["flash_attention_backend"] = flash_attention_backend
+ _attention_backends["fused_attention_backend"] = fused_attention_backend
+ _attention_backends["use_unfused_attention"] = use_unfused_attention
+ _attention_backends["backend_selection_requires_update"] = False
+ return available_backends, flash_attention_backend, fused_attention_backend
+
+ backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
+ if AttentionLogging._is_logging_setup is False:
+ AttentionLogging.setup_logging()
+ with logging_context(highest_level=AttentionLogging._log_level):
+ for i in range(3):
+ os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
+ _attention_backends["backend_selection_requires_update"] = True
+ available_backends, flash_attention_backend, fused_attention_backend = test()
+ if fused_attention_backend == FusedAttnBackend[backends[i]]:
+ fused_attn_backends.append(fused_attention_backend)
+ return available_backends, flash_attention_backend, fused_attn_backends
diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp
index b512133efd..11344ecc10 100644
--- a/transformer_engine/common/fused_attn/fused_attn.cpp
+++ b/transformer_engine/common/fused_attn/fused_attn.cpp
@@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
- !requires_64bit_ragged_offset) {
+ !requires_64bit_ragged_offset &&
+ // 9.10.0: known bugs with SDPA FP8
+ (cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
@@ -239,10 +241,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
- // 9.10: any head_dim + any arch + fprop + paged
- // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
- // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
- (!is_training && cudnn_runtime_version >= 91000 &&
+ // 9.10.2: any head_dim + any arch + fprop + paged
+ // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1
+ // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
+ (!is_training && cudnn_runtime_version >= 91002 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
@@ -354,7 +356,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) &&
// check 64-bit ragged offset support
- (supported_ragged_offset_size)) {
+ (supported_ragged_offset_size) &&
+ // 9.10.0/9.10.1: known bugs with SDPA F16
+ (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {