diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index e5c84a9bda..9b56170fa1 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -15,8 +15,12 @@ from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase +from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import ( + Float8BlockwiseQTensorBase, +) from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.pytorch.debug_state import TEDebugState @@ -110,7 +114,14 @@ def inspect_tensor_postquantize( API call used to collect the data about the tensor after process_tensor()/quantization. """ - assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], ( + assert type(tensor) in [ + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ], ( f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using" " log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors." )