Skip to content

Commit e585d98

Browse files
avjvesavjvess
andauthored
Refactor packages info check and gate FA logging (#571)
Co-authored-by: Aleksi Vesanto <[email protected]>
1 parent cd06115 commit e585d98

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

xfuser/envs.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def __new__(cls):
164164
return cls._instance
165165

166166
def initialize(self):
167-
self.packages_info = {
168-
"has_aiter": self.check_aiter(),
169-
"has_flash_attn": self.check_flash_attn(),
170-
"has_long_ctx_attn": self.check_long_ctx_attn(),
171-
"diffusers_version": self.check_diffusers_version(),
172-
}
167+
packages_info = {}
168+
packages_info["has_aiter"] = self.check_aiter()
169+
packages_info["has_flash_attn"] = self.check_flash_attn(packages_info)
170+
packages_info["has_long_ctx_attn"] = self.check_long_ctx_attn()
171+
packages_info["diffusers_version"] = self.check_diffusers_version()
172+
self.packages_info = packages_info
173173

174174
def check_aiter(self):
175175
"""
@@ -188,7 +188,7 @@ def check_aiter(self):
188188
return False
189189

190190

191-
def check_flash_attn(self):
191+
def check_flash_attn(self, packages_info):
192192
if not torch.cuda.is_available():
193193
return False
194194
if _is_musa():
@@ -209,10 +209,11 @@ def check_flash_attn(self):
209209
raise ImportError(f"install flash_attn >= 2.6.0")
210210
return True
211211
except ImportError:
212-
logger.warning(
213-
f'Flash Attention library "flash_attn" not found, '
214-
f"using pytorch attention implementation"
215-
)
212+
if not packages_info.get("has_aiter", False):
213+
logger.warning(
214+
f'Flash Attention library "flash_attn" not found, '
215+
f"using pytorch attention implementation"
216+
)
216217
return False
217218

218219
def check_long_ctx_attn(self):

0 commit comments

Comments
 (0)