Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention V2 #485

Closed
nivibilla opened this issue Jul 17, 2023 · 14 comments · Fixed by #877
Closed

Flash Attention V2 #485

nivibilla opened this issue Jul 17, 2023 · 14 comments · Fixed by #877

Comments

@nivibilla
Copy link

https://github.com/Dao-AILab/flash-attention

Flash attention v2 was released claiming 2x speedups. Making an issue to remind myself to have a look at it. And also if anyone else wants to try implement it.

@chenrui17
Copy link

I use benchmarks/benchmark_throughput.py to test flash attention V2, but it doesn't seem to have any effect. my test step is like this,

  • update xformers to latest version
  • modify the line to self.attn_op = xops.fmha.flash.FwOp()
  • python3 benchmark_throughput.py --dataset=./ShareGPT_V3_unfiltered_cleaned_split.json --model=/huggingface_data/llama-7b-hf/ --tokenizer=hf-internal-testing/llama-tokenizer --num-prompts=500

test time is like this,

Further analysis of performance, i found that the replaced part (flash attention V2) cost is too small, only at the beginning of the execution, i am confused , for flash attention V2, what can we do for vllm?

@tmm1
Copy link
Contributor

tmm1 commented Aug 3, 2023

  • update xformers to latest version

hi, which version specifically? i don't think flash v2 support has been released yet, so you would have to install from git. also there are still some open PRs to bump xformers to flash-attn v2.0.4 bugfix release (facebookresearch/xformers#816).

@tmm1
Copy link
Contributor

tmm1 commented Aug 8, 2023

  • modify the line to self.attn_op = xops.fmha.flash.FwOp()
  • python3 benchmark_throughput.py --dataset=./ShareGPT_V3_unfiltered_cleaned_split.json --model=/huggingface_data/llama-7b-hf/ --tokenizer=hf-internal-testing/llama-tokenizer --num-prompts=500

I tried this as well, and there was no improvement in the benchmarks after switching to flash-attn v2.

I will try to profile the benchmark script.

@Zhuqln
Copy link

Zhuqln commented Aug 9, 2023

modify the line to self.attn_op = xops.fmha.flash.FwOp()

i dont think this one really works. because flash-attn's another important feature is to decrease the highly gpu-memory usage in super long-context like more than 5k.
when i set that line and run inference . i dont see any changes on memory usage.

@WoosukKwon
Copy link
Collaborator

Hi @nivibilla, thanks for submitting the issue. The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.

@tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance.

@nivibilla
Copy link
Author

@WoosukKwon thanks for the explanation!

@tmm1
Copy link
Contributor

tmm1 commented Aug 25, 2023

The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.

Hi, this is inaccurate since the code is still forcing xops.fmha.cutlass.FwOp to be used. If you want to take advantage of FA2, you would need to switch to xops.fmha.flash.FwOp

See benchmark results in facebookresearch/xformers#832

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Aug 29, 2023

Hi @nivibilla, thanks for submitting the issue. The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.

@tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance.

Thanks for the details @WoosukKwon . I just have a question. Why FlashAttention could not be used for decoding phase?

@learning-chip
Copy link

Why FlashAttention could not be used for decoding phase?

Its tiling strategy is not optimized for Q with seqlen=1 Dao-AILab/flash-attention#427 (comment)

@Lvjinhong
Copy link

你好@nivibilla,感谢您提交问题。最新版本的 xformers 现在使用 FlashAttention-V2 算法,因此 vLLM 现在也利用了它。请将vLLM升级到v0.1.4。
@tmm1 @Zhuqln据我了解,整体加速应该取决于您的工作负载。在推理时,FlashAttention 仅用于提示输入,从不用于解码输入。对于许多工作负载,解码阶段占用了总执行时间的大部分,因此更改为 FlashAttention V2 可能不会带来显着的加速。然而,对于其他工作负载,例如提示很长的文本摘要,我相信提示输入的计算注意力将占用执行时间的很大一部分,因此 FlashAttention V2 将对整体性能产生巨大影响。

感谢您提供详细信息@WoosukKwon。我只是有一个问题。为什么FlashAttention不能用于解码阶段?

I'm delighted to engage in this discussion. Your report has been immensely helpful, but I do have some questions. For instance, I'm curious to know if there's a performance comparison available between trtLLM and vLLM. Such information would be greatly beneficial in guiding my decision on which framework to choose.

@matanhol
Copy link

Hi @nivibilla, thanks for submitting the issue. The latest version of xformers now uses the FlashAttention-V2 algorithm, so vLLM also now takes advantage of it. Please upgrade vLLM to v0.1.4.
@tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance.

Thanks for the details @WoosukKwon . I just have a question. Why FlashAttention could not be used for decoding phase?

you assume that in summarization task most of the workload is by decoding the input. in my experimentation I saw that the scale of generation is much bigger. so, if you generate only 1-5 token then most of the workload is decoding input, there will be dependency on input length and flash attention 2 will be advantageous (as it linear in input length while naive implementation is exponential in input length). but if you generate a considerable amount of tokens, then that factor is prominent, the input decoding is negligible, and flash attention 2 has no power here.
(usually when you have long text you want a longer summarization. it doesn't make sense to summarize 1000 words article by 5 tokens)
attached link to the simulation.
please LMK if you have any comments.

https://github.com/matanhol/summarization_with_flash_attn_2_simulation

@brando90
Copy link

I tried installing vllm with flash attn but it didn't work, my attempts:

Install flash attention:
```bash
# my current vllm setup without flash
# pip install --upgrade pip
# pip install torch==2.2.1
# pip install vllm==0.4.1

# flash attn https://amzn-aws.slack.com/archives/C06Q26TNN8G/p1724182667464149
# flash-attn>=2.5.8
# pip install flash-attn
# Collabs's setup with flash
# vllm                              0.5.4
# vllm-flash-attn                   2.6.1
# flash-attn                        2.6.3
# torch                             2.4.0
# Python 3.10.8 

# try to install flash attn in a new py env
python3.11 -m venv ~/.virtualenvs/flash_attn_test_py10
source ~/.virtualenvs/flash_attn_test/bin/activate
pip install --upgrade pip
pip install -e ~/snap-cluster-setup

pip list | grep vllm
pip list | grep torch
pip list | grep flash-attn
pip list | grep vllm-flash-attn

# # didn't work
# pip install torch==2.2.1
# pip install vllm==0.4.1
# MAX_JOBS=4 pip install flash-attn --no-build-isolation --force

# this installed flash but vllm didn't say in it's output it was using it
pip install torch==2.4.0
pip install vllm==0.5.4
pip install flash-attn==2.6.3
pip install vllm-flash-attn==2.6.1

python ~/snap-cluster-setup/py_src/evals/boxed_acc_eval.py --model internlm/internlm2_5-1_8b --hf_gen_type vllm --path_2_eval_dataset ~/snap-cluster-setup/data/MATH/test --max_tokens 2048 --batch_size 100 --end 100 -n 1 --shuffle True --mode dryrun 2>&1 | tee $LOG_FILE && echo "Log file created at: $LOG_FILE"

# later try with py 3.10
# python3xxx -m venv ~/.virtualenvs/flash_attn_test_py10
# source ~/.virtualenvs/flash_attn_test_py10/bin/activate
# pip install --upgrade pip
# pip install -e ~/snap-cluster-setup
# pip install torch==2.4.0
# pip install vllm==0.5.4
# pip install flash-attn==2.6.3
# pip install vllm-flash-attn==2.6.1

@brando90
Copy link

my setting is python 3.11, that is what I really want/need.

@brando90
Copy link

related vllm general issues for vllm ver: #2747

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.