Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from vllm.attention import Attention
from vllm.attention.backends.abstract import MultipleOf
from vllm.config import (
CacheConfig,
ModelConfig,
Expand Down Expand Up @@ -34,6 +35,7 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup

BLOCK_SIZE = 16
NUM_BLOCKS = 10
Expand Down Expand Up @@ -181,6 +183,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
).all()


def _make_mock_backend_for_kernel_block_size(
supported_sizes: list[int | MultipleOf],
):
class _MockBackend:
@staticmethod
def get_supported_kernel_block_size():
return supported_sizes

return _MockBackend()


def _make_kv_cache_spec() -> FullAttentionSpec:
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")


def test_select_common_block_size_prefers_manager_block_size():
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]

selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
assert selected_size == 128


def test_select_common_block_size_uses_largest_shared_int():
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]

selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
assert selected_size == 64


def test_select_common_block_size_no_valid_option():
backend_a = _make_mock_backend_for_kernel_block_size([64])
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
attn_groups = [
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
]

with pytest.raises(ValueError):
GPUModelRunner.select_common_block_size(48, attn_groups)


def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0"

Expand Down
175 changes: 91 additions & 84 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3978,6 +3978,7 @@ def get_attn_backends_for_group(

def create_attn_groups(
attn_backends_map: dict[AttentionGroupKey, list[str]],
kv_cache_group_id: int,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
Expand All @@ -3987,6 +3988,7 @@ def create_attn_groups(
kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
Expand All @@ -4005,8 +4007,8 @@ def create_attn_groups(
# Resolve cudagraph_mode before actually initialize metadata_builders
self._check_and_update_cudagraph_mode(attention_backend_set)

for attn_backends_map in attention_backend_maps:
self.attn_groups.append(create_attn_groups(attn_backends_map))
for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))

# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
Expand Down Expand Up @@ -4156,104 +4158,96 @@ def calculate_reorder_batch_threshold(self) -> None:
return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)

def _find_compatible_block_sizes(
self,
kv_manager_block_size: int,
backend_cls: type[AttentionBackend],
return_all: bool = False,
) -> list[int]:
"""
Find compatible block sizes for a backend.

Args:
kv_manager_block_size: Physical block size of KV cache
backend_cls: Attention backend class
return_all: Return all compatible sizes if True, max size if False

Returns:
Compatible block size(s) based on return_all parameter

Raises:
ValueError: If no compatible block size found
"""
supported_block_size = backend_cls.get_supported_kernel_block_size()
compatible_sizes = []

for block_size in supported_block_size:
if isinstance(block_size, int):
if kv_manager_block_size % block_size == 0:
compatible_sizes.append(block_size)
elif (
isinstance(block_size, MultipleOf)
and kv_manager_block_size % block_size.base == 0
):
compatible_sizes.append(kv_manager_block_size)

if not compatible_sizes:
raise ValueError(f"No compatible block size for {kv_manager_block_size}")

return compatible_sizes if return_all else [max(compatible_sizes)]

def _select_common_block_size(
self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]
@staticmethod
def select_common_block_size(
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
) -> int:
"""
Select common block size for all backends.
Select a block size that is supported by all backends and is a factor of
kv_manager_block_size.

If kv_manager_block_size is supported by all backends, return it directly.
Otherwise, return the max supported size.

Args:
kv_manager_block_size: Block size of KV cache
attn_groups: List of attention groups

Returns:
Block size supported by all backends,
prioritizing cache_config.block_size
The selected block size

Raises:
ValueError: If no common block size found
ValueError: If no valid block size found
"""
all_backend_supports = []

for attn_group in attn_groups:
compatible_sizes = self._find_compatible_block_sizes(
kv_manager_block_size, attn_group.backend, return_all=True
)
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
all_backend_supports.append(set(supported_sizes))

common_supported_sizes = set.intersection(*all_backend_supports)

if not common_supported_sizes:
error_msg = f"No common block size for {kv_manager_block_size}. "
for i, attn_group in enumerate(attn_groups):
supported = all_backend_supports[i]
error_msg += (
f"Backend {attn_group.backend} supports: {sorted(supported)}. "
)
raise ValueError(error_msg)

if self.cache_config.block_size in common_supported_sizes:
return self.cache_config.block_size
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
"""
Check if the block size is supported by all backends.
"""
for backend in backends:
is_supported = False
for supported_size in backend.get_supported_kernel_block_size():
if isinstance(supported_size, int):
if block_size == supported_size:
is_supported = True
elif isinstance(supported_size, MultipleOf):
if block_size % supported_size.base == 0:
is_supported = True
else:
raise ValueError(f"Unknown supported size: {supported_size}")
if not is_supported:
return False
return True

backends = [group.backend for group in attn_groups]

# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
if block_size_is_supported(backends, kv_manager_block_size):
return kv_manager_block_size

# Case 2: otherwise, the block_size must be an `int`-format supported size of
# at least one backend. Iterate over all `int`-format supported sizes in
# descending order and return the first one that is supported by all backends.
# Simple proof:
# If the supported size b is in MultipleOf(x_i) format for all attention
# backends i, and b a factor of kv_manager_block_size, then
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
# return kv_manager_block_size in case 1.
all_int_supported_sizes = set(
supported_size
for backend in backends
for supported_size in backend.get_supported_kernel_block_size()
if isinstance(supported_size, int)
)

return max(common_supported_sizes)
for supported_size in sorted(all_int_supported_sizes, reverse=True):
if kv_manager_block_size % supported_size != 0:
continue
if block_size_is_supported(backends, supported_size):
return supported_size
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new implementation raises a less informative error message compared to the previous version when no common block size is found. The old error message listed the supported block sizes for each backend, which is very helpful for debugging. It would be great to restore that level of detail in the error message.

Suggested change
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
error_msg = f"No common block size for {kv_manager_block_size}. "
for backend in backends:
supported_sizes = backend.get_supported_kernel_block_size()
error_msg += f"Backend {backend.__name__} supports: {supported_sizes}. "
raise ValueError(error_msg)


def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.

Args:
kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]

