-
Notifications
You must be signed in to change notification settings - Fork 0
Added MXFP6 packing and fused unpack-dequantise kernel #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alex-titterton Just a few small comments, I am keeping it at a high level as anyway you'll a more detailed code review on the upstream repo (and they will know better what they prefer as a style for integration).
A few additional things not directly in the code:
- The
ruff
format linter is failing, worth fixing before opening the PR upstream; - The
cutlass
submodule hash seems to be different to themain
branch one. They'll ask you to revert back to the one onmain
- It is not a small/simple PR, so I think it worth motivating a bit more in the main comment why this additional complexity should be added to the repo. The main motivation for me is
FP6
is as good as FP8 for accuracy, but saving memory. We should support FP6 packing to save memory.
Additionally, it would help to document in the PR what is the FP6 packing we are using here (as we know there are multiple options). They may ask you if it aligns with Blackwell hardware specs.
@pack_uint6.register_fake | ||
def _(uint8_data): | ||
out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) | ||
return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need some test coverage of the function added to custom_cast.py
to be accepted upstream. i.e. for pack_uint6
, triton_f6_e3m2_to_scaled_bf16
, triton_f6_e2m3_to_scaled_bf16
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) | ||
|
||
x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) | ||
y_ref = m(x) | ||
y_mx = m_mx(x) | ||
sqnr = compute_error(y_ref, y_mx) | ||
print(sqnr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove
…sts to suit appropriate tensor dimensions
7b59a94
to
555c845
Compare
555c845
to
2cd2104
Compare
Initial PR to support MXFP6 packing, whereby the bit representations of 4 x
fp6
are packed into 3 xuint8
containers via custom triton kernel.mx_formats
pytests have been amended to ensure the trailing tensor dimension and M block size are both multiples of 4 wheneverfp6
packing is performed, since this is required in order to pack4N
values into3N
elements. This shouldn't cause any issues since any FP8 or lower HW implementation (e.g. tensor core) typically expects a minimum trailing dim size of16
/32
/...Main changes:
uint8
containingfp6
bits into 3 xuint8
fp6
-->bfloat16
unpack-dequantise fused Triton kerneltorch
custom ops to call these kernels andFakeTensor
shapes in order to supporttorch.compile