Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .buildkite/scripts/hardware_ci/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,9 @@ fi

if [[ $commands == *" kernels/attention"* ]]; then
commands="${commands} \
--ignore=kernels/attention/stest_attention_selector.py \
--ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_blocksparse_attention.py \
--ignore=kernels/attention/test_encoder_decoder_attn.py \
--ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_flash_attn.py \
--ignore=kernels/attention/test_flashinfer.py \
--ignore=kernels/attention/test_prefix_prefill.py \
Expand Down
10 changes: 2 additions & 8 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,6 @@ Specified using `--task generate`.
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.

!!! note
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.

!!! note
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.

Expand Down Expand Up @@ -671,11 +668,8 @@ Specified using `--task generate`.
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.

!!! note
To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers.git`.

Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the out-of-date documentation


#### Transcription

Expand Down
4 changes: 2 additions & 2 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs(
model="Salesforce/blip2-opt-6.7b",
model="Salesforce/blip2-opt-2.7b",
limit_mm_per_prompt={modality: 1},
)

Expand Down Expand Up @@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
)


# Qwen
# Qwen-VL
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"

Expand Down
13 changes: 12 additions & 1 deletion tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_env(
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
backend = get_attn_backend(32,
torch.float16,
torch.float16,
block_size,
Expand All @@ -180,6 +180,17 @@ def test_env(
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected

if use_v1:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == "FLEX_ATTENTION", (
"Should fallback to FlexAttention if head size is "
"not supported by FlashAttention")


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
Expand Down
9 changes: 2 additions & 7 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

REQUIRES_V0_MODELS = [
# V1 Test: no way to fall back for head_dim = 80
# https://github.com/vllm-project/vllm/issues/14524
"qwen_vl",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model actually has head size 128

# V1 Test: not enough KV cache space in C1.
"fuyu",
]
Expand Down Expand Up @@ -221,8 +218,7 @@
marks=[large_gpu_mark(min_gb=32)],
),
"blip2": VLMTestInfo(
# TODO: Change back to 2.7b once head_dim = 80 is supported
models=["Salesforce/blip2-opt-6.7b"],
models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
img_idx_to_prompt=lambda idx: "",
Expand Down Expand Up @@ -340,8 +336,7 @@
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
# TODO: Re-enable once head_dim = 80 is supported
# "h2oai/h2ovl-mississippi-2b",
"h2oai/h2ovl-mississippi-2b",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
Expand Down
2 changes: 1 addition & 1 deletion tests/models/quantization/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def gguf_model(self):
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
# STABLELM_CONFIG, # enable this when v1 support head_size=80
STABLELM_CONFIG,
DOLPHIN_CONFIG,
# STARCODER_CONFIG, # broken
]
Expand Down
15 changes: 6 additions & 9 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,9 @@ def check_available_online(
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
# Blocksparse attention not supported in V1 yet
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True,
v0_only=True),
Expand All @@ -258,10 +259,8 @@ def check_available_online(
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
v0_only=True),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
v0_only=True),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
Expand Down Expand Up @@ -330,8 +329,7 @@ def check_available_online(
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
v0_only=True),
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
Expand Down Expand Up @@ -359,8 +357,7 @@ def check_available_online(
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True,
v0_only=True),
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
model_info.check_transformers_version(on_fail="skip")

# FIXME: Possible memory leak in the previous tests?
if model_arch == "GraniteSpeechForConditionalGeneration":
if model_arch in ("GraniteSpeechForConditionalGeneration",
"KimiVLForConditionalGeneration"):
pytest.skip("Avoid OOM")

# Avoid OOM and reduce initialization time by only using 1 layer
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def __init__(
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
Expand Down
33 changes: 30 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from contextlib import contextmanager
from functools import cache
from typing import Generator, Optional, Type
from typing import Generator, Optional, Union

import torch

Expand Down Expand Up @@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend


def supports_head_size(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False

assert isinstance(attn_backend, type)

# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
try:
validate_head_size(head_size)
return True
except Exception:
return False

raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")


def get_attn_backend(
head_size: int,
dtype: torch.dtype,
Expand All @@ -87,7 +114,7 @@ def get_attn_backend(
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
Expand Down Expand Up @@ -115,7 +142,7 @@ def _cached_get_attn_backend(
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
) -> type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ def _verify_args(self) -> Self:

if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs"
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
self.max_num_seqs * self.max_model_len)
Expand Down
44 changes: 28 additions & 16 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,46 +234,58 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501

if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
return TRITON_ATTN_VLLM_V1
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
return FLASH_ATTN_V1

from vllm.attention.selector import supports_head_size

# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
# FP32 is only supported by FlexAttention
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100):
"Using FlexAttention backend for %s on V1 engine.",
dtype,
)
return FLEX_ATTENTION_V1

# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100) and \
supports_head_size(FLASHINFER_V1, head_size):
try:
import flashinfer # noqa: F401
logger.info_once(
"Using FlashInfer backend on V1 engine by default for "
"Blackwell (SM 10.0) GPUs.")
return ("vllm.v1.attention.backends."
"flashinfer.FlashInferBackend")
return FLASHINFER_V1
except ImportError:
logger.info_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if cls.has_device_capability(80) and \
supports_head_size(FLASH_ATTN_V1, head_size):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
return FLASH_ATTN_V1

logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1

# Backends for V0 engine
if selected_backend == _Backend.FLASHINFER:
Expand Down
20 changes: 18 additions & 2 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import torch

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState
Expand All @@ -17,9 +18,24 @@
from vllm.v1.worker.gpu_input_batch import InputBatch


class TorchSDPABackend:
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return PagedAttention.get_supported_head_sizes()

@classmethod
def validate_head_size(cls, head_size: int) -> None:
supported_head_sizes = cls.get_supported_head_sizes()
if head_size not in supported_head_sizes:
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {supported_head_sizes}. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")

@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
Expand Down
Loading