diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 5ef5132c8b..c74fc37592 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -46,6 +46,8 @@ def check_fp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" + if get_device_compute_capability() >= (12, 0): + return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." @@ -64,7 +66,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" - if get_device_compute_capability() >= (10, 0): # blackwell and above + if check_mxfp8_support()[0]: + # This is a temporary restriction until MXFP8 is supported for all + # gemm layouts. + if get_device_compute_capability() >= (12, 0): + return Float8BlockScaling() return MXFP8BlockScaling() return DelayedScaling()