diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b3b7630df3..6aa3e6fe8a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -115,15 +115,15 @@ flash_attn_with_kvcache_v3 = None # pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: - from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_3.flash_attn_interface import ( + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flash_attn_3.flash_attn_interface import ( + from flash_attn_interface import ( flash_attn_with_kvcache as flash_attn_with_kvcache_v3, ) - from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 fa_utils.set_flash_attention_3_params()