From 4a3c7b3b382ed37869fe02c93a6b6fd2401c3aa3 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 9 Jul 2025 09:05:56 -0700 Subject: [PATCH 1/3] mxfp8 is not supported on 120+ arch yet Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 5ef5132c8b..8da7f0d76e 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -46,6 +46,8 @@ def check_fp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" + if get_device_compute_capability() >= (12, 0): + return False, "MXFP8 is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." @@ -64,7 +66,7 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" - if get_device_compute_capability() >= (10, 0): # blackwell and above + if check_mxfp8_support()[0]: return MXFP8BlockScaling() return DelayedScaling() From 3d7cd923ae6064fd106c3100b633fd363d805393 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 9 Jul 2025 13:39:53 -0700 Subject: [PATCH 2/3] change the default recipe for arch 120 Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/fp8.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8da7f0d76e..7432956e43 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -47,7 +47,8 @@ def check_fp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 is not supported on 12.0+ architectures yet." + return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ " \ + "architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." @@ -67,6 +68,10 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if check_mxfp8_support()[0]: + # This is a temporary restriction until MXFP8 is supported for all + # gemm layouts. + if get_device_compute_capability() >= (12, 0): + return Float8BlockScaling() return MXFP8BlockScaling() return DelayedScaling() From 308694498b7d5d37ae188c3916fca7fc8aa52340 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Jul 2025 20:40:52 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/fp8.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 7432956e43..c74fc37592 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -47,8 +47,7 @@ def check_fp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ " \ - "architectures yet." + return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution."