diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index dafafdff47..1df16cc016 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) @@ -197,7 +197,7 @@ def main(): ) for model in model_configs.keys(): config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index e9eec14d99..97f1bcd7ec 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -5,7 +5,7 @@ import os import torch from typing import Tuple -from tests.pytorch.fused_attn.test_fused_attn import ModelConfig +from tests.pytorch.utils import ModelConfig from transformer_engine.pytorch.attention import DotProductAttention # Initialize RNG state diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 53a5eede74..6cd56d23da 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -375,7 +375,7 @@ "\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "\n", - "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." + "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." ] }, { @@ -394,10 +394,10 @@ "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", - "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" + "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)" ] }, { @@ -458,7 +458,7 @@ " \n", "\n", "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
\n", "Note\n", @@ -548,7 +548,7 @@ "id": "dda4a589", "metadata": {}, "source": [ - "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", + "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", @@ -594,7 +594,7 @@ "\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "\n", - "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." + "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)." ] }, { @@ -612,7 +612,7 @@ "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." ] } ], diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 2c32e8b5f7..cf650265bc 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention # data type dtype = torch.bfloat16 @@ -90,7 +90,7 @@ def main(): models = ["test_0"] for model in models: config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 7fe439b37f..9a924282b5 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 09ef661c4a..f0436d4ff8 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -28,7 +28,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 547849e950..7e9616cd03 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -41,6 +41,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py done diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f9e5c8ad2e..29a9bc2b9f 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -372,7 +372,7 @@ def _check_configs(self): self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() - if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip("Unsupported inputs combination or device compute capability.") if ( diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py similarity index 99% rename from tests/pytorch/fused_attn/run_fused_attn_with_cp.py rename to tests/pytorch/attention/run_attention_with_cp.py index f1db30d992..0ad64204f7 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -13,7 +13,7 @@ get_cu_seqlens_on_cp_rank, ) import transformer_engine_torch as tex -from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/attention/test_attention.py similarity index 91% rename from tests/pytorch/fused_attn/test_fused_attn.py rename to tests/pytorch/attention/test_attention.py index a05e64fca3..2c7d86d3d6 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/attention/test_attention.py @@ -4,8 +4,9 @@ import logging import math import os +import sys +import pathlib from typing import Any, Dict, List, Tuple, Union, Optional -from contextlib import contextmanager import pytest import torch @@ -21,7 +22,6 @@ FlashAttentionUtils, get_attention_backend, check_set_window_size, - AttentionParams, ) from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import RotaryPositionEmbedding @@ -48,6 +48,17 @@ restore_from_saved, ) +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + _rng_states, + reset_rng_states, + ModelConfig, + dtype_tols, + logging_context, + get_available_attention_backends, +) + # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() @@ -55,171 +66,8 @@ seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -def reset_rng_states() -> None: - """Revert back to initial RNG state""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - fp8.FP8GlobalStateManager.reset() - - -class ModelConfig: - def __init__( - self, - batch_size: int, - num_heads: int, - num_gqa_groups: int, - head_dim_qk: int, - max_seqlen_q: int, - max_seqlen_kv: int, - dropout_p: float, - attn_mask_type: str, - attn_bias_type: str, - head_dim_v: int = None, - alibi_type: str = "none", - num_layers: int = 1, - bias_shape: str = "1hss", - window_size: Tuple[int, int] = (-1, -1), - total_requests: int = None, - max_ctx_len: int = None, - ): - self.batch_size = batch_size - self.num_heads = num_heads - self.num_gqa_groups = num_gqa_groups - self.head_dim_qk = head_dim_qk - self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v - self.hidden_size = num_heads * head_dim_qk - self.hidden_size_kv = num_gqa_groups * self.head_dim_v - self.max_seqlen_q = max_seqlen_q - self.max_seqlen_kv = max_seqlen_kv - self.dropout_p = dropout_p - self.attn_mask_type = attn_mask_type - self.attn_bias_type = attn_bias_type - self.alibi_type = alibi_type - self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" - self.num_layers = num_layers - self.bias_shape = bias_shape - self.window_size = window_size - self.total_requests = total_requests - self.max_ctx_len = max_ctx_len - - -@contextmanager -def logging_context(highest_level=logging.WARNING): - previous_level = logging.root.manager.disable - logging.disable(highest_level) - try: - yield - finally: - logging.disable(previous_level) - - -def _get_attention_backends( - config: ModelConfig, - qkv_dtype: torch.dtype, - qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), - pad_between_seqs: bool = False, - context_parallel: bool = False, - deterministic: bool = False, - fp8: bool = False, - fp8_meta: Optional[Dict[str, Any]] = None, - is_training: bool = True, - inference_params: Optional[InferenceParams] = None, -) -> Tuple[List, List]: - """Check if what attention backends support a model configuration""" - - os.environ["NVTE_FLASH_ATTN"] = "1" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - - alibi_slopes_shape = None - if config.attn_bias_type == "alibi" and config.alibi_type == "custom": - if config.bias_shape == "1hss": - alibi_slopes_shape = [config.num_heads] - if config.bias_shape == "bhss": - alibi_slopes_shape = [config.batch_size, config.num_heads] - - core_attention_bias_shape = ( - config.bias_shape if config.attn_bias_type == "post_scale_bias" else None - ) - core_attention_bias_requires_grad = False - # d=256 is supported by cuDNN 9.0+ for inference but not training - if ( - config.attn_bias_type == "post_scale_bias" - and config.head_dim_qk <= 128 - and config.head_dim_v <= 128 - ): - core_attention_bias_requires_grad = True - - fused_attn_backends = [] - available_backends = None - flash_attention_backend = None - fused_attention_backend = None - - def test(): - attention_params = AttentionParams( - qkv_dtype=qkv_dtype, - qkv_layout=qkv_layout, - batch_size=config.batch_size, - num_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - max_seqlen_q=config.max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - head_dim_qk=config.head_dim_qk, - head_dim_v=config.head_dim_v, - attn_mask_type=config.attn_mask_type, - window_size=window_size, - alibi_slopes_shape=alibi_slopes_shape, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias_shape=core_attention_bias_shape, - core_attention_bias_requires_grad=core_attention_bias_requires_grad, - pad_between_seqs=pad_between_seqs, - attention_dropout=config.dropout_p, - context_parallel=context_parallel, - deterministic=deterministic, - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - inference_params=inference_params, - ) - ( - use_flash_attention, - use_fused_attention, - flash_attention_backend, - fused_attention_backend, - use_unfused_attention, - available_backends, - ) = get_attention_backend(attention_params) - # Set attention.py _attention_backends var using return value - # from get_attention_backend() - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["flash_attention_backend"] = flash_attention_backend - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - return available_backends, flash_attention_backend, fused_attention_backend - - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) - return available_backends, flash_attention_backend, fused_attn_backends - +_rng_states = None +reset_rng_states() model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -278,7 +126,7 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) is_training = True - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -289,7 +137,7 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -1168,7 +1016,7 @@ def test_transformer_layer( # Test backend availability is_training = True - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1179,7 +1027,7 @@ def test_transformer_layer( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1492,6 +1340,150 @@ def _run_transformer_layer( return out, inp.grad +model_configs_fp8_extra_state = { + "large": ModelConfig(2, 4, 4, 128, 128, 128, 0.0, "no_mask", "no_bias", num_layers=1), +} + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") +@pytest.mark.parametrize("model", ["large"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sanity_attention_extra_state(model, dtype): + config = model_configs_fp8_extra_state[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="sb3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported and not flash_attn_supported: + pytest.skip("No attention backend available.") + + outputs = _run_attention_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_attention_extra_state( + dtype, config, mimic_v1_6=True, checkpoint=True + ) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + + +def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + steps = 10 + path = "checkpoint.pt" + fp8_enabled = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_enabled, + fp8_mha=False, + ) + + reset_rng_states() + hidden_states = torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + def get_model(dtype, config): + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + device="cuda", + ) + return block + + block = get_model(dtype, config) + for i in range(steps // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + if checkpoint: + sd = block.state_dict() + if mimic_v1_6: + sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ + "self_attention.core_attention._extra_state" + ] + del sd["self_attention.core_attention._extra_state"] + torch.save(sd, path) + + param_grads = [] + for p in block.parameters(): + if p.requires_grad: + param_grads.append(p.grad.clone()) + + _cpu_rng_state_new = torch.get_rng_state() + _cuda_rng_state_new = torch.cuda.get_rng_state() + + del block + block = get_model(dtype, config) + block.load_state_dict(torch.load(path, weights_only=False)) + torch.set_rng_state(_cpu_rng_state_new) + torch.cuda.set_rng_state(_cuda_rng_state_new) + + for p in block.parameters(): + if p.requires_grad: + p.grad = param_grads.pop(0) + + assert not param_grads, "Oops!" + + for i in range((steps + 1) // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + torch.cuda.synchronize() + + if os.path.exists(path): + os.remove(path) + + outputs = [output, hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + + return outputs + + model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), @@ -1554,18 +1546,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + pytest.skip("Less than two backends to compare.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1591,11 +1595,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1768,23 +1768,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): # if get_device_compute_capability() >= (10, 0): # config.dropout_p = 0.1 - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: - pytest.skip("qkv_layout not applicable for MQA/GQA") - os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_layout, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1813,11 +1824,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -2027,6 +2034,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model): config = model_configs_fp8[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not (fused_attn_backends and unfused_attn_supported): + pytest.skip("Not enough backends to run this test with.") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py similarity index 92% rename from tests/pytorch/fused_attn/test_fused_attn_with_cp.py rename to tests/pytorch/attention/test_attention_with_cp.py index 458070c9b0..fe21568e55 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -4,6 +4,8 @@ import os import subprocess +import sys +import pathlib import pytest import torch @@ -12,7 +14,10 @@ get_cudnn_version, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils -from test_fused_attn import ModelConfig + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig, get_available_attention_backends model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -43,7 +48,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "--nproc-per-node=" + str(num_gpus_per_node), ] te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") + script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) for k, v in kwargs.items(): args.append(f"{k}={v}") @@ -175,6 +180,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + window_size=config.window_size, + context_parallel=True, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py similarity index 98% rename from tests/pytorch/fused_attn/test_kv_cache.py rename to tests/pytorch/attention/test_kv_cache.py index 9673094597..6ab538ddc2 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -5,18 +5,14 @@ from collections import OrderedDict from typing import List import os +import sys +import pathlib import logging import math import pytest import torch -from test_fused_attn import ( - ModelConfig, - reset_rng_states, - _get_attention_backends, -) - from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe @@ -34,13 +30,21 @@ is_bf16_compatible, ) +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + ModelConfig, + _rng_states, + reset_rng_states, + get_available_attention_backends, +) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - +_rng_states = None +reset_rng_states() param_types = [torch.float16] if is_bf16_compatible(): @@ -470,7 +474,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) if is_paged: qkv_layout = "paged_kv_" + qkv_layout - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 87494f3c21..c4a3128d67 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -10,6 +10,8 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from utils import ModelConfig, get_available_attention_backends # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -22,10 +24,13 @@ recipe.DelayedScaling(), ] -SIZE = 512 -NUM_HEADS = 8 -NUM_LAYERS = 5 -EPSILON = 0.1 +model_config = { + "small": ModelConfig(8, 8, 8, 64, 512, 512, num_layers=5, eps=0.1), +} +SIZE = model_config["small"].hidden_size +NUM_HEADS = model_config["small"].num_heads +NUM_LAYERS = model_config["small"].num_layers +EPSILON = model_config["small"].eps # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. @@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if model_key in ["multihead_attention", "transformer_layer"]: + available_backends, *_ = get_available_attention_backends( + model_config["small"], + qkv_dtype=torch.bfloat16, + qkv_layout="sbhd_sbhd_sbhd", + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("Fused attention backend not available.") + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + without_offloading = _measure_memory_between_forward_and_backward( models_list, fp8_recipe, False ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 7bfe506f26..5b4e8f8bbc 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -23,6 +23,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe +from utils import ModelConfig, _rng_states, reset_rng_states # Check if FP8 is supported. @@ -37,22 +38,12 @@ seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +_rng_states = None +reset_rng_states() - -@dataclass -class ModelConfig: - """Data tensor dimensions within Transformer model""" - - sequence_length: int - batch_size: int - hidden_size: int - num_heads: int - kv_channels: int - - -model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} +model_configs = { + "small": ModelConfig(32, 2, 2, 32, 2, 2), +} fp8_recipes = [ recipe.DelayedScaling(), @@ -67,18 +58,6 @@ class ModelConfig: dtypes.append(torch.bfloat16) -def reset_rng_states() -> None: - """Revert to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - FP8GlobalStateManager.reset() - - def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: """Check that two lists of tensors match exactly.""" assert len(l1) == len(l2), "Unequal number of outputs." @@ -107,7 +86,7 @@ def generate_data( """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn return gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.hidden_size, device="cuda", @@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention( gen_func = torch.ones if warmup else torch.randn return [ gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.num_heads, model_config.kv_channels, @@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", @@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 440be43a04..7934425d1f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -40,11 +40,13 @@ from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version from transformer_engine.common import recipe import transformer_engine_torch as tex +from utils import ModelConfig, _rng_states, reset_rng_states, get_available_attention_backends # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -59,30 +61,18 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Record initial RNG state from script run. -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +_rng_states = None +reset_rng_states() torch._dynamo.config.recompile_limit = 16 -class ModelConfig: - def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): - self.hidden_size = hidden_size - self.eps = eps - self.num_attention_heads = num_attention_heads - self.embed = embed - self.num_layers = num_layers - self.seq_len = seq_len - - model_configs = { - "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), + "small": ModelConfig(1, 8, 8, 16, 128, 128, num_layers=4), + "126m": ModelConfig(1, 12, 12, 64, 2048, 2048, num_layers=12), } - model_configs_inference = { - # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), + "126m": ModelConfig(1, 12, 12, 64, 256, 256, num_layers=12), } backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] @@ -124,6 +114,18 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq ] +def is_fused_attn_available( + config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True +): + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends + + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -173,18 +175,6 @@ def assert_allclose( raise AssertionError(msg) -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - FP8GlobalStateManager.reset() - - class TorchScaledMaskedSoftmax(nn.Module): def __init__(self) -> None: super().__init__() @@ -531,13 +521,13 @@ def _test_e2e_selective_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -546,13 +536,13 @@ def _test_e2e_selective_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( @@ -626,13 +616,13 @@ def _test_e2e_full_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -641,14 +631,14 @@ def _test_e2e_full_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=use_reentrant, ) if use_reentrant: te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: @@ -757,13 +747,13 @@ def _test_e2e_checkpointing_get_model(config, dtype): return TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -775,7 +765,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= reset_rng_states() te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -805,14 +795,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= if p.requires_grad: param_grads.append(p.grad.clone()) - global _cpu_rng_state, _cuda_rng_state _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() del block block = _test_e2e_checkpointing_get_model(config, dtype) block.load_state_dict(torch.load(path, weights_only=False)) - reset_rng_states() + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) for p in block.parameters(): if p.requires_grad: @@ -845,6 +835,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] + if not is_fused_attn_available(config, dtype): + pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -865,13 +857,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() @@ -891,11 +883,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] + if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + pytest.skip("No attention backend available.") te_gpt = TransformerLayer( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size, - num_attention_heads=config.num_attention_heads, + num_attention_heads=config.num_heads, layernorm_epsilon=config.eps, attention_dropout=0.1, hidden_dropout=0.1, @@ -910,7 +904,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): TorchGPT( config.hidden_size, config.eps, - config.num_attention_heads, + config.num_heads, parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) @@ -971,13 +965,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None forward_kwargs = {} if te: @@ -1002,10 +996,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] + if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + pytest.skip("No attention backend available.") te_mha = MultiheadAttention( config.hidden_size, - config.num_attention_heads, + config.num_heads, fuse_qkv_params=True, params_dtype=dtype, qkv_weight_interleaved=False, @@ -1016,7 +1012,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): torch_mha = ( TorchMHA( config.hidden_size, - config.num_attention_heads, + config.num_heads, ) .to(dtype=dtype) .cuda() @@ -1062,7 +1058,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1094,11 +1090,12 @@ def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() mask = torch.triu( - torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1 + torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"), + diagonal=1, ) query, key, value = [ torch.randn( - (config.seq_len, bs, config.num_attention_heads, config.embed), + (config.max_seqlen_q, bs, config.num_heads, config.kv_channels), dtype=dtype, device="cuda", requires_grad=True, @@ -1127,8 +1124,8 @@ def test_dpa_accuracy(dtype, bs, model): te_dpa = ( DotProductAttention( - config.num_attention_heads, - config.embed, + config.num_heads, + config.kv_channels, attention_dropout=0.0, # disable dropout, FU uses rng differently ) .to(dtype=dtype) @@ -1137,7 +1134,7 @@ def test_dpa_accuracy(dtype, bs, model): torch_dpa = ( TorchDotProductAttention( - config.embed, + config.kv_channels, 0.0, # dropout ) .to(dtype=dtype) @@ -1726,7 +1723,7 @@ def _test_grouped_linear_accuracy( FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1739,14 +1736,14 @@ def _test_grouped_linear_accuracy( split_size = 16 if recipe.mxfp8(): split_size = 128 - m = config.seq_len // split_size + m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = m_splits * split_size - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms else: - m_splits = torch.tensor([config.seq_len]) + m_splits = torch.tensor([config.max_seqlen_q]) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): @@ -1812,7 +1809,7 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2064,14 +2061,14 @@ def _generate_random_numbers(n, total_sum): FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len * bs, config.hidden_size), + (config.max_seqlen_q * bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): @@ -2124,7 +2121,7 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2258,9 +2255,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): # Placeholders used for graph capture. static_input = torch.randn( - config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + ) + static_target = torch.randn( + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype ) - static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype) real_input = torch.rand_like(static_input) real_target = torch.rand_like(static_target) @@ -2324,7 +2323,7 @@ def test_gpt_cuda_graph(dtype, bs, model): block_args = ( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, ) block_kwargs = dict( layernorm_epsilon=config.eps, @@ -2332,7 +2331,7 @@ def test_gpt_cuda_graph(dtype, bs, model): output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2367,13 +2366,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -2382,13 +2381,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) @@ -2451,13 +2450,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_sbhd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2472,13 +2471,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_bshd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2490,13 +2489,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_thd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2511,15 +2510,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) x_bshd = x_sbhd.transpose(0, 1).contiguous() - x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() - x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len + x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -2546,165 +2545,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): x_thd, cu_seqlens_q=x_thd_cumsum, cu_seqlens_kv=x_thd_cumsum, - max_seqlen_q=config.seq_len, - max_seqlen_kv=config.seq_len, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, ) torch.testing.assert_close( y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), - ) - - -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", all_boolean) -@pytest.mark.parametrize("input_format", input_formats_inference) -@pytest.mark.parametrize("module", module_inference) -@pytest.mark.parametrize("backend", backends_inference) -@pytest.mark.parametrize("is_paged", [False, True]) -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): - reset_rng_states() - - if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: - pytest.skip("FusedAttention and FlashAttention do not support FP32") - if use_RoPE: - pytest.skip("KV cache does not support starting positions for RoPE") - if ( - backend == "FusedAttention" - and get_device_compute_capability() == (8, 9) - and get_cudnn_version() < (9, 12, 0) - ): - pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12") - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - - if backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - elif backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - elif backend == "UnfusedAttention": - os.environ["NVTE_UNFUSED_ATTN"] = "1" - - config = model_configs_inference[model_key] - - S = config.seq_len - B = bs - H = config.num_attention_heads - D = config.hidden_size - head_size = config.embed - layer_number = 1 - - # Limits the max size of KV-cache - B_max = B - S_max = S - - if module == "TransformerLayer": - model = TransformerLayer( - hidden_size=D, - ffn_hidden_size=4 * D, - num_attention_heads=H, - attn_input_format=input_format, - self_attn_mask_type="causal", - enc_dec_attn_mask_type="causal", - layer_number=layer_number, - attention_dropout=0.0, - params_dtype=dtype, - device="cuda", - ).eval() - else: - model = ( - MultiheadAttention( - hidden_size=D, - num_attention_heads=H, - qkv_format=input_format, - layer_number=layer_number, - attention_dropout=0.0, - attn_mask_type="causal", - params_dtype=dtype, - ) - .cuda() - .eval() - ) - - inference_params = InferenceParams( - max_batch_size=B_max, - max_sequence_length=S_max, - num_heads_kv=H, - head_dim_k=head_size, - dtype=dtype, - is_paged=is_paged, - total_num_pages=int(B_max * S_max / 256), - page_size=256, - ) - - rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") - - input = torch.randn((S, B, D), dtype=dtype, device="cuda") - if input_format == "bshd": - input = input.transpose(0, 1).contiguous() - - incremental_output = torch.zeros_like(input) - - # Generate output for the entire sequence - full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) - - # Incrementaly generate outputs using KV-cache - step_dict = OrderedDict(zip(list(range(B)), [1] * B)) - for i in range(S): - inference_params.pre_step(step_dict) - - if input_format == "sbhd": - incremental_input = input[i].view(1, B, D) - else: - incremental_input = input[:, i, :].view(B, 1, D) - - seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") - cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_kv = cu_seqlens_q.clone() - - mask_type = "padding" - kwargs = {} - if module == "TransformerLayer": - kwargs["self_attn_mask_type"] = mask_type - else: - kwargs["attn_mask_type"] = mask_type - line_output = model( - hidden_states=incremental_input, - inference_params=inference_params, - rotary_pos_emb=rotary_freqs if use_RoPE else None, - **kwargs, - max_seqlen_q=1, - max_seqlen_kv=S, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), ) - if input_format == "sbhd": - incremental_output[i, :, :] = line_output.view(B, D) - else: - incremental_output[:, i, :] = line_output.view(B, D) - - if module == "TransformerLayer": - atol = { - torch.float32: 5e-3, - torch.half: 5e-3, - torch.bfloat16: 5e-2, - } - else: - atol = { - torch.float32: 1e-3, - torch.half: 1e-3, - torch.bfloat16: 1e-2, - } - - # Check if the fully generated output matches the one generated incrementally - assert_allclose(full_output, incremental_output, atol[dtype]) - @pytest.mark.parametrize( "shape", diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 00dff53da0..473de29887 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,7 +46,7 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint -from utils import dtype_tols +from utils import ModelConfig, dtype_tols # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -59,8 +59,6 @@ seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) @@ -105,37 +103,19 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: return torch.min(amax_history, dim=0).values -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - global _cpu_rng_state, _cuda_rng_state - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@dataclass -class ModelConfig: - """Transformer model configuration""" - - num_layers: int - seq_len: int - batch_size: int - hidden_size: int - num_attention_heads: int - kv_channels: Optional[int] = None - - def is_fp8_supported(self): - if self.seq_len * self.batch_size % 16: - return False - if self.hidden_size % 16: - return False - return True +def is_fp8_supported(config: ModelConfig): + if config.max_seqlen_q * config.batch_size % 16 or config.max_seqlen_q * config.batch_size % 16: + return False + if config.hidden_size % 16 or config.hidden_size_kv % 16: + return False + return True model_configs = { - "126m": ModelConfig(12, 2048, 2, 768, 12), - "small": ModelConfig(2, 32, 2, 64, 2), - "weird": ModelConfig(2, 37, 3, 69, 3), - "large": ModelConfig(1, 128, 2, 512, 4, 128), + "126m": ModelConfig(2, 12, 12, 64, 2048, 2048, num_layers=12), + "small": ModelConfig(2, 2, 2, 32, 32, 32, num_layers=2), + "weird": ModelConfig(3, 3, 3, 23, 37, 37, num_layers=2), + "large": ModelConfig(2, 4, 4, 128, 128, 128, num_layers=1), } fp8_recipes = [ @@ -171,12 +151,6 @@ def _disable_wgrads(block): p.requires_grad = False -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - FP8GlobalStateManager.reset() - - def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): # Initialize loss function and optimizer. loss_fn = torch.nn.MSELoss() @@ -184,7 +158,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): # Placeholders used for capture. static_input = torch.randn( - config.seq_len, + config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", @@ -192,7 +166,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): requires_grad=True, ) static_target = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype + config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype ) real_input = torch.rand_like(static_input) @@ -236,7 +210,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=torch.float32, device="cuda", requires_grad=True, @@ -244,7 +218,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states.retain_grad() te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -271,14 +245,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -311,7 +285,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -337,7 +311,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -345,7 +319,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_q), dtype=torch.bool, device="cuda", ) @@ -363,21 +337,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) enc_dec_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -405,7 +379,7 @@ def _test_sanity_common( pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=not skip_dgrad, @@ -433,7 +407,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), device="cuda", requires_grad=True, ) @@ -494,7 +468,7 @@ def test_sanity_layernorm_linear( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -528,7 +502,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -555,7 +529,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size - num_tokens = bs * config.seq_len + num_tokens = bs * config.max_seqlen_q if fp8_recipe is not None: if not fp8_available: @@ -564,7 +538,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -600,7 +574,7 @@ def test_sanity_grouped_linear( ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 - num_tokens = bs * config.seq_len * (num_gemms - 1) + num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) if fp8_recipe is not None: if not fp8_available: @@ -609,7 +583,7 @@ def test_sanity_grouped_linear( pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -621,7 +595,7 @@ def test_sanity_grouped_linear( inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() - m_splits = [bs * config.seq_len] * num_gemms + m_splits = [bs * config.max_seqlen_q] * num_gemms if empty_split == "first": m_splits[0] = 0 elif empty_split == "last": @@ -665,7 +639,7 @@ def test_sanity_layernorm_mlp( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -719,7 +693,7 @@ def test_sanity_gpt( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -729,7 +703,7 @@ def test_sanity_gpt( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -788,7 +762,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -798,7 +772,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -849,7 +823,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -859,7 +833,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -908,7 +882,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -918,7 +892,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -945,7 +919,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -955,7 +929,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -985,7 +959,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -995,7 +969,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1028,7 +1002,7 @@ def test_sanity_gradient_accumulation_fusion( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1038,7 +1012,7 @@ def test_sanity_gradient_accumulation_fusion( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1074,7 +1048,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling(): pytest.skip("cuda graph not supported for float8_block_scaling recipe") - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1084,7 +1058,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1156,133 +1130,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): torch.cuda.synchronize() -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") -@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") -@pytest.mark.parametrize("model", ["large"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): - config = model_configs[model] - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( - dtype, config, mimic_v1_6=True, checkpoint=True - ) - - # Check that results match - tols = dtype_tols(dtype) - if dtype in (torch.float16, torch.bfloat16): - tols.update(dict(rtol=2e-2, atol=2e-3)) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - - -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): - steps = 10 - path = "checkpoint.pt" - fp8_enabled = True - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_enabled, - fp8_mha=False, - ) - - reset_rng_states() - hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - - def get_model(dtype, config): - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - - with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.0, - attention_dropout=0.0, - fuse_qkv_params=True, - params_dtype=dtype, - device="cuda", - ) - return block - - block = get_model(dtype, config) - for i in range(steps // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - if checkpoint: - sd = block.state_dict() - if mimic_v1_6: - sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ - "self_attention.core_attention._extra_state" - ] - del sd["self_attention.core_attention._extra_state"] - torch.save(sd, path) - - param_grads = [] - for p in block.parameters(): - if p.requires_grad: - param_grads.append(p.grad.clone()) - - _cpu_rng_state_new = torch.get_rng_state() - _cuda_rng_state_new = torch.cuda.get_rng_state() - - del block - block = get_model(dtype, config) - block.load_state_dict(torch.load(path, weights_only=False)) - torch.set_rng_state(_cpu_rng_state_new) - torch.cuda.set_rng_state(_cuda_rng_state_new) - - for p in block.parameters(): - if p.requires_grad: - p.grad = param_grads.pop(0) - - assert not param_grads, "Oops!" - - for i in range((steps + 1) // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - torch.cuda.synchronize() - - if os.path.exists(path): - os.remove(path) - - outputs = [output, hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - outputs.append(p.grad) - - return outputs - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): """Test the functionality of replace_raw_data""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 61ccfc6f29..52b3f6f3ab 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -4,12 +4,29 @@ from __future__ import annotations +import logging +import os +from contextlib import contextmanager + +import pytest import torch import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine_torch as tex +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + get_attention_backend, + AttentionParams, + AttentionLogging, +) +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend + +# Initialize RNG state +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: @@ -106,3 +123,184 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() raise ValueError(f"Unsupported quantization scheme ({name})") + + +# Cached RNG state +_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + +def reset_rng_states() -> None: + """Revert to deterministic RNG state""" + global _rng_states + if _rng_states is None: + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state()) + else: + cpu_rng_state, cuda_rng_state = _rng_states + torch.set_rng_state(cpu_rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + fp8.FP8GlobalStateManager.reset() + + +class ModelConfig: + def __init__( + self, + batch_size: int, + num_heads: int, + num_gqa_groups: int, + head_dim_qk: int, + max_seqlen_q: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + attn_mask_type: str = "no_mask", + attn_bias_type: str = "no_bias", + head_dim_v: int = None, + alibi_type: str = "none", + num_layers: int = 1, + bias_shape: str = "1hss", + window_size: Tuple[int, int] = (-1, -1), + total_requests: int = None, + max_ctx_len: int = None, + eps: float = 1e-5, + ): + self.batch_size = batch_size + self.num_heads = num_heads + self.num_gqa_groups = num_gqa_groups + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + if self.head_dim_qk == self.head_dim_v: + self.kv_channels = self.head_dim_qk + else: + self.kv_channels = (self.head_dim_qk, self.head_dim_v) + self.hidden_size = num_heads * head_dim_qk + self.hidden_size_kv = num_gqa_groups * self.head_dim_v + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_kv = max_seqlen_kv + self.dropout_p = dropout_p + self.attn_mask_type = attn_mask_type + self.attn_bias_type = attn_bias_type + self.alibi_type = alibi_type + self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" + self.num_layers = num_layers + self.bias_shape = bias_shape + self.window_size = window_size + self.total_requests = total_requests + self.max_ctx_len = max_ctx_len + self.eps = eps + + +@contextmanager +def logging_context(highest_level=logging.WARNING): + previous_level = logging.root.manager.disable + logging.disable(highest_level) + try: + yield + finally: + logging.disable(previous_level) + + +def get_available_attention_backends( + config: ModelConfig, + qkv_dtype: torch.dtype, + qkv_layout: str, + window_size: Tuple[int, int] = (-1, -1), + pad_between_seqs: bool = False, + context_parallel: bool = False, + deterministic: bool = False, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + is_training: bool = True, + inference_params: Optional[InferenceParams] = None, +) -> Tuple[List, List]: + """Check for all available attention backends that support a model configuration""" + + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + alibi_slopes_shape = None + if config.attn_bias_type == "alibi" and config.alibi_type == "custom": + if config.bias_shape == "1hss": + alibi_slopes_shape = [config.num_heads] + if config.bias_shape == "bhss": + alibi_slopes_shape = [config.batch_size, config.num_heads] + + core_attention_bias_shape = ( + config.bias_shape if config.attn_bias_type == "post_scale_bias" else None + ) + core_attention_bias_requires_grad = False + # d=256 is supported by cuDNN 9.0+ for inference but not training + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): + core_attention_bias_requires_grad = True + + fused_attn_backends = [] + available_backends = None + flash_attention_backend = None + fused_attention_backend = None + + def test(): + attention_params = AttentionParams( + qkv_dtype=qkv_dtype, + qkv_layout=qkv_layout, + batch_size=config.batch_size, + num_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, + attn_mask_type=config.attn_mask_type, + window_size=window_size, + alibi_slopes_shape=alibi_slopes_shape, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias_shape=core_attention_bias_shape, + core_attention_bias_requires_grad=core_attention_bias_requires_grad, + pad_between_seqs=pad_between_seqs, + attention_dropout=config.dropout_p, + context_parallel=context_parallel, + deterministic=deterministic, + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + inference_params=inference_params, + ) + ( + use_flash_attention, + use_fused_attention, + flash_attention_backend, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) = get_attention_backend(attention_params) + # Set attention.py _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False + return available_backends, flash_attention_backend, fused_attention_backend + + backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) + return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b512133efd..11344ecc10 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + // 9.10.0: known bugs with SDPA FP8 + (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -239,10 +241,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10: any head_dim + any arch + fprop + paged - // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91000 && + // 9.10.2: any head_dim + any arch + fprop + paged + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91002 && (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || @@ -354,7 +356,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && // check 64-bit ragged offset support - (supported_ragged_offset_size)) { + (supported_ragged_offset_size) && + // 9.10.0/9.10.1: known bugs with SDPA F16 + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {