diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 6e9af1e721bb..156456c92e63 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -107,10 +107,9 @@ fi if [[ $commands == *" kernels/attention"* ]]; then commands="${commands} \ - --ignore=kernels/attention/stest_attention_selector.py \ + --ignore=kernels/attention/test_attention_selector.py \ --ignore=kernels/attention/test_blocksparse_attention.py \ --ignore=kernels/attention/test_encoder_decoder_attn.py \ - --ignore=kernels/attention/test_attention_selector.py \ --ignore=kernels/attention/test_flash_attn.py \ --ignore=kernels/attention/test_flashinfer.py \ --ignore=kernels/attention/test_prefix_prefill.py \ diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 23d71fd44525..7ec91df98b28 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -626,9 +626,6 @@ Specified using `--task generate`. !!! note Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. -!!! note - `h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80. - !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. @@ -671,11 +668,8 @@ Specified using `--task generate`. Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. !!! note - To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via - `pip install git+https://github.com/huggingface/transformers.git`. - - Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. - `--mm-processor-kwargs '{"use_audio_in_video": true}'`. + For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) + is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. #### Transcription diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index bf7be33107da..5bd75a78f2c4 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( - model="Salesforce/blip2-opt-6.7b", + model="Salesforce/blip2-opt-2.7b", limit_mm_per_prompt={modality: 1}, ) @@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData: ) -# Qwen +# Qwen-VL def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 0437bb8293ce..3722e0eb537f 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -171,7 +171,7 @@ def test_env( expected = "FLASHINFER_VLLM_V1" if use_v1 else name assert backend.get_name() == expected else: - backend = get_attn_backend(16, + backend = get_attn_backend(32, torch.float16, torch.float16, block_size, @@ -180,6 +180,17 @@ def test_env( expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name assert backend.get_name() == expected + if use_v1: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert backend.get_name() == "FLEX_ATTENTION", ( + "Should fallback to FlexAttention if head size is " + "not supported by FlashAttention") + @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("use_v1", [True, False]) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 6ecf6db56cb3..cbc2e9c87a64 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -33,9 +33,6 @@ os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" REQUIRES_V0_MODELS = [ - # V1 Test: no way to fall back for head_dim = 80 - # https://github.com/vllm-project/vllm/issues/14524 - "qwen_vl", # V1 Test: not enough KV cache space in C1. "fuyu", ] @@ -221,8 +218,7 @@ marks=[large_gpu_mark(min_gb=32)], ), "blip2": VLMTestInfo( - # TODO: Change back to 2.7b once head_dim = 80 is supported - models=["Salesforce/blip2-opt-6.7b"], + models=["Salesforce/blip2-opt-2.7b"], test_type=VLMTestType.IMAGE, prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:", img_idx_to_prompt=lambda idx: "", @@ -340,8 +336,7 @@ "h2ovl": VLMTestInfo( models = [ "h2oai/h2ovl-mississippi-800m", - # TODO: Re-enable once head_dim = 80 is supported - # "h2oai/h2ovl-mississippi-2b", + "h2oai/h2ovl-mississippi-2b", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index a424bd6798fd..3e77d3e71039 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -83,7 +83,7 @@ def gguf_model(self): QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, - # STABLELM_CONFIG, # enable this when v1 support head_size=80 + STABLELM_CONFIG, DOLPHIN_CONFIG, # STARCODER_CONFIG, # broken ] diff --git a/tests/models/registry.py b/tests/models/registry.py index 728c18643a00..aba01cefe993 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -240,8 +240,9 @@ def check_available_online( "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", trust_remote_code=True), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), - "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True), + "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), + # Blocksparse attention not supported in V1 yet "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True, v0_only=True), @@ -258,10 +259,8 @@ def check_available_online( "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), - "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 - v0_only=True), - "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t", - v0_only=True), + "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 + "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -330,8 +329,7 @@ def check_available_online( "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 - extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501 - v0_only=True), + extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 @@ -359,8 +357,7 @@ def check_available_online( trust_remote_code=True), "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 - trust_remote_code=True, - v0_only=True), + trust_remote_code=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 max_model_len=10240), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 0ac684fdd30d..25bc96bf3266 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): model_info.check_transformers_version(on_fail="skip") # FIXME: Possible memory leak in the previous tests? - if model_arch == "GraniteSpeechForConditionalGeneration": + if model_arch in ("GraniteSpeechForConditionalGeneration", + "KimiVLForConditionalGeneration"): pytest.skip("Avoid OOM") # Avoid OOM and reduce initialization time by only using 1 layer diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0c79aaf13551..f0ad68b16405 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -310,7 +310,8 @@ def __init__( # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA else: - if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, + _Backend.FLEX_ATTENTION): backend = _Backend.XFORMERS self.attn_backend = backend if backend in { diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index cb577fa67302..df14aea729f3 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -4,7 +4,7 @@ import os from contextlib import contextmanager from functools import cache -from typing import Generator, Optional, Type +from typing import Generator, Optional, Union import torch @@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend +def supports_head_size( + attn_backend: Union[str, type[AttentionBackend]], + head_size: int, +) -> bool: + if isinstance(attn_backend, str): + try: + attn_backend = resolve_obj_by_qualname(attn_backend) + except ImportError: + return False + + assert isinstance(attn_backend, type) + + # TODO: Update the interface once V0 is removed + if get_supported_head_sizes := getattr(attn_backend, + "get_supported_head_sizes", None): + return head_size in get_supported_head_sizes() + if validate_head_size := getattr(attn_backend, "validate_head_size", None): + try: + validate_head_size(head_size) + return True + except Exception: + return False + + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "head size validation") + + def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -87,7 +114,7 @@ def get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_mla: bool = False, -) -> Type[AttentionBackend]: +) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong # value to be returned from the cache if the value changes between calls. @@ -115,7 +142,7 @@ def _cached_get_attn_backend( is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, -) -> Type[AttentionBackend]: +) -> type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( diff --git a/vllm/config.py b/vllm/config.py index a1d8c32953b0..724f69a3887f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2319,7 +2319,7 @@ def _verify_args(self) -> Self: if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: logger.warning( - "max_num_batched_tokens (%d) exceeds max_num_seqs" + "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, self.max_num_seqs * self.max_model_len) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f82c1e569977..0a5f4004e448 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -234,35 +234,44 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return ("vllm.attention.backends." "flashmla.FlashMLABackend") if use_v1: + FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 + FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") - return "vllm.v1.attention.backends.flashinfer.FlashInferBackend" + return FLASHINFER_V1 elif selected_backend == _Backend.FLEX_ATTENTION: - logger.info("Using FlexAttenion backend on V1 engine.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1: logger.info_once("Using Triton backend on V1 engine.") - return ("vllm.v1.attention.backends." - "triton_attn.TritonAttentionBackend") + return TRITON_ATTN_VLLM_V1 elif selected_backend == _Backend.FLASH_ATTN: logger.info_once("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "flash_attn.FlashAttentionBackend") + return FLASH_ATTN_V1 + + from vllm.attention.selector import supports_head_size # Default backends for V1 engine - # Prefer FlashInfer for Blackwell GPUs if installed + # FP32 is only supported by FlexAttention if dtype not in (torch.float16, torch.bfloat16): logger.info_once( - f"Using FlexAttenion backend for {dtype} on V1 engine.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - if cls.is_device_capability(100): + "Using FlexAttention backend for %s on V1 engine.", + dtype, + ) + return FLEX_ATTENTION_V1 + + # Prefer FlashInfer for Blackwell GPUs if installed + if cls.is_device_capability(100) and \ + supports_head_size(FLASHINFER_V1, head_size): try: import flashinfer # noqa: F401 logger.info_once( "Using FlashInfer backend on V1 engine by default for " "Blackwell (SM 10.0) GPUs.") - return ("vllm.v1.attention.backends." - "flashinfer.FlashInferBackend") + return FLASHINFER_V1 except ImportError: logger.info_once( "FlashInfer failed to import for V1 engine on " @@ -270,10 +279,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "install FlashInfer for better performance.") pass # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80): + if cls.has_device_capability(80) and \ + supports_head_size(FLASH_ATTN_V1, head_size): logger.info_once("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "flash_attn.FlashAttentionBackend") + return FLASH_ATTN_V1 + + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index e493f1b8088b..37c04c7a029e 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -3,7 +3,8 @@ import numpy as np import torch -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl, TorchSDPAMetadata) from vllm.attention.backends.utils import CommonAttentionState @@ -17,9 +18,24 @@ from vllm.v1.worker.gpu_input_batch import InputBatch -class TorchSDPABackend: +class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return PagedAttention.get_supported_head_sizes() + + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "TORCH_SDPA_VLLM_V1" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 6182b2f9b2bd..fbc13c06c65a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -44,10 +44,21 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "FLASH_ATTN_VLLM_V1" @@ -416,12 +427,7 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}. " - "Set VLLM_USE_V1=0 to use another attention backend.") + FlashAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 03a2ed7139c7..860309faa905 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -38,10 +38,22 @@ class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "FLASHINFER_VLLM_V1" @@ -207,14 +219,8 @@ def query_start_loc(self): return self.qo_indptr def __post_init__(self): - # Refer to - # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 - supported_head_sizes = FlashInferBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") + if self.head_dim is not None: + FlashInferBackend.validate_head_size(self.head_dim) class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index ebd5914ee40a..a8c5f464aa32 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" - +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -21,9 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable -if current_platform.is_cuda(): - pass - logger = init_logger(__name__) if TYPE_CHECKING: @@ -45,9 +42,9 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: - return [16, 32, 64, 96, 128, 160, 192, 224, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + return # FlexAttention supports any head size @staticmethod def get_name() -> str: @@ -384,12 +381,8 @@ def __init__( raise NotImplementedError( "FlexAttention does not support kv sharing yet.") - support_head_sizes = FlexAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}. " - "Set VLLM_USE_V1=0 to use another attention backend.") + FlexAttentionBackend.validate_head_size(head_size) + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet") @@ -464,12 +457,20 @@ def forward( # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) - # default M=64, N=64 may run out of shared memory on - # some GPUs with fp32, so we use smaller M and N. - extra_kernel_options = { - "BLOCK_M": 32, - "BLOCK_N": 32 - } if query.dtype == torch.float32 else {} + # default M=64, N=64 may run out of shared memory on some GPUs + # TODO: Explicit configs for each GPU? + # Not sure how to calculate the shared memory requirement + extra_kernel_options = defaultdict[str, int](lambda: 64) + if query.dtype == torch.float32: + extra_kernel_options["BLOCK_M"] //= 2 + extra_kernel_options["BLOCK_N"] //= 2 + if current_platform.is_cuda(): + device_props = torch.cuda.get_device_properties() + max_shared_memory = device_props.shared_memory_per_block_optin + if max_shared_memory < 144 * 1024: + extra_kernel_options["BLOCK_M"] //= 2 + extra_kernel_options["BLOCK_N"] //= 2 + out = flex_attention_compiled( query, key_cache, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 39379b11863c..f2aaf59a40f8 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -254,10 +254,21 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_head_sizes(cls) -> list[int]: return [576] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @dataclass class MLACommonPrefillMetadata: @@ -320,12 +331,8 @@ class MLACommonMetadata(Generic[D]): prefill: Optional[MLACommonPrefillMetadata] = None def __post_init__(self): - supported_head_sizes = MLACommonBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") + if self.head_dim is not None: + MLACommonBackend.validate_head_size(self.head_dim) M = TypeVar("M", bound=MLACommonMetadata) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 63537384a1da..6a78b03dce86 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -314,10 +314,21 @@ class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "FLASH_ATTN_VLLM_V1" @@ -428,14 +439,7 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = \ - AiterFlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by " - "AiterFlashAttention. " - f"Supported head sizes are: {support_head_sizes}. " - "Set VLLM_USE_V1=0 to use another attention backend.") + AiterFlashAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4c5a1a755c1a..cdaff2f6a40f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -190,10 +190,21 @@ class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - @staticmethod - def get_supported_head_sizes() -> list[int]: + @classmethod + def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @classmethod + def validate_head_size(cls, head_size: int) -> None: + supported_head_sizes = cls.get_supported_head_sizes() + if head_size not in supported_head_sizes: + attn_type = cls.__name__.removesuffix("Backend") + raise ValueError( + f"Head size {head_size} is not supported by {attn_type}. " + f"Supported head sizes are: {supported_head_sizes}. " + "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "FlexAttention backend which supports all head sizes.") + @staticmethod def get_name() -> str: return "TRITON_ATTN_VLLM_V1" @@ -268,11 +279,7 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = TritonAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by TritonAttention. " - f"Supported head sizes are: {support_head_sizes}.") + TritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 156f5764e8dc..6661d984a771 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,8 +12,8 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, - FlashAttentionMetadata) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel