Skip to content

Added MXFP6 packing and fused unpack-dequantise kernel #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

alex-titterton
Copy link

Initial PR to support MXFP6 packing, whereby the bit representations of 4 x fp6 are packed into 3 x uint8 containers via custom triton kernel.

mx_formats pytests have been amended to ensure the trailing tensor dimension and M block size are both multiples of 4 whenever fp6 packing is performed, since this is required in order to pack 4N values into 3N elements. This shouldn't cause any issues since any FP8 or lower HW implementation (e.g. tensor core) typically expects a minimum trailing dim size of 16/32/...

Main changes:

  • Added custom Triton kernel to pack 4 x uint8 containing fp6 bits into 3 x uint8
  • Added custom fp6 --> bfloat16 unpack-dequantise fused Triton kernel
  • Registered torch custom ops to call these kernels and FakeTensor shapes in order to support torch.compile
  • Added bool in config to enable/disable fp6 packing
  • Amended pytests to use multiple of 4 for trailing tensor dimension.

Copy link

@balancap balancap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alex-titterton Just a few small comments, I am keeping it at a high level as anyway you'll a more detailed code review on the upstream repo (and they will know better what they prefer as a style for integration).

A few additional things not directly in the code:

  • The ruff format linter is failing, worth fixing before opening the PR upstream;
  • The cutlass submodule hash seems to be different to the main branch one. They'll ask you to revert back to the one on main
  • It is not a small/simple PR, so I think it worth motivating a bit more in the main comment why this additional complexity should be added to the repo. The main motivation for me is FP6 is as good as FP8 for accuracy, but saving memory. We should support FP6 packing to save memory.

Additionally, it would help to document in the PR what is the FP6 packing we are using here (as we know there are multiple options). They may ask you if it aligns with Blackwell hardware specs.

@pack_uint6.register_fake
def _(uint8_data):
out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4)
return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need some test coverage of the function added to custom_cast.py to be accepted upstream. i.e. for pack_uint6, triton_f6_e3m2_to_scaled_bf16, triton_f6_e2m3_to_scaled_bf16

swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)
print(sqnr)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants