Skip to content

[PyTorch] Support FA3 MLA CP feature #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

zhujian19891203
Copy link

Description

Because flash-attention #1604 has already support hdimQK != hdimV backward, so we can support FA3(Flash Attention 3) backend for MLA (Multi-latent attention). #1604 allows us to skip explicit padding & unpadding hdimV to use FA3 as attention backend, and it can bring performance benefits.

Test Results:

  1. I add some unit tests for FA3 MLA CP, test_fused_attn.py and test_fused_attn_with_cp.py passed.
  2. I use the Megatron-LM framework to run the 16B DeepSeek model, where TP=CP=2, hdimQK=192, hdimV=128, FA3's MFU is basically the same as cuDNN attention, and the overall loss curves of the two are exactly the same.

image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed)
  2. Update get_attention_backend method because FA3 support MLA now
  3. Support FA3 MLA for CP module
  4. Add unit tests for FA3 MLA CP
  5. Update attention doc

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward
2. Update get_attention_backend method because FA3 support MLA now
3. Add CP MLA support for FA3
4. Add unit tests for FA3 MLA CP
5. Update attention doc

Signed-off-by: zhujian <zhujian.whu.cs@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants