@@ -164,12 +164,12 @@ def __new__(cls):
164164 return cls ._instance
165165
166166 def initialize (self ):
167- self . packages_info = {
168- "has_aiter" : self .check_aiter (),
169- "has_flash_attn" : self .check_flash_attn (),
170- "has_long_ctx_attn" : self .check_long_ctx_attn (),
171- "diffusers_version" : self .check_diffusers_version (),
172- }
167+ packages_info = {}
168+ packages_info [ "has_aiter" ] = self .check_aiter ()
169+ packages_info [ "has_flash_attn" ] = self .check_flash_attn (packages_info )
170+ packages_info [ "has_long_ctx_attn" ] = self .check_long_ctx_attn ()
171+ packages_info [ "diffusers_version" ] = self .check_diffusers_version ()
172+ self . packages_info = packages_info
173173
174174 def check_aiter (self ):
175175 """
@@ -188,7 +188,7 @@ def check_aiter(self):
188188 return False
189189
190190
191- def check_flash_attn (self ):
191+ def check_flash_attn (self , packages_info ):
192192 if not torch .cuda .is_available ():
193193 return False
194194 if _is_musa ():
@@ -209,10 +209,11 @@ def check_flash_attn(self):
209209 raise ImportError (f"install flash_attn >= 2.6.0" )
210210 return True
211211 except ImportError :
212- logger .warning (
213- f'Flash Attention library "flash_attn" not found, '
214- f"using pytorch attention implementation"
215- )
212+ if not packages_info .get ("has_aiter" , False ):
213+ logger .warning (
214+ f'Flash Attention library "flash_attn" not found, '
215+ f"using pytorch attention implementation"
216+ )
216217 return False
217218
218219 def check_long_ctx_attn (self ):
0 commit comments