File tree Expand file tree Collapse file tree 1 file changed +9
-8
lines changed
Expand file tree Collapse file tree 1 file changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -264,14 +264,15 @@ def get_vit_attn_backend(
264264 cls , head_size : int , dtype : torch .dtype
265265 ) -> "AttentionBackendEnum" :
266266 # Try FlashAttention first
267- try :
268- backend_class = AttentionBackendEnum .FLASH_ATTN .get_class ()
269- if backend_class .supports_head_size (
270- head_size
271- ) and backend_class .supports_dtype (dtype ):
272- return AttentionBackendEnum .FLASH_ATTN
273- except ImportError :
274- pass
267+ if (cc := cls .get_device_capability ()) and cc .major >= 8 :
268+ try :
269+ backend_class = AttentionBackendEnum .FLASH_ATTN .get_class ()
270+ if backend_class .supports_head_size (
271+ head_size
272+ ) and backend_class .supports_dtype (dtype ):
273+ return AttentionBackendEnum .FLASH_ATTN
274+ except ImportError :
275+ pass
275276
276277 return AttentionBackendEnum .TORCH_SDPA
277278
You can’t perform that action at this time.
0 commit comments