Skip to content

Added MXFP6 packing and fused unpack-dequantise kernel #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
f32_to_f6_e3m2_unpacked,
get_bits,
pack_uint4,
pack_uint6,
triton_f4_to_bf16,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
unpack_uint4,
)
from torchao.prototype.mx_formats.fp_format_spec import (
Expand Down Expand Up @@ -411,3 +414,41 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e2m3_pack_unpack():
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e3m2_pack_unpack():
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
26 changes: 13 additions & 13 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
grad_shape[-1] = 6
grad_shape[-1] = 8

m = nn.Sequential(
nn.Linear(8, 6, bias=bias, device="cuda"),
nn.Linear(8, 8, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
block_size = 2
block_size = 4
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
Expand Down Expand Up @@ -90,14 +90,14 @@ def test_linear_eager(elem_dtype, bias, input_shape):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 6)
grad_shape = (2, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
nn.Linear(4, 8, bias=True, device="cuda"),
nn.Linear(8, 8, bias=True, device="cuda"),
)
block_size = 2
block_size = 4
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
Expand Down Expand Up @@ -127,13 +127,13 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6
M, K, N = 4, 8, 8
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
block_size = 4
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand Down Expand Up @@ -178,10 +178,10 @@ def test_inference_linear(elem_dtype, bias, input_shape):
"""
Smoke test for inference linear module with mx weight
"""
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
block_size = 4
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
Expand All @@ -206,10 +206,10 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
block_size = 4
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
m_mx = torch.compile(m_mx, fullgraph="true")

Expand Down
80 changes: 56 additions & 24 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import pack_uint4
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
Expand Down Expand Up @@ -70,15 +70,15 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -88,7 +88,7 @@ def test_some_zeros(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data[0, :] = 0.0
data[:, 2] = 0.0
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -100,9 +100,9 @@ def test_exponent_nan_in(elem_dtype):
value is set to is NaN
"""
tensor_hp = torch.tensor(
[float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
)
block_size = 2
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
Expand All @@ -115,24 +115,36 @@ def test_exponent_nan_out(elem_dtype):
If block exponent value is NaN, the MX tensor block value is NaN
"""
scale_e8m0_bits = torch.tensor(
[E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda"
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
)

block_size = 4

if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda"
) # noqa: E501
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
if config.pack_fp6:
data_bits = data_bits.reshape(-1, block_size)
data_bits = pack_uint6(data_bits)
elif elem_dtype == DTYPE_FP4:
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
data_bits = pack_uint4(data_bits)
else:
raise AssertionError("unsupported")
block_size = 2

tensor_mx = MXTensor(
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
assert not torch.any(torch.isnan(tensor_hp[2:]))
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -141,24 +153,26 @@ def test_ranks(elem_dtype):
"""
The reshaping logic works for various ranks
"""
B = 2
shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2))
B = 4
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
for s in shapes:
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_block_sizes(elem_dtype):
@pytest.mark.parametrize("B", [1, 4, 32])
def test_block_sizes(elem_dtype, B):
"""
Smoke test for various block sizes
"""
for B in (1, 2, 32):
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -202,14 +216,32 @@ def test_cast_autograd(elem_dtype):
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_view(elem_dtype):
x = torch.randn(1, 2, 4)
block_size = 2
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_mx_2 = x_mx.view(2, 4) # noqa: F841


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
@pytest.mark.parametrize("do_fp6_packing", [False, True])
def test_fp6_packing(elem_dtype, do_fp6_packing):
config.pack_fp6 = do_fp6_packing
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
if config.pack_fp6:
expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4])
else:
expected_packed_shape = x.shape
config.pack_fp6 = True

assert x_mx._data.shape == expected_packed_shape


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
Expand All @@ -231,7 +263,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
else:
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
block_size = 2
block_size = 4
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)

x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel = False
pack_fp6 = True
Loading
Loading