Skip to content

Commit f9cc372

Browse files
committed
refactoring backends(pt3)
Signed-off-by: vnadathur <[email protected]>
1 parent 024bf5f commit f9cc372

File tree

12 files changed

+114
-147
lines changed

12 files changed

+114
-147
lines changed

tests/v1/attention/test_mla_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,4 +829,4 @@ def test_backend_correctness(
829829

830830
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
831831
detailed_msg = "\n".join(failures)
832-
pytest.fail(f"{summary}\n{detailed_msg}")
832+
pytest.fail(f"{summary}\n{detailed_msg}")

tests/v1/attention/test_sparse_mla_backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm import _custom_ops as ops
2424
from vllm.attention.ops import flashmla
2525
from vllm.model_executor.layers.linear import ColumnParallelLinear
26-
from vllm.utils import cdiv
26+
from vllm.utils.math_utils import cdiv
2727
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
2828
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
2929

@@ -389,4 +389,4 @@ def test_sparse_backend_decode_correctness(
389389
)
390390
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
391391
out = split_prefill_chunks(seq_lens, max_buf, start)
392-
assert out == expected
392+
assert out == expected

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,4 @@ def forward(
255255

256256

257257
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
258-
return kv_cache_dtype != "auto"
258+
return kv_cache_dtype != "auto"

vllm/attention/layer.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,6 @@ def maybe_get_vit_flash_attn_backend(
104104

105105
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
106106
use_upstream_fa = True
107-
elif current_platform.is_cuda():
108-
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
109-
torch.get_default_dtype()
110-
):
111-
attn_backend = _Backend.FLASH_ATTN
112-
use_upstream_fa = True
113-
elif current_platform.is_xpu():
114-
assert attn_backend == _Backend.FLASH_ATTN, (
115-
"XPU platform only supports FLASH_ATTN as vision attention backend."
116-
)
117-
use_upstream_fa = False
118-
else:
119-
return _Backend.TORCH_SDPA, None
120107

121108
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
122109
if attn_backend == _Backend.ROCM_AITER_FA:
@@ -125,7 +112,7 @@ def maybe_get_vit_flash_attn_backend(
125112
if use_upstream_fa:
126113
from flash_attn import flash_attn_varlen_func
127114
else:
128-
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
115+
from vllm.vllm_flash_attn import flash_attn_varlen_func
129116
else:
130117
flash_attn_varlen_func = None
131118

@@ -482,18 +469,22 @@ def __init__(
482469
# If vllm native fa is selected, we use it directly.
483470
use_upstream_fa = False
484471

485-
self.attn_backend = (
486-
backend
487-
if backend
488-
in {
489-
_Backend.TORCH_SDPA,
490-
_Backend.XFORMERS,
491-
_Backend.PALLAS,
492-
_Backend.ROCM_AITER_FA,
493-
_Backend.FLASH_ATTN,
494-
}
495-
else _Backend.TORCH_SDPA
496-
)
472+
if current_platform.is_xpu():
473+
# currently, only torch_sdpa is supported on xpu
474+
self.attn_backend = _Backend.TORCH_SDPA
475+
else:
476+
self.attn_backend = (
477+
backend
478+
if backend
479+
in {
480+
_Backend.TORCH_SDPA,
481+
_Backend.XFORMERS,
482+
_Backend.PALLAS,
483+
_Backend.ROCM_AITER_FA,
484+
_Backend.FLASH_ATTN,
485+
}
486+
else _Backend.TORCH_SDPA
487+
)
497488

498489
self.attn_backend, self._flash_attn_varlen_func = (
499490
maybe_get_vit_flash_attn_backend(

0 commit comments

Comments
 (0)