Skip to content

[PyTorch Debug] Support log fp8 tensor stats for blockwise recipe #1905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base import (
Float8BlockwiseQTensorBase,
)
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState

Expand Down Expand Up @@ -110,7 +114,14 @@ def inspect_tensor_postquantize(
API call used to collect the data about the tensor after process_tensor()/quantization.
"""

assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], (
assert type(tensor) in [
Float8Tensor,
Float8TensorBase,
MXFP8Tensor,
MXFP8TensorBase,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
], (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using"
" log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors."
)
Expand Down