From 6e69cdd4722f20698c96f7721872ae112ccc6ed6 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Mon, 30 Jun 2025 22:10:44 +0300 Subject: [PATCH] Fix import error when flash attention 3 is installed Referring to https://github.com/Dao-AILab/flash-attention/blob/7661781d001e0900121c000a0aaf21b3f94337d6/README.md?plain=1#L61-L62 `flash_attn_interface` shouldn't be imported from flash_attn_3 but instead directly, otherwise, the import error will happen. Signed-off-by: Hollow Man --- .../attention/dot_product_attention/backends.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b3b7630df3..6aa3e6fe8a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -115,15 +115,15 @@ flash_attn_with_kvcache_v3 = None # pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: - from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_3.flash_attn_interface import ( + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flash_attn_3.flash_attn_interface import ( + from flash_attn_interface import ( flash_attn_with_kvcache as flash_attn_with_kvcache_v3, ) - from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 fa_utils.set_flash_attention_3_params()