Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING

import vllm.envs as envs
Expand Down Expand Up @@ -399,11 +400,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
# easily by changing the way we layout chunks in the
# mamba2 kernels.

from math import gcd

def lcm(a, b):
return a * b // gcd(a, b)

base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)

Expand Down
162 changes: 81 additions & 81 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3881,6 +3881,7 @@ def get_attn_backends_for_group(

def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]],
kv_cache_group_id: int,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
Expand All @@ -3890,6 +3891,7 @@ def create_attn_groups(
kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
Expand All @@ -3898,9 +3900,9 @@ def create_attn_groups(
attn_groups.append(attn_group)
return attn_groups

for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
for i, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups):
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends))
self.attn_groups.append(create_attn_groups(attn_backends, i))

# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
Expand Down Expand Up @@ -4051,104 +4053,93 @@ def calculate_reorder_batch_threshold(self) -> None:
else:
self.reorder_batch_threshold = reorder_batch_threshold_i

def _find_compatible_block_sizes(
self,
kv_manager_block_size: int,
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.

Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False

Returns:
Compatible block size(s) based on return_all parameter

Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []

for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)

if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")

return compatible_sizes if return_all else [max(compatible_sizes)]

def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select common block size for all backends.
Select a block size that is supported by all backends and is divisible by
kv_manager_block_size.

If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.

Args:
kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups

Returns:
Block size supported by all backends,
prioritizing cache_config.block_size
The selected block size

Raises:
ValueError: If no common block size found
ValueError: If no valid block size found
"""
all_backend_supports = []

for attn_group in attn_groups:
compatible_sizes = self._find_compatible_block_sizes(
kv_manager_block_size, attn_group.backend, return_all=True
)
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
all_backend_supports.append(set(supported_sizes))

common_supported_sizes = set.intersection(*all_backend_supports)

if not common_supported_sizes:
error_msg = f"No common block size for {kv_manager_block_size}. "
for i, attn_group in enumerate(attn_groups):
supported = all_backend_supports[i]
error_msg += (
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
)
raise ValueError(error_msg)

if self.cache_config.block_size in common_supported_sizes:
return self.cache_config.block_size
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
"""
Check if the block size is supported by all backends.
"""
for backend in backends:
is_supported = False
for supported_size in backend.get_supported_kernel_block_size():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
elif isinstance(supported_size, MultipleOf):
if block_size % supported_size.base == 0:
is_supported = True
else:
raise ValueError(f"Unknown supported size: {supported_size}")
if not is_supported:
return False
return True

backends = [group.backend for group in attn_groups]

# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
if block_size_is_supported(backends, kv_manager_block_size):
return kv_manager_block_size

# Case 2: otherwise, the block_size must be a `int`-format supported size of
# at lease one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b is divisible by kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.get_supported_kernel_block_size()
if isinstance(supported_size, int)
)

return max(common_supported_sizes)
for supported_size in sorted(all_int_supported_sizes, reverse=True):
if block_size_is_supported(backends, supported_size):
return supported_size
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new implementation raises a less informative error message compared to the previous version when no common block size is found. The old error message listed the supported block sizes for each backend, which is very helpful for debugging. It would be great to restore that level of detail in the error message.

Suggested change
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
error_msg = f"No common block size for {kv_manager_block_size}. "
for backend in backends:
supported_sizes = backend.get_supported_kernel_block_size()
error_msg += f"Backend {backend.__name__} supports: {supported_sizes}. "
raise ValueError(error_msg)


def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.

Args:
kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each attention backend.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]

# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)

if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
]:
Expand Down Expand Up @@ -4261,6 +4252,7 @@ def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
kernel_block_sizes: list[int],
) -> dict[str, torch.Tensor]:
"""
Reshape the KV cache tensors to the desired shape and dtype.
Expand All @@ -4269,6 +4261,7 @@ def _reshape_kv_cache_tensors(
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
kernel_block_sizes: The kernel block sizes for attention backends.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
Expand All @@ -4278,6 +4271,7 @@ def _reshape_kv_cache_tensors(
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
for layer_name in group.layer_names:
Comment on lines 4385 to 4392

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip encoder-only groups before indexing kernel_block_sizes

Encoder-only cache groups are appended to kv_cache_config.kv_cache_groups but _prepare_kernel_block_sizes intentionally omits them when building kernel_block_sizes. _kv_cache_spec_attn_group_iterator() still yields AttentionGroups for those encoder-only layers, so the loop dereferences kernel_block_sizes[group.kv_cache_group_id] before the subsequent runner_only_attn_layers check. When an encoder-only group exists (e.g., encoder–decoder models), this index is past the end of the list and initialize_kv_cache will crash. Skip encoder-only specs before indexing or include placeholders in kernel_block_sizes so that the list length matches the group ids.

Useful? React with 👍 / 👎.

if layer_name in self.runner_only_attn_layers:
continue
Expand All @@ -4286,24 +4280,21 @@ def _reshape_kv_cache_tensors(
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
kernel_size_list = self._find_compatible_block_sizes(
kv_manager_block_size, attn_backend, return_all=False
num_blocks_per_kv_block = (
kv_cache_spec.block_size // kernel_block_size
)
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block

kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_size,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
Expand Down Expand Up @@ -4386,13 +4377,15 @@ def _update_hybrid_attention_mamba_layout(
)

def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.

Args:
kv_cache_config: The KV cache config
kernel_block_sizes: The kernel block sizes for each attention backend.

Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
Expand All @@ -4401,7 +4394,7 @@ def initialize_kv_cache_tensors(
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(
kv_cache_config, kv_cache_raw_tensors
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
)

# Set up cross-layer KV cache sharing
Expand Down Expand Up @@ -4460,9 +4453,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
# The block size of attention backends. For example, if kv_cache_manager uses
# block_size 256, but attention backends only supports block_size 64, we will
# return kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config, kernel_block_sizes
)

if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class AttentionGroup:
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int

@staticmethod
def create_with_metadata_builders(
Expand All @@ -144,13 +145,16 @@ def create_with_metadata_builders(
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
num_metadata_builders: int = 1,
) -> "AttentionGroup":
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
for _ in range(num_metadata_builders)
]
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec)
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)

def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
Expand Down