Skip to content

Commit a995c14

Browse files
committed
extract mm encoder attention as custom op.
Co-authored-by: Isotr0py <[email protected]> Co-authored-by: tjtanaa <[email protected]> Signed-off-by: shen-shanshan <[email protected]>
1 parent 65ee972 commit a995c14

File tree

16 files changed

+559
-279
lines changed

16 files changed

+559
-279
lines changed

vllm/attention/layer.py

Lines changed: 24 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -57,42 +57,29 @@
5757

5858

5959
def maybe_get_vit_flash_attn_backend(
60-
attn_backend: AttentionBackendEnum,
61-
attn_backend_override: AttentionBackendEnum | None = None,
62-
) -> tuple[AttentionBackendEnum, Callable | None]:
63-
if current_platform.is_rocm():
64-
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
65-
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
66-
elif (
67-
attn_backend_override is None
68-
and on_gfx9()
69-
and attn_backend == AttentionBackendEnum.FLASH_ATTN
70-
):
71-
pass
72-
else:
73-
return AttentionBackendEnum.TORCH_SDPA, None
74-
elif current_platform.is_cuda():
75-
pass
76-
elif current_platform.is_xpu():
77-
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
78-
"XPU platform only supports FLASH_ATTN as vision attention backend."
79-
)
80-
pass
81-
else:
82-
return AttentionBackendEnum.TORCH_SDPA, None
83-
84-
if attn_backend in {
85-
AttentionBackendEnum.FLASH_ATTN,
86-
AttentionBackendEnum.ROCM_AITER_FA,
87-
}:
88-
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
89-
from aiter import flash_attn_varlen_func
90-
else:
91-
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
60+
attn_backend: AttentionBackendEnum | None,
61+
) -> Callable | None:
62+
# At this point,
63+
# we already have the attn_backend,
64+
# overriding logic is done in the platform-specific implementation.
65+
# so we don't need to override backend here.
66+
# Just return the attn_backend and flash_attn_varlen_func.
67+
68+
if (
69+
attn_backend == AttentionBackendEnum.FLASH_ATTN
70+
and current_platform.is_cuda_alike()
71+
):
72+
from flash_attn import flash_attn_varlen_func
73+
elif attn_backend == AttentionBackendEnum.FLASH_ATTN and current_platform.is_xpu():
74+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
75+
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
76+
from aiter import flash_attn_varlen_func
9277
else:
9378
flash_attn_varlen_func = None
9479

95-
return attn_backend, flash_attn_varlen_func
80+
# if attn_backend is TORCH_SDPA,
81+
# it will reach here and the flash_attn_varlen_func will be None.
82+
return flash_attn_varlen_func
9683

9784

9885
def _init_kv_cache_quant(
@@ -467,29 +454,15 @@ def __init__(
467454
attn_backend_override = None
468455
if multimodal_config is not None:
469456
attn_backend_override = multimodal_config.mm_encoder_attn_backend
470-
backend = get_vit_attn_backend(
457+
458+
self.backend = get_vit_attn_backend(
471459
head_size=head_size,
472460
dtype=dtype,
473461
attn_backend_override=attn_backend_override,
474462
)
475463

476-
self.attn_backend = (
477-
backend
478-
if backend
479-
in {
480-
AttentionBackendEnum.TORCH_SDPA,
481-
AttentionBackendEnum.PALLAS,
482-
AttentionBackendEnum.ROCM_AITER_FA,
483-
AttentionBackendEnum.FLASH_ATTN,
484-
}
485-
else AttentionBackendEnum.TORCH_SDPA
486-
)
487-
488-
self.attn_backend, self._flash_attn_varlen_func = (
489-
maybe_get_vit_flash_attn_backend(
490-
self.attn_backend,
491-
attn_backend_override=attn_backend_override,
492-
)
464+
self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
465+
self.attn_backend,
493466
)
494467

495468
self.is_flash_attn_backend = self.attn_backend in {

0 commit comments

Comments
 (0)