Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for qk hidden dim different from v hidden dim #1166

Open
wants to merge 54 commits into
base: main
Choose a base branch
from

Conversation

smallscientist1
Copy link

@smallscientist1 smallscientist1 commented Aug 20, 2024

We add support for

  • different hidden dimension between qk and v.
  • not equal num_heads_k and num_heads_v, such as (num_heads_q, num_heads_k, num_heads_v) = (32, 4, 16).

For different hidden dimension between qk and v, we have supported:

  • FlashAttention-2 with QKHeadDim=32, VHeadDim=64
  • FlashAttention-2 with QKHeadDim=64, VHeadDim=128
  • FlashAttention-2 with QKHeadDim=96, VHeadDim=192
  • FlashAttention-2 with QKHeadDim=128, VHeadDim=256
  • FlashAttention-2 with QKHeadDim=192, VHeadDim=128

For headdim not supported, you can use the autotuner to generate the implementation. Details are in autotuner.md.

Performance

We test the performance speedup compare to padding qk&v hidden_dim to the same length.
Screenshot 2024-08-22 at 09 06 23

Test

We add unittest in tests/test_flash_attn_headdim.py tests/test_flash_attn_head.py.

@iqiancheng
Copy link

iqiancheng commented Aug 21, 2024

hi~ @smallscientist1
Regarding the combinations of qk and v dimensions you've implemented in FlashAttention-2, which configuration have you found to offer the best balance between performance and model effectiveness? Specifically, among the combinations:

QKHeadDim=32, VHeadDim=64
QKHeadDim=64, VHeadDim=128
QKHeadDim=96, VHeadDim=192
QKHeadDim=128, VHeadDim=256

Which one stands out in terms of computational efficiency and model quality?

@xiayuqing0622
Copy link

hi~ @smallscientist1 Regarding the combinations of qk and v dimensions you've implemented in FlashAttention-2, which configuration have you found to offer the best balance between performance and model effectiveness? Specifically, among the combinations:

QKHeadDim=32, VHeadDim=64 QKHeadDim=64, VHeadDim=128 QKHeadDim=96, VHeadDim=192 QKHeadDim=128, VHeadDim=256

Which one stands out in terms of computational efficiency and model quality?

In terms of model quality, it's too early to make a definitive assessment since the work is still in progress. However, several teams we've collaborated with expressed a need for this combination, so we implemented it. Additionally, anticipating that others might find it useful, we created this PR to benefit the broader community.

smallscientist1 and others added 9 commits August 22, 2024 08:32
* create bench headdim

* update bench result

* update Readme

* reorg code to reduce compile time

* update (128,256) config

* add (192,128)

* add config (192,128)

* fix bug

* fix bug backward

* fix bug
@smallscientist1 smallscientist1 marked this pull request as ready for review September 19, 2024 08:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants