Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
supported_sizes = backend.get_class().supported_kernel_block_sizes
supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
Expand Down
4 changes: 3 additions & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
supported_kernel_block_sizes = supported_sizes
@staticmethod
def get_supported_kernel_block_sizes():
return supported_sizes

return _MockBackend()

Expand Down
10 changes: 7 additions & 3 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)]

@staticmethod
@abstractmethod
def get_name() -> str:
Expand Down Expand Up @@ -142,10 +145,11 @@ def supports_block_size(cls, block_size: int | None) -> bool:
if block_size not in valid_sizes:
return False

if not cls.supported_kernel_block_sizes:
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True

for supported_size in cls.supported_kernel_block_sizes:
for supported_size in supported_kernel_block_sizes:
if isinstance(supported_size, MultipleOf):
supported_size = supported_size.base
# With hybrid_blocks feature, the framework-level block size
Expand Down
27 changes: 21 additions & 6 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
get_scheduler_metadata,
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
Expand All @@ -56,11 +56,26 @@
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
return [MultipleOf(16)]

@staticmethod
def get_name() -> str:
Expand Down
12 changes: 6 additions & 6 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from typing_extensions import override

from vllm import envs
from vllm.attention.backends.abstract import (
Expand Down Expand Up @@ -275,17 +274,19 @@ def run(
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# Note: Not sure for all platforms,
# but on Blackwell, only support a page size of
# 16, 32, 64
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# Note: Not sure for all platforms, but on Blackwell,
# only support a page size of 16, 32, 64.
return [16, 32, 64]

@staticmethod
def get_name() -> str:
return "FLASHINFER"
Expand Down Expand Up @@ -566,7 +567,6 @@ def __init__(
)

@classmethod
@override
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit rightfully complaining as this is not an ovveride

def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"],
vllm_config: VllmConfig,
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):

class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [128]

@staticmethod
def get_name() -> str:
return "CUTLASS_MLA"
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@

class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):

class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [32, 64]

@staticmethod
def get_name() -> str:
return "FLASHINFER_MLA"
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@

class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]

@staticmethod
def get_name() -> str:
return "FLASHMLA"
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]

@staticmethod
def get_name() -> str:
return "FLASHMLA_SPARSE"
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@


class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
1 if current_platform.is_rocm() else 64
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@


class AiterMLABackend(MLACommonBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]

@staticmethod
def get_name() -> str:
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,10 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,17 @@ class TritonAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@staticmethod
def get_name() -> str:
return "TRITON_ATTN"
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
class XFormersAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4618,7 +4618,7 @@ def block_size_is_supported(
"""
for backend in backends:
is_supported = False
for supported_size in backend.supported_kernel_block_sizes:
for supported_size in backend.get_supported_kernel_block_sizes():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
Expand Down Expand Up @@ -4649,7 +4649,7 @@ def block_size_is_supported(
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.supported_kernel_block_sizes
for supported_size in backend.get_supported_kernel_block_sizes()
if isinstance(supported_size, int)
)

Expand Down