diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index d27e1831c9..03dc2523de 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -26,7 +26,10 @@ f32_to_f6_e3m2_unpacked, get_bits, pack_uint4, + pack_uint6, triton_f4_to_bf16, + triton_f6_e2m3_to_bf16, + triton_f6_e3m2_to_bf16, unpack_uint4, ) from torchao.prototype.mx_formats.fp_format_spec import ( @@ -411,3 +414,41 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device)) assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) +def test_fp6_e2m3_pack_unpack(): + orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to( + "cuda" + ) + orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals) + orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) + assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) + orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to( + torch.float32 + ) + assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) +def test_fp6_e3m2_pack_unpack(): + orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to( + "cuda" + ) + orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals) + orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) + assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) + orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to( + torch.float32 + ) + assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 17a76a750d..a5e9218af6 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -53,13 +53,13 @@ def test_linear_eager(elem_dtype, bias, input_shape): """ # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) - grad_shape[-1] = 6 + grad_shape[-1] = 8 m = nn.Sequential( - nn.Linear(8, 6, bias=bias, device="cuda"), + nn.Linear(8, 8, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) - block_size = 2 + block_size = 4 swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() @@ -90,14 +90,14 @@ def test_linear_eager(elem_dtype, bias, input_shape): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): input_shape = (2, 4) - grad_shape = (2, 6) + grad_shape = (2, 8) elem_dtype = torch.float8_e4m3fn m = nn.Sequential( - nn.Linear(4, 6, bias=True, device="cuda"), - nn.Linear(6, 6, bias=True, device="cuda"), + nn.Linear(4, 8, bias=True, device="cuda"), + nn.Linear(8, 8, bias=True, device="cuda"), ) - block_size = 2 + block_size = 4 swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) x = torch.randn(*input_shape, device="cuda").requires_grad_() @@ -127,13 +127,13 @@ def test_linear_compile(elem_dtype, bias, use_autocast): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - M, K, N = 4, 8, 6 + M, K, N = 4, 8, 8 input_shape = (M, K) grad_shape = (M, N) m_mx = nn.Sequential( nn.Linear(K, N, bias=bias, device="cuda"), ) - block_size = 2 + block_size = 4 swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -178,10 +178,10 @@ def test_inference_linear(elem_dtype, bias, input_shape): """ Smoke test for inference linear module with mx weight """ - m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16)) + m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 + block_size = 4 swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) @@ -206,10 +206,10 @@ def test_inference_compile_simple(elem_dtype): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) + m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 + block_size = 4 swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) m_mx = torch.compile(m_mx, fullgraph="true") diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 21cb49c064..d258888af6 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -14,7 +14,7 @@ DTYPE_FP6_E3M2, SUPPORTED_ELEM_DTYPES, ) -from torchao.prototype.mx_formats.custom_cast import pack_uint4 +from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6 from torchao.prototype.mx_formats.mx_tensor import ( E8M0_EXPONENT_NAN_VAL, MXTensor, @@ -70,7 +70,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_hello_world(elem_dtype): data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) - block_size = 2 + block_size = 4 _test_mx(data, elem_dtype, block_size) @@ -78,7 +78,7 @@ def test_hello_world(elem_dtype): @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16) - block_size = 2 + block_size = 4 _test_mx(data, elem_dtype, block_size) @@ -88,7 +88,7 @@ def test_some_zeros(elem_dtype): data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) data[0, :] = 0.0 data[:, 2] = 0.0 - block_size = 2 + block_size = 4 _test_mx(data, elem_dtype, block_size) @@ -100,9 +100,9 @@ def test_exponent_nan_in(elem_dtype): value is set to is NaN """ tensor_hp = torch.tensor( - [float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16 + [float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16 ) - block_size = 2 + block_size = 4 tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL) assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL) @@ -115,24 +115,36 @@ def test_exponent_nan_out(elem_dtype): If block exponent value is NaN, the MX tensor block value is NaN """ scale_e8m0_bits = torch.tensor( - [E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda" + [E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda" ) + + block_size = 4 + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501 + data_bits = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda" + ) # noqa: E501 elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): - data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501 + data_bits = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + ) # noqa: E501 + if config.pack_fp6: + data_bits = data_bits.reshape(-1, block_size) + data_bits = pack_uint6(data_bits) elif elem_dtype == DTYPE_FP4: - data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501 + data_bits = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + ) # noqa: E501 data_bits = pack_uint4(data_bits) else: raise AssertionError("unsupported") - block_size = 2 + tensor_mx = MXTensor( scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float ) tensor_hp = tensor_mx.to_dtype(torch.float) - assert torch.all(torch.isnan(tensor_hp[0:1])) - assert not torch.any(torch.isnan(tensor_hp[2:])) + assert torch.all(torch.isnan(tensor_hp.flatten()[0:4])) + assert not torch.any(torch.isnan(tensor_hp.flatten()[4:])) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -141,8 +153,8 @@ def test_ranks(elem_dtype): """ The reshaping logic works for various ranks """ - B = 2 - shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2)) + B = 4 + shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4)) for s in shapes: tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) @@ -150,15 +162,17 @@ def test_ranks(elem_dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_block_sizes(elem_dtype): +@pytest.mark.parametrize("B", [1, 4, 32]) +def test_block_sizes(elem_dtype, B): """ Smoke test for various block sizes """ - for B in (1, 2, 32): - if B == 1 and elem_dtype == DTYPE_FP4: - pytest.skip("unsupported configuration") - tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) - _test_mx(tensor_hp, elem_dtype, B) + if B == 1 and elem_dtype == DTYPE_FP4: + pytest.skip("unsupported configuration") + elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: + pytest.skip("unsupported configuration") + tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) + _test_mx(tensor_hp, elem_dtype, B) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -202,14 +216,32 @@ def test_cast_autograd(elem_dtype): torch.testing.assert_close(grad, x.grad, atol=0, rtol=0) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_view(elem_dtype): - x = torch.randn(1, 2, 4) - block_size = 2 + x = torch.randn(1, 2, 4, device="cuda") + block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size) x_mx_2 = x_mx.view(2, 4) # noqa: F841 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) +@pytest.mark.parametrize("do_fp6_packing", [False, True]) +def test_fp6_packing(elem_dtype, do_fp6_packing): + config.pack_fp6 = do_fp6_packing + x = torch.randn(1, 2, 4, device="cuda") + block_size = 4 + x_mx = MXTensor.to_mx(x, elem_dtype, block_size) + if config.pack_fp6: + expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4]) + else: + expected_packed_shape = x.shape + config.pack_fp6 = True + + assert x_mx._data.shape == expected_packed_shape + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" @@ -231,7 +263,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x = torch.randn(*shape, dtype=hp_dtype, device="cuda") else: x = torch.zeros(*shape, dtype=hp_dtype, device="cuda") - block_size = 2 + block_size = 4 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) x_mx = MXTensor.to_mx(x, elem_dtype, block_size) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3e7e03d8f6..732f95d7aa 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -1,2 +1,3 @@ # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel = False +pack_fp6 = True diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index cda946e285..cb915e4bd2 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -24,6 +24,8 @@ E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, F4_E2M1_EXP_BIAS, + F6_E2M3_EXP_BIAS, + F6_E3M2_EXP_BIAS, F32_EXP_BIAS, ) @@ -44,6 +46,12 @@ def get_bits(x: torch.Tensor) -> str: SIGN_MASK_F4 = 0x8 # 1000 MANTISSA_MASK_F4 = 0x1 # 0001 +SIGN_MASK_F6_E2M3 = 0x20 # 100000 +MANTISSA_MASK_F6_E2M3 = 0x7 # 000111 + +SIGN_MASK_F6_E3M2 = 0x20 # 100000 +MANTISSA_MASK_F6_E3M2 = 0x3 # 000011 + ZERO_BITS_F32 = 0x0 ZERO_POINT_FIVE_BITS_F32 = 0x3F000000 @@ -313,6 +321,300 @@ def triton_f4_to_scaled_bf16_kernel( tl.store(output_ptr + offsets_out, output, mask=mask_out) + @triton.jit + def _fp6_packed_to_bf16( + packed_4bits_a, + packed_4bits_b, + packed_2bits, + sign_mask_f6, + mbits_f6, + f6_exp_bias, + mbits_f32, + f32_exp_bias, + ): + """ + Input: a tensor of packed fp6 values + Output: a tensor of bfloat16 values + """ + + # L/R shift and combine back into uint8 with first 2 bits empty (i.e. unpacked) + x_0 = ((packed_4bits_a >> 2) & 0x3C) | ((packed_2bits & 0xC0) >> 6) + x_1 = ((packed_4bits_a << 2) & 0x3C) | ((packed_2bits & 0x30) >> 4) + x_2 = ((packed_4bits_b >> 2) & 0x3C) | ((packed_2bits & 0xC) >> 2) + x_3 = ((packed_4bits_b << 2) & 0x3C) | (packed_2bits & 0x3) + + # repeat_interleave not supported yet, see https://github.com/triton-lang/triton/issues/1426 + # instead we can interleave(interleave(4*i, 4*i+2), interleave(4*i+1, 4*i+3)) + # TODO: is there a more performant way? + # We could stack all 4, then transpose and ravel and do it that way? + x_02 = tl.interleave(x_0, x_2) # [x_0_0, x_2_0, x_0_1, x_2_1, ...] + x_13 = tl.interleave(x_1, x_3) # [x_1_0, x_3_0, x_1_1, x_3_1, ...] + x = tl.interleave(x_02, x_13) # [x_0_0, x_1_0, x_2_0, x_3_0, x_0_1, ...] + + # save the sign + sign_f6 = x & sign_mask_f6 + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_f6 + + # shift the exponent and mantissa + result = x_pos.to(tl.int32) << (mbits_f32 - mbits_f6) + + # add sign back + # left shift is always 26 regardless of fp6 variant + sign_f32 = sign_f6.to(tl.int32) << 26 + result = result | sign_f32 + + # The bit shifting above is for float32, so for now we + # bitcast to float32 and then regular cast to bfloat16 + # TODO(later): it should be pretty easy to cast directly to bf16, just + # need to adjust the mbits/ebits/special values. Perf impact is likely + # to be small as we would not be changing memory access patterns. + output = result.to(tl.float32, bitcast=True) + + # Scale the fp32 exponent afterwards, handles the denorms correctly + output *= 2.0 ** (f32_exp_bias - f6_exp_bias) + + output = output.to(tl.bfloat16) + return output + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), + ], + key=["n_mx_blocks"], + ) + @triton.jit + def triton_f6_to_bf16_kernel( + x_ptr, + output_ptr, + n_mx_blocks, + mx_block_size: tl.constexpr, + packed_mx_block_size: tl.constexpr, + sign_mask_f6: tl.constexpr, + mbits_f6: tl.constexpr, + f6_exp_bias: tl.constexpr, + mbits_f32: tl.constexpr, + f32_exp_bias: tl.constexpr, + BLOCK_SIZE_IN: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_IN + + offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) + offsets_cols = tl.arange(0, packed_mx_block_size // 3) + mask_in = (offsets_rows[:, None] < n_mx_blocks) & ( + offsets_cols[None, :] < packed_mx_block_size // 3 + ) + offsets_in = ( + offsets_rows[:, None] * packed_mx_block_size + offsets_cols[None, :] + ) + + # packed 4 x fp6 into 3 x uint8 + packed_4bits_a = tl.load(x_ptr + offsets_in, mask=mask_in, other=0) + packed_4bits_b = tl.load( + x_ptr + offsets_in + (packed_mx_block_size // 3), mask=mask_in, other=0 + ) + packed_2bits = tl.load( + x_ptr + offsets_in + (2 * packed_mx_block_size // 3), mask=mask_in, other=0 + ) + + output = _fp6_packed_to_bf16( + packed_4bits_a, + packed_4bits_b, + packed_2bits, + sign_mask_f6, + mbits_f6, + f6_exp_bias, + mbits_f32, + f32_exp_bias, + ) + + # set up output offsets + offsets_rows_out = block_start + tl.arange(0, BLOCK_SIZE_IN) + offsets_cols_out = tl.arange(0, mx_block_size) + offsets_out = ( + offsets_rows_out[:, None] * mx_block_size + offsets_cols_out[None, :] + ) + mask_out = (offsets_rows_out[:, None] < n_mx_blocks) & ( + offsets_cols_out[None, :] < mx_block_size + ) + + tl.store(output_ptr + offsets_out, output, mask=mask_out) + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), + ], + key=["n_mx_blocks"], + ) + @triton.jit + def triton_f6_to_scaled_bf16_kernel( + x_ptr, + s_ptr, + output_ptr, + n_mx_blocks, + mx_block_size: tl.constexpr, + packed_mx_block_size: tl.constexpr, + sign_mask_f6: tl.constexpr, + mbits_f6: tl.constexpr, + f6_exp_bias: tl.constexpr, + mbits_f32: tl.constexpr, + f32_exp_bias: tl.constexpr, + e8m0_exponent_bias: tl.constexpr, + e8m0_exponent_nan_val: tl.constexpr, + BLOCK_SIZE_IN: tl.constexpr, + ): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE_IN + + offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) + offsets_cols = tl.arange(0, packed_mx_block_size // 3) + mask_in = (offsets_rows[:, None] < n_mx_blocks) & ( + offsets_cols[None, :] < packed_mx_block_size // 3 + ) + offsets_in = ( + offsets_rows[:, None] * packed_mx_block_size + offsets_cols[None, :] + ) + + # packed 4 x fp6 into 3 x uint8 + packed_4bits_a = tl.load(x_ptr + offsets_in, mask=mask_in, other=0) + packed_4bits_b = tl.load( + x_ptr + offsets_in + (packed_mx_block_size // 3), mask=mask_in, other=0 + ) + packed_2bits = tl.load( + x_ptr + offsets_in + (2 * packed_mx_block_size // 3), mask=mask_in, other=0 + ) + + output = _fp6_packed_to_bf16( + packed_4bits_a, + packed_4bits_b, + packed_2bits, + sign_mask_f6, + mbits_f6, + f6_exp_bias, + mbits_f32, + f32_exp_bias, + ) + + # load scale + offsets_s = block_start + tl.arange(0, BLOCK_SIZE_IN) + mask_s = offsets_s < n_mx_blocks + s = tl.load(s_ptr + offsets_s, mask=mask_s) + + # create the scale in bf16 + s_offset = s.to(tl.float32) - e8m0_exponent_bias + s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) + s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) + + # multiply output by scale + # TODO(later): see if manipulating the exponent instead of fp + # multiplication is going to give a significant speedup + output = tl.reshape(output, (BLOCK_SIZE_IN, mx_block_size)) # noqa: E501 + s_fp = tl.reshape(s_fp, (BLOCK_SIZE_IN // 1, 1)) + output = output * s_fp + output = tl.reshape(output, (BLOCK_SIZE_IN, mx_block_size)) + + # set up output offsets + offsets_rows_out = block_start + tl.arange(0, BLOCK_SIZE_IN) + offsets_cols_out = tl.arange(0, mx_block_size) + offsets_out = ( + offsets_rows_out[:, None] * mx_block_size + offsets_cols_out[None, :] + ) + mask_out = (offsets_rows_out[:, None] < n_mx_blocks) & ( + offsets_cols_out[None, :] < mx_block_size + ) + + tl.store(output_ptr + offsets_out, output, mask=mask_out) + + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), + triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), + ], + key=["n_mx_blocks"], + ) + @triton.jit + def triton_pack_uint6_kernel( + input_ptr, + output_ptr, + n_mx_blocks, + MX_BLOCK_SIZE: tl.constexpr, + PACKED_MX_BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_IN: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_IN + + # input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE] + # Load BLOCK_SIZE rows of input_ptr + offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) + offsets_cols = tl.arange(0, MX_BLOCK_SIZE // 4) + offsets = offsets_rows[:, None] * MX_BLOCK_SIZE + (4 * offsets_cols[None, :]) + mask = (offsets_rows[:, None] < n_mx_blocks) & ( + offsets_cols[None, :] < MX_BLOCK_SIZE // 4 + ) + + # x is shape [BLOCK_SIZE, MX_BLOCK_SIZE] + x_0 = tl.load(input_ptr + offsets, mask=mask) + x_1 = tl.load(input_ptr + offsets + 1, mask=mask) + x_2 = tl.load(input_ptr + offsets + 2, mask=mask) + x_3 = tl.load(input_ptr + offsets + 3, mask=mask) + + # OR between remainder 0/1, 2/3 elements to pack 2 x first-4-bit partial representations + # next to each other. These are the middle 4 bits of the uint8, so some gymnastics required. + # i.e. (00abcd00 >> 2) | (00wxyz00 << 2) = 0000abcd | wxyz0000 = wxyzabcd + bits_packed_4_a = (x_1 >> 2) | ((x_0 << 2) & 0xF0) + bits_packed_4_b = (x_3 >> 2) | ((x_2 << 2) & 0xF0) + # Similarly pack 4 remaining 2-bit partial representations into one uint8 + # e.g. 000000ab, 0000cd00, 00ef0000, gh000000 --> abcdefgh + bits_packed_2 = ( + (x_0 << 6) | ((x_1 << 4) & 0x30) | ((x_2 << 2) & 0xC) | (x_3 & 0x3) + ) + + # Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4` + offsets_out_4_a = ( + offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + offsets_cols[None, :] + ) + offsets_out_4_b = ( + offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + + offsets_cols[None, :] + + (MX_BLOCK_SIZE // 4) + ) + offsets_out_2 = ( + offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + + offsets_cols[None, :] + + (MX_BLOCK_SIZE // 2) + ) + + # Store into output tensor + tl.store( + output_ptr + offsets_out_4_a, + bits_packed_4_a, + mask=mask, + ) + + tl.store( + output_ptr + offsets_out_4_b, + bits_packed_4_b, + mask=mask, + ) + + tl.store( + output_ptr + offsets_out_2, + bits_packed_2, + mask=mask, + ) + else: def triton_f4_to_bf16_kernel( @@ -355,6 +657,46 @@ def triton_f4_to_scaled_bf16_kernel( ): raise AssertionError("unsupported without triton") + def triton_f6_to_bf16_kernel( + x_ptr, + output_ptr, + n_elements_in, + sign_mask_f6, + mbits_f6, + f6_exp_bias, + mbits_f32, + f32_exp_bias, + BLOCK_SIZE_IN, + ): + raise AssertionError("unsupported without triton") + + def triton_f6_to_scaled_bf16_kernel( + x_ptr, + s_ptr, + output_ptr, + n_elements_in, + mx_block_size, + sign_mask_f6, + mbits_f6, + f6_exp_bias, + mbits_f32, + f32_exp_bias, + e8m0_exponent_bias, + e8m0_exponent_nan_val, + BLOCK_SIZE_IN, + ): + raise AssertionError("unsupported without triton") + + def triton_pack_uint6_kernel( + input_ptr, + output_ptr, + n_mx_blocks, + MX_BLOCK_SIZE: tl.constexpr, + PACKED_MX_BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + raise AssertionError("unsupported without triton") + def triton_f4_to_bf16(x: torch.Tensor): """ @@ -432,6 +774,178 @@ def triton_f4_to_scaled_bf16( return output +def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values + Output: a tensor of bfloat16 values + + Note: this function is only used in testing, so we can test + the numerical correctness of the cast without the scaling. + """ + packed_mx_block_size = x.shape[-1] + mx_block_size = 4 * packed_mx_block_size // 3 + + x = x.view(-1, packed_mx_block_size) + new_shape = (x.shape[0], mx_block_size) + + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda + + n_mx_blocks = x.shape[0] + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_bf16_kernel[grid]( + x, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E2M3, + mbits_f6=MBITS_F6_E2M3, + f6_exp_bias=F6_E2M3_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + ) + return output + + +def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values + Output: a tensor of bfloat16 values + + Note: this function is only used in testing, so we can test + the numerical correctness of the cast without the scaling. + """ + packed_mx_block_size = x.shape[-1] + mx_block_size = 4 * packed_mx_block_size // 3 + + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda + + n_mx_blocks = x.shape[0] + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_bf16_kernel[grid]( + x, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E3M2, + mbits_f6=MBITS_F6_E3M2, + f6_exp_bias=F6_E3M2_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + ) + return output + + +@torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) +def triton_f6_e2m3_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + assert TORCH_VERSION_AT_LEAST_2_4, "unsupported" + + packed_mx_block_size = 3 * mx_block_size // 4 + + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda + + n_mx_blocks = x.shape[0] + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E2M3, + mbits_f6=MBITS_F6_E2M3, + f6_exp_bias=F6_E2M3_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output + + +@torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) +def triton_f6_e3m2_to_scaled_bf16( + x: torch.Tensor, + s_e8m0: torch.Tensor, + mx_block_size: int, +) -> torch.Tensor: + """ + Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block + size is currently assumed to be 32. + Output: a tensor of bfloat16 values, multiplied by the encoded scale + """ + assert TORCH_VERSION_AT_LEAST_2_4, "unsupported" + + packed_mx_block_size = 3 * mx_block_size // 4 + + x = x.view(-1, packed_mx_block_size) + new_shape = (x.numel() // packed_mx_block_size, mx_block_size) + + output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) + + assert x.is_contiguous() + assert x.is_cuda and output.is_cuda + + n_mx_blocks = x.numel() // packed_mx_block_size + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + triton_f6_to_scaled_bf16_kernel[grid]( + x, + s_e8m0, + output, + n_mx_blocks, + mx_block_size, + packed_mx_block_size, + sign_mask_f6=SIGN_MASK_F6_E3M2, + mbits_f6=MBITS_F6_E3M2, + f6_exp_bias=F6_E3M2_EXP_BIAS, + mbits_f32=MBITS_F32, + f32_exp_bias=F32_EXP_BIAS, + e8m0_exponent_bias=E8M0_EXPONENT_BIAS, + e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, + ) + return output + + +@triton_f6_e3m2_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) + + +@triton_f6_e2m3_to_scaled_bf16.register_fake +def _(x, s_e8m0, mx_block_size): + _padded_mx_block_size = 3 * mx_block_size // 4 + out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) + return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) + + # pack/unpack code copy-pasted from # https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/uint4.py @@ -478,9 +992,47 @@ def unpack_uint4(uint8_data) -> torch.Tensor: return unpacked -def pack_uint4(uint8_data) -> torch.Tensor: +def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: # converting to uint8 for operations shape = uint8_data.shape assert shape[-1] % 2 == 0 uint8_data = uint8_data.contiguous().view(-1) return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) + + +@torch.library.custom_op("ao::pack_uint6", mutates_args=()) +def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: + # ensure input data is contiguous before passing to kernel + assert uint8_data.is_contiguous() + + # tensor should already be of shape [..., mx_block_size] + mx_block_size = uint8_data.shape[-1] + assert mx_block_size % 4 == 0 + + # effective mx block size since we're packing 2 fp4 into 1 uint8 + packed_mx_block_size = 3 * mx_block_size // 4 + packed_shape = [uint8_data.shape[0], packed_mx_block_size] + n_mx_blocks = uint8_data.numel() // mx_block_size + + grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) + + # contiguous uint8 container in which we can store the unpacked tensor + packed_uint8_data = torch.empty( + packed_shape, dtype=torch.uint8, device=uint8_data.device + ) + + triton_pack_uint6_kernel[grid]( + uint8_data, + packed_uint8_data, + n_mx_blocks, + MX_BLOCK_SIZE=mx_block_size, + PACKED_MX_BLOCK_SIZE=packed_mx_block_size, + ) + + return packed_uint8_data + + +@pack_uint6.register_fake +def _(uint8_data): + out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) + return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 57fb0d54b4..9ef3ddad6b 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -22,10 +22,15 @@ import torch from torch.utils._pytree import tree_map -from torchao.prototype.mx_formats.constants import DTYPE_FP4 +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, +) from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 MXTensor, tensor_size_hp_to_fp4x2, + tensor_size_hpx3_to_fp6x4, ) aten = torch.ops.aten @@ -113,6 +118,9 @@ def mx_view_op(aten_op, args, kwargs=None): if args[0]._elem_dtype == DTYPE_FP4: # special case fp4 as we pack two elements per byte new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) + elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3]: + # special case fp6 as we pack 4 elements in 3 bytes + new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = aten_op(data, new_size, *args[2:], **kwargs) return MXTensor( args[0]._scale_e8m0, @@ -131,9 +139,9 @@ def autocast_to_copy(aten_op, args, kwargs=None): """ assert isinstance(args[0], MXTensor) # print('before', args[0], args[0].dtype, args[0]._orig_dtype) - assert ( - len(kwargs) == 1 and "dtype" in kwargs - ), "Only support dtype kwarg for autocast" + assert len(kwargs) == 1 and "dtype" in kwargs, ( + "Only support dtype kwarg for autocast" + ) assert kwargs["dtype"] in { torch.float16, torch.bfloat16, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 8eeeaf8bfd..dbd0e45816 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -49,7 +49,10 @@ f32_to_f6_e2m3_unpacked, f32_to_f6_e3m2_unpacked, pack_uint4, + pack_uint6, triton_f4_to_scaled_bf16, + triton_f6_e2m3_to_scaled_bf16, + triton_f6_e3m2_to_scaled_bf16, unpack_uint4, ) @@ -157,21 +160,30 @@ def to_mx( data_lp = torch.clamp( data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos ) - data_lp = data_lp.reshape(orig_shape) # cast to target dtype if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_lp = data_lp.to(elem_dtype) elif elem_dtype == DTYPE_FP6_E2M3: data_lp = f32_to_f6_e2m3_unpacked(data_lp) + if config.pack_fp6: + orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] + data_lp = pack_uint6(data_lp) elif elem_dtype == DTYPE_FP6_E3M2: data_lp = f32_to_f6_e3m2_unpacked(data_lp) + if config.pack_fp6: + orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] + data_lp = pack_uint6(data_lp) elif elem_dtype == DTYPE_FP4: data_lp = f32_to_f4_unpacked(data_lp) + orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2] data_lp = pack_uint4(data_lp) else: raise AssertionError("unsupported") + # Moved the reshape to later to simplify fp6 packing + data_lp = data_lp.reshape(orig_shape) + return scale_e8m0_biased, data_lp @@ -204,11 +216,33 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_hp = data_lp.to(target_dtype) elif elem_dtype == DTYPE_FP6_E2M3: - data_hp = f6_e2m3_unpacked_to_f32(data_lp) - data_hp = data_hp.to(target_dtype) + if config.pack_fp6: + orig_shape = (*orig_shape[:-1], 4 * orig_shape[-1] // 3) + data_hp_rescaled = triton_f6_e2m3_to_scaled_bf16( + data_lp, + scale_e8m0, + block_size, + ).reshape(orig_shape) + if is_transposed: + data_hp_rescaled = data_hp_rescaled.t() + return data_hp_rescaled.to(target_dtype) + else: + data_hp = f6_e2m3_unpacked_to_f32(data_lp) + data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E3M2: - data_hp = f6_e3m2_unpacked_to_f32(data_lp) - data_hp = data_hp.to(target_dtype) + if config.pack_fp6: + orig_shape = (*orig_shape[:-1], 4 * orig_shape[-1] // 3) + data_hp_rescaled = triton_f6_e3m2_to_scaled_bf16( + data_lp, + scale_e8m0, + block_size, + ).reshape(orig_shape) + if is_transposed: + data_hp_rescaled = data_hp_rescaled.t() + return data_hp_rescaled.to(target_dtype) + else: + data_hp = f6_e3m2_unpacked_to_f32(data_lp) + data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == DTYPE_FP4: if config.use_fp4_custom_triton_dequant_kernel: data_hp_rescaled = triton_f4_to_scaled_bf16( @@ -263,6 +297,24 @@ def tensor_size_fp4x2_to_hp(orig_size, is_contiguous): return new_size +def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous): + new_size = orig_size + if is_contiguous: + new_size = [*list(new_size[:-1]), 3 * new_size[-1] // 4] + else: + new_size = [3 * new_size[0] // 4, *list(new_size[1:])] + return new_size + + +def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous): + new_size = orig_size + if is_contiguous: + new_size = [*list(new_size[:-1]), 4 * new_size[-1] // 3] + else: + new_size = [4 * new_size[0] // 3, *list(new_size[1:])] + return new_size + + @torch._dynamo.allow_in_graph class ToMXConstrFunc(torch.autograd.Function): """ @@ -322,6 +374,12 @@ def __new__( new_size, data_bits.is_contiguous(), ) + elif elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: + # set the tensor size to what it would be without 2x4 packing + new_size = tensor_size_fp6x4_to_hpx3( + new_size, + data_bits.is_contiguous(), + ) self = torch.Tensor._make_wrapper_subclass( cls, new_size, @@ -341,13 +399,16 @@ def __new__( if elem_dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, - DTYPE_FP6_E2M3, - DTYPE_FP6_E3M2, ): target_numel = scale_e8m0_bits.numel() * block_size elif elem_dtype == DTYPE_FP4: assert data_bits.dtype is torch.uint8 # fp4 target_numel = scale_e8m0_bits.numel() * block_size / 2 + elif elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: + assert data_bits.dtype is torch.uint8 # fp4 + target_numel = scale_e8m0_bits.numel() * block_size + if config.pack_fp6: + target_numel = 3 * target_numel // 4 else: raise AssertionError("unsupported") if not issubclass( @@ -356,9 +417,9 @@ def __new__( ): # this check is sometimes broken for FakeTensor # TODO investigate - assert ( - target_numel == data_bits.numel() - ), f"{target_numel} != {data_bits.numel()}" + assert target_numel == data_bits.numel(), ( + f"{target_numel} != {data_bits.numel()}" + ) # `_scale_e8m0` has rank 1 and applies to a row-major memory layout of # `_data`