From ee01cafc19d4b5f07dc5d2a1e4ea723bb4b60ac2 Mon Sep 17 00:00:00 2001 From: Wei Heng Date: Fri, 27 Jun 2025 16:40:53 -0700 Subject: [PATCH 1/2] [PyTorch Debug] Support log fp8 tensor stats for blockwise Support log fp8 tensor stats for fp8 blockwise recipe Signed-off-by: Wei Heng --- transformer_engine/debug/features/log_fp8_tensor_stats.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index e5c84a9bda..a2078b5465 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -15,8 +15,10 @@ from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase +from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.pytorch.debug_state import TEDebugState @@ -110,7 +112,7 @@ def inspect_tensor_postquantize( API call used to collect the data about the tensor after process_tensor()/quantization. """ - assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], ( + assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase, Float8BlockwiseQTensor, Float8BlockwiseQTensorBase], ( f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using" " log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors." ) From fea049f8d0c9db86afe7904eeea1a959098b6fa0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Jun 2025 23:49:04 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../debug/features/log_fp8_tensor_stats.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index a2078b5465..9b56170fa1 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -18,7 +18,9 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase -from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import ( + Float8BlockwiseQTensorBase, +) from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.pytorch.debug_state import TEDebugState @@ -112,7 +114,14 @@ def inspect_tensor_postquantize( API call used to collect the data about the tensor after process_tensor()/quantization. """ - assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase, Float8BlockwiseQTensor, Float8BlockwiseQTensorBase], ( + assert type(tensor) in [ + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ], ( f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using" " log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors." )