# Generate kernel_block_sizes that matches each block_size
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)

if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
]:
Expand Down Expand Up @@ -4354,7 +4348,7 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
# all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self._select_common_block_size(
selected_kernel_size = self.select_common_block_size(
kv_manager_block_size, attn_groups
)
kernel_block_sizes.append(selected_kernel_size)
Expand All @@ -4372,6 +4366,7 @@ def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
kernel_block_sizes: list[int],
) -> dict[str, torch.Tensor]:
"""
Reshape the KV cache tensors to the desired shape and dtype.
Expand All @@ -4380,6 +4375,7 @@ def _reshape_kv_cache_tensors(
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
Expand All @@ -4389,6 +4385,10 @@ def _reshape_kv_cache_tensors(
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
if group.kv_cache_group_id == len(kernel_block_sizes):
# There may be a last group for layers without kv cache.
continue
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
Expand All @@ -4397,24 +4397,21 @@ def _reshape_kv_cache_tensors(
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
kernel_size_list = self._find_compatible_block_sizes(
kv_manager_block_size, attn_backend, return_all=False
num_blocks_per_kv_block = (
kv_cache_spec.block_size // kernel_block_size
)
kernel_size = kernel_size_list[0]
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block

kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_size,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
Expand Down Expand Up @@ -4497,13 +4494,15 @@ def _update_hybrid_attention_mamba_layout(
)

def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.

Args:
kv_cache_config: The KV cache config
kernel_block_sizes: The kernel block sizes for each KV cache group.

Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
Expand All @@ -4512,7 +4511,7 @@ def initialize_kv_cache_tensors(
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(
kv_cache_config, kv_cache_raw_tensors
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
)

# Set up cross-layer KV cache sharing
Expand Down Expand Up @@ -4571,9 +4570,17 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
# backends for that group only supports block_size 64, we will return
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config, kernel_block_sizes
)

if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
Expand Down
Loading