|
57 | 57 |
|
58 | 58 |
|
59 | 59 | def maybe_get_vit_flash_attn_backend( |
60 | | - attn_backend: AttentionBackendEnum, |
61 | | - attn_backend_override: AttentionBackendEnum | None = None, |
62 | | -) -> tuple[AttentionBackendEnum, Callable | None]: |
63 | | - if current_platform.is_rocm(): |
64 | | - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): |
65 | | - attn_backend = AttentionBackendEnum.ROCM_AITER_FA |
66 | | - elif ( |
67 | | - attn_backend_override is None |
68 | | - and on_gfx9() |
69 | | - and attn_backend == AttentionBackendEnum.FLASH_ATTN |
70 | | - ): |
71 | | - pass |
72 | | - else: |
73 | | - return AttentionBackendEnum.TORCH_SDPA, None |
74 | | - elif current_platform.is_cuda(): |
75 | | - pass |
76 | | - elif current_platform.is_xpu(): |
77 | | - assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( |
78 | | - "XPU platform only supports FLASH_ATTN as vision attention backend." |
79 | | - ) |
80 | | - pass |
81 | | - else: |
82 | | - return AttentionBackendEnum.TORCH_SDPA, None |
83 | | - |
84 | | - if attn_backend in { |
85 | | - AttentionBackendEnum.FLASH_ATTN, |
86 | | - AttentionBackendEnum.ROCM_AITER_FA, |
87 | | - }: |
88 | | - if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: |
89 | | - from aiter import flash_attn_varlen_func |
90 | | - else: |
91 | | - from vllm.attention.utils.fa_utils import flash_attn_varlen_func |
| 60 | + attn_backend: AttentionBackendEnum | None, |
| 61 | +) -> Callable | None: |
| 62 | + # At this point, |
| 63 | + # we already have the attn_backend, |
| 64 | + # overriding logic is done in the platform-specific implementation. |
| 65 | + # so we don't need to override backend here. |
| 66 | + # Just return the attn_backend and flash_attn_varlen_func. |
| 67 | + |
| 68 | + if ( |
| 69 | + attn_backend == AttentionBackendEnum.FLASH_ATTN |
| 70 | + and current_platform.is_cuda_alike() |
| 71 | + ): |
| 72 | + from flash_attn import flash_attn_varlen_func |
| 73 | + elif attn_backend == AttentionBackendEnum.FLASH_ATTN and current_platform.is_xpu(): |
| 74 | + from vllm.attention.utils.fa_utils import flash_attn_varlen_func |
| 75 | + elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA: |
| 76 | + from aiter import flash_attn_varlen_func |
92 | 77 | else: |
93 | 78 | flash_attn_varlen_func = None |
94 | 79 |
|
95 | | - return attn_backend, flash_attn_varlen_func |
| 80 | + # if attn_backend is TORCH_SDPA, |
| 81 | + # it will reach here and the flash_attn_varlen_func will be None. |
| 82 | + return flash_attn_varlen_func |
96 | 83 |
|
97 | 84 |
|
98 | 85 | def _init_kv_cache_quant( |
@@ -467,29 +454,15 @@ def __init__( |
467 | 454 | attn_backend_override = None |
468 | 455 | if multimodal_config is not None: |
469 | 456 | attn_backend_override = multimodal_config.mm_encoder_attn_backend |
470 | | - backend = get_vit_attn_backend( |
| 457 | + |
| 458 | + self.backend = get_vit_attn_backend( |
471 | 459 | head_size=head_size, |
472 | 460 | dtype=dtype, |
473 | 461 | attn_backend_override=attn_backend_override, |
474 | 462 | ) |
475 | 463 |
|
476 | | - self.attn_backend = ( |
477 | | - backend |
478 | | - if backend |
479 | | - in { |
480 | | - AttentionBackendEnum.TORCH_SDPA, |
481 | | - AttentionBackendEnum.PALLAS, |
482 | | - AttentionBackendEnum.ROCM_AITER_FA, |
483 | | - AttentionBackendEnum.FLASH_ATTN, |
484 | | - } |
485 | | - else AttentionBackendEnum.TORCH_SDPA |
486 | | - ) |
487 | | - |
488 | | - self.attn_backend, self._flash_attn_varlen_func = ( |
489 | | - maybe_get_vit_flash_attn_backend( |
490 | | - self.attn_backend, |
491 | | - attn_backend_override=attn_backend_override, |
492 | | - ) |
| 464 | + self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( |
| 465 | + self.attn_backend, |
493 | 466 | ) |
494 | 467 |
|
495 | 468 | self.is_flash_attn_backend = self.attn_backend in { |
|
0 commit comments