From 4f81b65750a1d5163c34a1c25aaf2900cd23507d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 11 May 2025 07:15:21 -0700 Subject: [PATCH 01/44] hybrid allocator Signed-off-by: Chen Zhang --- vllm/config.py | 21 +++ vllm/engine/arg_utils.py | 3 + vllm/v1/core/block_pool.py | 98 +++++++---- vllm/v1/core/kv_cache_coordinator.py | 136 +++++++++++++++ vllm/v1/core/kv_cache_manager.py | 130 ++++++++------- vllm/v1/core/kv_cache_utils.py | 165 ++++++++++++++++++- vllm/v1/core/sched/scheduler.py | 6 +- vllm/v1/core/single_type_kv_cache_manager.py | 78 ++++----- vllm/v1/engine/core.py | 2 + vllm/v1/kv_cache_interface.py | 23 ++- vllm/v1/worker/gpu_model_runner.py | 113 +++++++++---- vllm/v1/worker/tpu_model_runner.py | 5 +- 12 files changed, 589 insertions(+), 191 deletions(-) create mode 100644 vllm/v1/core/kv_cache_coordinator.py diff --git a/vllm/config.py b/vllm/config.py index 09e89c1116f1..6bfe89d2fe68 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2036,6 +2036,12 @@ class SchedulerConfig: default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" + disable_hybrid_kv_cache_manager: bool = False + """If set to True, KV cache manager will allocate the same size of KV cache + for all attention layers even if there are multiple type of attention layers + like full attention and sliding window attention. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -4347,6 +4353,21 @@ def __post_init__(self): if not self.instance_id: self.instance_id = random_uuid()[:5] + if (envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager): + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not, we don't log + # warning message here and will log it later. + if not (current_platform.is_cuda() or current_platform.is_rocm()): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.disable_hybrid_kv_cache_manager = True + if self.kv_transfer_config is not None: + # Hybrid KV cache manager is not compatible with KV transfer. + self.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.disable_hybrid_kv_cache_manager = True + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 240142a1c5d1..0f18af5d8559 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -368,6 +368,9 @@ class EngineArgs: bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input + disable_hybrid_allocator: bool = ( + SchedulerConfig.disable_hybrid_kv_cache_manager) + guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback guided_decoding_disable_any_whitespace: bool = \ diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index f2ed183b68fc..7ee4d8b26e6d 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -7,7 +7,7 @@ BlockStored, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + GroupedKVCacheBlock, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) from vllm.v1.request import Request @@ -26,12 +26,15 @@ class BlockPool: Args: num_gpu_blocks: The number of blocks in the pool. enable_caching: Whether to enable prefix caching. + num_single_type_managers: The number of single_type_managers. + enable_kv_cache_events: Whether to enable kv cache events. """ def __init__( self, num_gpu_blocks: int, enable_caching: bool, + num_single_type_managers: int, enable_kv_cache_events: bool = False, ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 @@ -46,8 +49,10 @@ def __init__( # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is - # a full block with a block hash that can be used for prefix caching. + # TODO: update comment + # {manager_id: {block_hash: {block ID: GroupedKVCacheBlock}}}. A cached + # block is a full block with a block hash that can be used for prefix + # caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. # NOTE: We currently don't de-duplicate the blocks in the cache, @@ -55,29 +60,33 @@ def __init__( # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashType, dict[ - int, KVCacheBlock]] = defaultdict(dict) - + self.cached_block_hash_to_block: list[dict[BlockHashType, dict[ + int, GroupedKVCacheBlock]]] = [ + defaultdict(dict) for _ in range(num_single_type_managers) + ] # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to # avoid freeing it. self.null_block = self.free_block_queue.popleft() + self.num_single_type_managers = num_single_type_managers self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] - def get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: + def get_cached_block(self, block_hash: BlockHashType, + manager_id: int) -> Optional[GroupedKVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. Args: block_hash: The hash value of the block. + manager_id: The id of the single_type_manager. Returns: The cached block if it exists, or None. """ - cached_blocks = self.cached_block_hash_to_block.get(block_hash) + cached_blocks = self.cached_block_hash_to_block[manager_id].get( + block_hash) if not cached_blocks: return None first_block_id = next(iter(cached_blocks)) @@ -86,11 +95,12 @@ def get_cached_block(self, def cache_full_blocks( self, request: Request, - blocks: list[KVCacheBlock], + blocks: list[GroupedKVCacheBlock], block_hashes: list[BlockHashType], num_cached_blocks: int, num_full_blocks: int, block_size: int, + manager_id: int, hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. @@ -110,6 +120,7 @@ def cache_full_blocks( num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. + manager_id: The id of the single_type_manager. hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: @@ -130,14 +141,16 @@ def cache_full_blocks( new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): + assert all(b.block_hash is None for b in blk.blocks) assert blk.block_hash is None if i < len(new_block_hashes): # The block hash may already be computed in # "get_computed_blocks" if the tokens are not generated by # this request (either the prompt tokens or the previously - # generated tokens with preemption). In this case we simply - # reuse the block hash. + # generated tokens with preemption). + # TODO: or other groups with the same block_size + # In this case we simply reuse the block hash. block_hash = new_block_hashes[i] else: # Otherwise compute the block hash and cache it in the request @@ -164,8 +177,12 @@ def cache_full_blocks( block_hashes.append(block_hash) # Update and added the full block to the cache. + for b in blk.blocks: + b.block_hash = block_hash + b.manager_id = manager_id blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + self.cached_block_hash_to_block[manager_id][block_hash][ + blk.master_block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) prev_block_hash_value = block_hash.hash_value @@ -227,20 +244,23 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: True if the block is evicted, False otherwise. """ block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] - + manager_id = block.manager_id + if block_hash and block_hash in self.cached_block_hash_to_block[ + manager_id]: + cached_blocks = ( + self.cached_block_hash_to_block[manager_id][block_hash]) + assert block.block_id in cached_blocks + cached_blocks[block.block_id].reset_hash() + del cached_blocks[block.block_id] + if len(cached_blocks) == 0: + del self.cached_block_hash_to_block[manager_id][block_hash] if self.enable_kv_cache_events: self.kv_event_queue.append( BlockRemoved(block_hashes=[block_hash.hash_value])) return True return False - def touch(self, blocks: list[KVCacheBlock]) -> None: + def touch(self, blocks: list[list[GroupedKVCacheBlock]]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -248,14 +268,18 @@ def touch(self, blocks: list[KVCacheBlock]) -> None: Args: blocks: A list of blocks to touch. """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.remove(block) - block.incr_ref() - - def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + # TODO: check whether we should manage ref_cnt at grouped_block level + for blocks_one_manager in blocks: + for grouped_block in blocks_one_manager: + for block in grouped_block.blocks: + # ref_cnt=0 means this block is in the free list (i.e. + # eviction candidate), so remove it. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, + ordered_blocks: Iterable[GroupedKVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. @@ -263,11 +287,13 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - for block in ordered_blocks: - block.decr_ref() - # null_block should not be added to the free list. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.append(block) + # TODO: make sure blocks in the first group are evicted first + for blk in ordered_blocks: + for block in blk.blocks: + block.decr_ref() + # null_block should not be added to the free list. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -286,7 +312,9 @@ def reset_prefix_cache(self) -> bool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self.cached_block_hash_to_block = [ + defaultdict(dict) for _ in range(self.num_single_type_managers) + ] # Remove all hashes from all blocks. for block in self.blocks: diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py new file mode 100644 index 000000000000..a91142ea3263 --- /dev/null +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable + +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHashType, GroupedKVCacheBlock +from vllm.v1.core.single_type_kv_cache_manager import ( + SingleTypeKVCacheManager, get_manager_for_kv_cache_spec) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request + + +class KVCacheCoordinator: + """ + Coordinator the KV cache of different KV cache groups. + # TODO: docstring for this class + """ + + def __init__(self, kv_cache_config: KVCacheConfig, block_pool: BlockPool, + max_model_len: int, use_eagle: bool, + caching_hash_fn: Callable): + self.block_pool = block_pool + self.kv_cache_config = kv_cache_config + self.max_model_len = max_model_len + + # the kv cache groups managed by the each manager + # manager_id -> list[kv_cache_group_id] + self.manager_to_group, self.group_to_manager = ( + self.generate_group_manager_map()) + self.num_single_type_manager = len(self.manager_to_group) + + self.single_type_managers: list[SingleTypeKVCacheManager] = [] + for i in range(len(self.manager_to_group)): + group_ids = self.manager_to_group[i] + kv_cache_spec = kv_cache_config.kv_cache_groups[ + group_ids[0]].kv_cache_spec + self.single_type_managers.append( + get_manager_for_kv_cache_spec( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + use_eagle=use_eagle, + num_kv_cache_groups=len(self.manager_to_group[i]), + manager_id=i, + caching_hash_fn=caching_hash_fn, + )) + + def find_longest_cache_hit( + self, request: Request, block_hashes_dict: dict[int, + list[BlockHashType]], + max_cache_hit_length: int + ) -> tuple[list[list[GroupedKVCacheBlock]], int]: + """Find the longest cache hit for each kv cache group. + TODO: add more notes + """ + # TODO: implement this + raise NotImplementedError("Not implemented") + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + for manager in self.single_type_managers: + manager.remove_skipped_blocks(request_id, num_computed_tokens) + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[list[GroupedKVCacheBlock]]) -> int: + num_blocks_to_allocate = 0 + for i, manager in enumerate(self.single_type_managers): + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) + return num_blocks_to_allocate + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[list[GroupedKVCacheBlock]]) -> None: + for i, manager in enumerate(self.single_type_managers): + manager.save_new_computed_blocks(request_id, + new_computed_blocks[i]) + + def cache_blocks(self, request: Request, + block_hashes: dict[int, list[BlockHashType]], + num_computed_tokens: int) -> None: + for manager in self.single_type_managers: + manager.cache_blocks(request, block_hashes[manager.block_size], + num_computed_tokens) + + def allocate_new_blocks( + self, request_id: str, + num_tokens: int) -> list[list[GroupedKVCacheBlock]]: + new_blocks = [] + for manager in self.single_type_managers: + new_blocks.append( + manager.allocate_new_blocks(request_id, num_tokens)) + return new_blocks + + def free(self, request_id: str) -> None: + for manager in self.single_type_managers: + manager.free(request_id) + + def get_num_common_prefix_blocks( + self, + request_id: str, + num_running_requests: int, + ) -> list[int]: + num_blocks_per_manager = [ + manager.get_num_common_prefix_blocks(request_id, + num_running_requests) + for manager in self.single_type_managers + ] + num_blocks_per_group = [ + num_blocks_per_manager[manager_id] + for manager_id, _ in self.group_to_manager + ] + return num_blocks_per_group + + def generate_group_manager_map( + self) -> tuple[list[list[int]], list[tuple[int, int]]]: + # TODO: refactor this function to ensure full attention is the first + # group + type_ids = [ + g.kv_cache_spec.type_id + for g in self.kv_cache_config.kv_cache_groups + ] + assert sorted(type_ids) == type_ids, "type_ids must be sorted" + manager_to_group: list[list[int]] = [] + for i, type_id in enumerate(type_ids): + if i == 0: + manager_to_group.append([i]) + else: + if type_id == type_ids[i - 1]: + manager_to_group[-1].append(i) + else: + manager_to_group.append([i]) + print("manager_to_group", manager_to_group) + group_to_manager = [(i, j) for i in range(len(manager_to_group)) + for j in range(len(manager_to_group[i]))] + print("group_to_manager", group_to_manager) + return manager_to_group, group_to_manager diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index da18ece7555a..396de64cd7c5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -8,10 +8,9 @@ from vllm.logger import init_logger from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, +from vllm.v1.core.kv_cache_coordinator import KVCacheCoordinator +from vllm.v1.core.kv_cache_utils import (BlockHashType, GroupedKVCacheBlock, hash_request_tokens) -from vllm.v1.core.single_type_kv_cache_manager import ( - get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -21,16 +20,22 @@ @dataclass class KVCacheBlocks: - blocks: list[KVCacheBlock] + blocks: list[list[GroupedKVCacheBlock]] + group_to_manager: list[tuple[int, int]] def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" - return KVCacheBlocks(self.blocks + other.blocks) + assert self.group_to_manager is other.group_to_manager + return KVCacheBlocks( + [blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)], + self.group_to_manager) @classmethod - def create_empty(cls) -> "KVCacheBlocks": + def create_empty( + cls, group_to_manager: list[tuple[int, int]]) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" - return cls([]) + return cls([[] for _ in range(len(group_to_manager))], + group_to_manager) def get_block_ids(self) -> list[list[int]]: """ @@ -38,15 +43,23 @@ def get_block_ids(self) -> list[list[int]]: Returns: list[list[int]]: A two-level list where - * the outer list corresponds to KV cache groups (only 1 group now) + * the outer list corresponds to KV cache groups * each inner list contains the block_ids of the blocks in that group """ - return [[block.block_id for block in self.blocks]] + block_ids = [] + for manager_id, group_id_in_manager in self.group_to_manager: + block_ids.append([ + blk.blocks[group_id_in_manager].block_id + for blk in self.blocks[manager_id] + ]) + return block_ids def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" + assert len(self.group_to_manager) == 1, "Only one group is supported" return [ - block.block_id for block in self.blocks if block.block_hash is None + block.master_block_id for block in self.blocks[0] + if block.block_hash is None ] @@ -62,11 +75,6 @@ def __init__( log_stats: bool = False, enable_kv_cache_events: bool = False, ) -> None: - assert len(kv_cache_config.kv_cache_groups) == 1, ( - "KVCacheManager does not support hybrid models with more than 1 " - "kv cache group") - kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec - self.block_size = kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len @@ -76,23 +84,27 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, + # TODO: remove hardcode num_managers + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, 2, enable_kv_cache_events) - self.single_type_manager = get_manager_for_kv_cache_spec( - kv_cache_spec=kv_cache_spec, + self.coordinator = KVCacheCoordinator( + kv_cache_config=kv_cache_config, block_pool=self.block_pool, + max_model_len=self.max_model_len, use_eagle=self.use_eagle, - num_kv_cache_groups=1, caching_hash_fn=self.caching_hash_fn, ) + self.group_to_manager = self.coordinator.group_to_manager # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + # TODO: update comment + self.req_to_block_hashes: defaultdict[str, dict[ + int, list[BlockHashType]]] = defaultdict(dict) + self.all_block_sizes = set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) @property def usage(self) -> float: @@ -132,14 +144,17 @@ def get_computed_blocks(self, # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching or request.sampling_params.prompt_logprobs is not None): - return KVCacheBlocks.create_empty(), 0 + return KVCacheBlocks.create_empty(self.group_to_manager), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if not block_hashes: - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) + block_hashes = { + block_size: + hash_request_tokens(self.caching_hash_fn, block_size, request) + for block_size in self.all_block_sizes + } self.req_to_block_hashes[request.request_id] = block_hashes if self.log_stats: @@ -154,19 +169,17 @@ def get_computed_blocks(self, # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 - computed_blocks = self.single_type_manager.find_longest_cache_hit( - block_hashes, max_cache_hit_length) - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size + computed_blocks, num_new_computed_tokens = ( + self.coordinator.find_longest_cache_hit(request, block_hashes, + max_cache_hit_length)) if self.log_stats: assert self.prefix_cache_stats is not None - self.prefix_cache_stats.queries += request.num_tokens - self.prefix_cache_stats.hits += num_computed_tokens + self.prefix_cache_stats.queries += len(request.all_token_ids) + self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks), num_computed_tokens + return KVCacheBlocks(computed_blocks, + self.group_to_manager), num_new_computed_tokens def allocate_slots( self, @@ -218,7 +231,9 @@ def allocate_slots( if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: - new_computed_block_list = [] + new_computed_block_list = [ + [] for _ in range(self.coordinator.num_single_type_manager) + ] # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -226,8 +241,8 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self.single_type_manager.remove_skipped_blocks( - request.request_id, request.num_computed_tokens) + self.coordinator.remove_skipped_blocks(request.request_id, + request.num_computed_tokens) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits @@ -236,12 +251,12 @@ def allocate_slots( num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, self.max_model_len) - num_blocks_to_allocate = ( - self.single_type_manager.get_num_blocks_to_allocate( - request_id=request.request_id, - num_tokens=num_tokens_need_slot, - new_computed_blocks=new_computed_block_list, - )) + + num_blocks_to_allocate = (self.coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_tokens_need_slot, + new_computed_blocks=new_computed_block_list, + )) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks @@ -257,25 +272,25 @@ def allocate_slots( # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.single_type_manager.save_new_computed_blocks( - request.request_id, new_computed_block_list) + self.coordinator.save_new_computed_blocks(request.request_id, + new_computed_block_list) - new_blocks = self.single_type_manager.allocate_new_blocks( + new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks, self.group_to_manager) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - self.single_type_manager.cache_blocks( + self.coordinator.cache_blocks( request, self.req_to_block_hashes[request.request_id], num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks, self.group_to_manager) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -285,7 +300,7 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - self.single_type_manager.free(request.request_id) + self.coordinator.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -343,10 +358,8 @@ def get_num_common_prefix_blocks( group. """ assert request.status == RequestStatus.RUNNING - return [ - self.single_type_manager.get_num_common_prefix_blocks( - request.request_id, num_running_requests) - ] + return self.coordinator.get_num_common_prefix_blocks( + request.request_id, num_running_requests) def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -366,6 +379,9 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" - assert request_id in self.single_type_manager.req_to_blocks - return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id] - ).get_block_ids() + # TODO: implement this + return [] + # assert request_id in self.single_type_manager.req_to_blocks + # return KVCacheBlocks(self.single_type_manager.req_to_blocks + # [request_id] + # ).get_block_ids() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 403b5401be75..1afe24e20653 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" import os -from collections import deque +from collections import defaultdict, deque from collections.abc import Sequence +from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -10,8 +11,9 @@ from vllm.logger import init_logger from vllm.utils import GiB_bytes, sha256 from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec) + KVCacheGroupSpec, KVCacheNewTensor, + KVCacheReuseTensor, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -124,6 +126,8 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None + manager_id: int = -1 + def incr_ref(self): self.ref_cnt += 1 @@ -143,6 +147,7 @@ def block_hash(self, block_hash: BlockHashType): def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None + self.manager_id = -1 def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ @@ -648,7 +653,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_config = KVCacheConfig( num_blocks=num_blocks, tensors={ - layer_name: KVCacheTensor(size=per_layer_size) + layer_name: KVCacheNewTensor(size=per_layer_size) for layer_name in kv_cache_spec }, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, @@ -657,17 +662,106 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config +def is_kv_cache_page_size_uniform( + kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same page size. + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + +def _get_kv_cache_config_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one page size. + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The KVCacheSpec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + Returns: + The generated KVCacheConfig + """ + # Group all layers by type_id. + # E.g., 2 full attention layers and 3 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2). + same_type_layers: dict[str, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec.type_id].append(layer_name) + + # Split each group into smaller groups, to make the number of layers in each + # group identical. Add padding to the last group of each type if necessary. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) + # split to 3 groups with 2 layers each: + # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + group_size = min([len(layers) for layers in same_type_layers.values()]) + grouped_layers = [] + for layers in same_type_layers.values(): + num_padding_layers = len(layers) % group_size + if num_padding_layers > 0: + logger.warning( + "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa + num_padding_layers, + num_padding_layers / len(layers) * 100) + for i in range(0, len(layers), group_size): + grouped_layers.append(layers[i:i + group_size]) + + # Divide the available memory equally among all layers in the first group. + # The memory layout in the example will be: + # full.0: Tensor with size=available_memory//2 + # full.1: Tensor with size=available_memory//2 + kv_cache_spec_first_group = { + layer_name: kv_cache_spec[layer_name] + for layer_name in grouped_layers[0] + } + kv_cache_config = _get_kv_cache_config_uniform_type( + vllm_config, kv_cache_spec_first_group, available_memory) + + # Reuse the KV cache tensors of the first group for the other groups. + # The memory layout in the example will be: + # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 + # full.1, sw.1: share another Tensor with size=available_memory//2 + # Layers of different groups have different block table, so they will + # use different parts of the shared Tensor. + for layers in grouped_layers[1:]: + for layer_name, layer_name_first_group in zip( + layers, grouped_layers[0][:len(layers)]): + kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( + reused_layer_name=layer_name_first_group) + + kv_cache_config.kv_cache_groups = create_kv_cache_group_specs( + kv_cache_spec, grouped_layers) + return kv_cache_config + + def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ - Only models with one type of KV cache are supported yet. This function tries - to convert the KV cache specs to one type if the model is a hybrid model - with multiple type of KV cache. It will convert all SlidingWindowSpec to - FullAttentionSpec if both types are present. + This function tries to convert the KV cache specs to one type if the model + is a hybrid model with multiple type of KV cache. It will convert all + SlidingWindowSpec to FullAttentionSpec if both types are present. Args: kv_cache_spec: The kv cache spec of each attention layer in the model """ + def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + type_ids = set(layer_spec.type_id + for layer_spec in kv_cache_spec.values()) + return len(type_ids) > 1 + + if not is_hybrid(kv_cache_spec): + return + # TODO: better warning message + logger.warning("Hybrid KV cache manager is disabled for this hybrid model," + "There can be some waste of KV cache memory.") + has_full_attention = any( isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) has_sliding_window = any( @@ -684,6 +778,12 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): sliding_window=spec.sliding_window, ) + if not is_hybrid(kv_cache_spec): + # TODO: better error message + raise ValueError( + "Hybrid KV cache manager is disabled but we failed to " + "convert the KV cache specs to one type.") + def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], @@ -701,13 +801,21 @@ def get_kv_cache_config(vllm_config: VllmConfig, The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - unify_hybrid_kv_cache_specs(kv_cache_spec) + + if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + unify_hybrid_kv_cache_specs(kv_cache_spec) + if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # KV cache of all layers have the same page size. TODO more notes + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) raise NotImplementedError @@ -746,3 +854,42 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): kv_cache_config.num_blocks = min_num_blocks return kv_cache_configs + + +@contextmanager +def remove_last_block_hash_for_divisible_prompt_length( + block_hashes: dict[int, list[BlockHashType]], num_tokens: int): + """ + Remove the last block hash for the case where the prompt length is divisible + by the block size and all blocks are cached. + """ + last_block_hashs: dict[int, BlockHashType] = {} + for block_size in block_hashes: + if len(block_hashes[block_size]) * block_size == num_tokens: + last_block_hashs[block_size] = block_hashes[block_size].pop() + yield + for block_size, block_hash in last_block_hashs.items(): + block_hashes[block_size].append(block_hash) + + +# KVCacheBlocks for the same block of all kv cache groups with the same kv cache +# spec (and belongs to the same manager) +# TODO: more notes +# TODO: optimize the creation of GroupedKVCacheBlock +@dataclass +class GroupedKVCacheBlock: + blocks: tuple[KVCacheBlock, ...] + block_hash: Optional[BlockHashType] = None + master_block_id: int = -1 + ref_cnt: int = 0 + + @staticmethod + def from_kv_cache_blocks(blocks: tuple[KVCacheBlock, ...]): + return GroupedKVCacheBlock(blocks=blocks, + block_hash=blocks[0].block_hash, + master_block_id=blocks[0].block_id) + + def reset_hash(self): + for block in self.blocks: + block.reset_hash() + self.block_hash = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5ad05485e8f3..5061cbc81203 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -352,7 +352,9 @@ def schedule(self) -> SchedulerOutput: request) else: # P/D: skip checking prefix cache if loaded from remote kvs. - new_computed_blocks = KVCacheBlocks.create_empty() + # TODO: add util function to create empty blocks + new_computed_blocks = KVCacheBlocks.create_empty( + self.kv_cache_manager.group_to_manager) num_native_computed_tokens = 0 # Get externally-cached tokens if using a KVConnector. @@ -966,7 +968,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: num_computed_tokens = len(block_ids) * self.block_size if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 - self.kv_cache_manager.single_type_manager.cache_blocks( + self.kv_cache_manager.coordinator.cache_blocks( request, self.kv_cache_manager.req_to_block_hashes[request.request_id], num_computed_tokens, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 0223c9ceec8d..32d814f3a2e8 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,7 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashType, GroupedKVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -23,6 +23,7 @@ def __init__( block_pool: BlockPool, use_eagle: bool, num_kv_cache_groups: int, + manager_id: int, caching_hash_fn: Callable, ) -> None: """ @@ -33,6 +34,7 @@ def __init__( use_eagle: Whether to use eagle. num_kv_cache_groups: The number of kv cache groups managed by this manager. + manager_id: The id of this manager. caching_hash_fn: The caching hash function. """ @@ -46,8 +48,8 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[ + str, list[GroupedKVCacheBlock]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. @@ -57,10 +59,11 @@ def __init__( self.num_kv_cache_groups = num_kv_cache_groups self.caching_hash_fn = caching_hash_fn + self.manager_id = manager_id def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + new_computed_blocks: list[GroupedKVCacheBlock]) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -89,7 +92,7 @@ def get_num_blocks_to_allocate( def save_new_computed_blocks( self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + new_computed_blocks: list[GroupedKVCacheBlock]) -> None: """ Add the new computed blocks to the request. @@ -109,7 +112,7 @@ def save_new_computed_blocks( assert len(new_computed_blocks) == 0 def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + num_tokens: int) -> list[GroupedKVCacheBlock]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -128,8 +131,16 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - new_blocks = self.block_pool.get_new_blocks( + flat_new_blocks = self.block_pool.get_new_blocks( num_new_blocks * self.num_kv_cache_groups) + # TODO: accelerate for num_blocks=1 + new_blocks = [] + for i in range(num_new_blocks): + blocks = flat_new_blocks[i * self.num_kv_cache_groups:(i + 1) * + self.num_kv_cache_groups] + grouped_block = GroupedKVCacheBlock.from_kv_cache_blocks( + tuple(blocks)) + new_blocks.append(grouped_block) req_blocks.extend(new_blocks) return new_blocks @@ -154,6 +165,7 @@ def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, + manager_id=self.manager_id, hash_fn=self.caching_hash_fn, ) @@ -188,7 +200,7 @@ def get_num_common_prefix_blocks(self, request_id: str, @abstractmethod def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlock]: + max_length: int) -> list[GroupedKVCacheBlock]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. If no cache hit is found, return an empty list. @@ -229,15 +241,16 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] + max_length: int) -> list[GroupedKVCacheBlock]: + computed_blocks: list[GroupedKVCacheBlock] = [] max_num_blocks = max_length // self.block_size for i in range(max_num_blocks): block_hash = block_hashes[i] # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): + if cached_block := self.block_pool.get_cached_block( + block_hash, self.manager_id): computed_blocks.append(cached_block) else: break @@ -278,43 +291,14 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # contiguous blocks needed for prefix cache hit by one and dropping # the last matched block. self.sliding_window_contiguous_blocks += 1 - self._null_block = block_pool.null_block + single_null_block = block_pool.null_block + self._null_block = GroupedKVCacheBlock.from_kv_cache_blocks( + tuple([single_null_block] * self.num_kv_cache_groups)) def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlock]: - # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # optimize the time complexity from O(max_num_blocks) to - # O(max_num_blocks / sliding_window_contiguous_blocks + - # sliding_window_contiguous_blocks), - # which is good for low cache hit rate scenarios. - max_num_blocks = max_length // self.block_size - computed_blocks = [self._null_block] * max_num_blocks - num_contiguous_blocks = 0 - - match_found = False - # Search from right to left and early stop when a match is found. - for i in range(max_num_blocks - 1, -1, -1): - if cached_block := self.block_pool.get_cached_block( - block_hashes[i]): - computed_blocks[i] = cached_block - num_contiguous_blocks += 1 - if (num_contiguous_blocks - >= self.sliding_window_contiguous_blocks): - # Trim the trailing blocks. - # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] - # when sliding_window_contiguous_blocks=2. - del computed_blocks[i + num_contiguous_blocks:] - match_found = True - break - else: - num_contiguous_blocks = 0 - if not match_found: - # The first `num_contiguous_blocks` is a cache hit even if - # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() - return computed_blocks + max_length: int) -> list[GroupedKVCacheBlock]: + # TODO + return [] def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: @@ -323,7 +307,7 @@ def remove_skipped_blocks(self, request_id: str, last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlock] = [] + removed_blocks: list[GroupedKVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): if blocks[i] == self._null_block: # If the block is already a null block, the blocks before it diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0cf2383af1c9..f281a186e6dd 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -155,6 +155,8 @@ def _initialize_kv_caches( num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 scheduler_kv_cache_config = kv_cache_configs[0] + # TODO: remove this debug print + print("kv_cache_config", scheduler_kv_cache_config) # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 2747fc7fabd1..821c08a78dbe 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -154,15 +154,30 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass -class KVCacheTensor: +class KVCacheTensorBase: """ A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will - be extended to support multiple layers sharing the same memory pool. + for a layer. + """ + pass + + +@dataclass +class KVCacheNewTensor(KVCacheTensorBase): + """ + Initialize the KV cache with a tensor of `size` bytes. """ size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheReuseTensor(KVCacheTensorBase): + """ + Reuse the KV cache tensor of `layer_name` for the current layer. + """ + reused_layer_name: str + + @dataclass class KVCacheGroupSpec: """ @@ -183,7 +198,7 @@ class KVCacheConfig: """The number of KV cache blocks""" num_blocks: int """layer_name -> how to initialize KV cache for that layer""" - tensors: dict[str, KVCacheTensor] + tensors: dict[str, KVCacheTensorBase] """ The kv cache groups of the model. The layers in the models are repeated with some patterns, e.g., a model diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b34a9fb0616..f4cb5cab7950 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -39,7 +39,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, + KVCacheConfig, KVCacheNewTensor, + KVCacheReuseTensor, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -1867,6 +1868,81 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + def _initialize_kv_cache_buffer( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheNewTensor): + # A new tensor with `tensor_config.size` bytes + kv_cache_raw_tensors[layer_name] = torch.zeros( + tensor_config.size, dtype=torch.int8, device=self.device) + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheReuseTensor): + # Reuse a tensor from `kv_cache_raw_tensors` + kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ + tensor_config.reused_layer_name] + assert len(kv_cache_raw_tensors) == len( + kv_cache_config.tensors), "Some layers are not initialized" + return kv_cache_raw_tensors + + def _setup_kv_cache_shapes( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape. + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: dict[str, torch.Tensor] = {} + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape) + else: + raise NotImplementedError + return kv_caches + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + # TODO: docstring + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._initialize_kv_cache_buffer( + kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._setup_kv_cache_shapes(kv_cache_config, + kv_cache_raw_tensors) + bind_kv_cache( + kv_caches, + self.vllm_config.compilation_config.static_forward_context, []) + return kv_caches + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1885,40 +1961,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config=kv_cache_config, ) self.initialize_attn_backend(kv_cache_config) - - kv_caches: dict[str, torch.Tensor] = {} - - for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - # `num_blocks` is the number of blocks the model runner can use. - # `kv_cache_config.num_blocks` is the number of blocks that - # KVCacheManager may allocate. - # Since different GPUs may have different number of layers and - # different memory capacities, `num_blocks` can be different on - # different GPUs, and `kv_cache_config.num_blocks` is set to - # the min of all `num_blocks`. Verify it here. - assert num_blocks >= kv_cache_config.num_blocks - if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - else: - # TODO: add new branches when introducing more types of - # KV cache specs. - raise ValueError("Unknown KV cache spec type.") - - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2da99696445e..cf701f7c1fe9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -31,8 +31,8 @@ PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, - SlidingWindowSpec) + KVCacheConfig, KVCacheNewTensor, + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata @@ -1272,6 +1272,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] + assert isinstance(tensor_config, KVCacheNewTensor) assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): From ec550211c4acc122ed20201a80481187100dc319 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 12 May 2025 06:59:22 -0700 Subject: [PATCH 02/44] refactor Signed-off-by: Chen Zhang --- .../v1/e2e/test_correctness_sliding_window.py | 4 +- vllm/v1/core/block_pool.py | 95 ++++--- vllm/v1/core/kv_cache_coordinator.py | 254 ++++++++++++++---- vllm/v1/core/kv_cache_manager.py | 6 +- vllm/v1/core/kv_cache_utils.py | 57 ++-- vllm/v1/core/single_type_kv_cache_manager.py | 98 ++++--- 6 files changed, 350 insertions(+), 164 deletions(-) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index a125d3fb7975..92bb87aeaa59 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -17,7 +17,7 @@ class TestConfig: model_config = { "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), - "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), + "google/gemma-3-1b-it": TestConfig(4096, (400, 800)), } @@ -25,7 +25,7 @@ class TestConfig: "model", [ "bigcode/starcoder2-3b", # sliding window only - "google/gemma-2-2b-it", # sliding window + full attention + "google/gemma-3-1b-it", # sliding window + full attention ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 7ee4d8b26e6d..6c72060b2d83 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -7,7 +7,7 @@ BlockStored, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - GroupedKVCacheBlock, KVCacheBlock, + KVCacheBlock, KVCacheBlockBundle, generate_block_hash_extra_keys, hash_block_tokens) from vllm.v1.request import Request @@ -49,19 +49,20 @@ def __init__( # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # TODO: update comment - # {manager_id: {block_hash: {block ID: GroupedKVCacheBlock}}}. A cached - # block is a full block with a block hash that can be used for prefix - # caching. + # {manager_id: {block_hash: {block ID: KVCacheBlockBundle}}}. + # A cached block is a full block with a block hash that can be used for + # prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. + # Use KVCacheBlockBundle to make sure different kv cache groups managed + # by the same single_type_manager are cached & evicted together. # NOTE: We currently don't de-duplicate the blocks in the cache, # meaning that if a block becomes full and is cached, we don't check # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. self.cached_block_hash_to_block: list[dict[BlockHashType, dict[ - int, GroupedKVCacheBlock]]] = [ + int, KVCacheBlockBundle]]] = [ defaultdict(dict) for _ in range(num_single_type_managers) ] # To represent a placeholder block with block_id=0. @@ -74,7 +75,7 @@ def __init__( self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block(self, block_hash: BlockHashType, - manager_id: int) -> Optional[GroupedKVCacheBlock]: + manager_id: int) -> Optional[KVCacheBlockBundle]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. @@ -95,7 +96,7 @@ def get_cached_block(self, block_hash: BlockHashType, def cache_full_blocks( self, request: Request, - blocks: list[GroupedKVCacheBlock], + blocks: list[KVCacheBlockBundle], block_hashes: list[BlockHashType], num_cached_blocks: int, num_full_blocks: int, @@ -141,15 +142,14 @@ def cache_full_blocks( new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): - assert all(b.block_hash is None for b in blk.blocks) - assert blk.block_hash is None + assert blk.block_hash_is_none() if i < len(new_block_hashes): # The block hash may already be computed in # "get_computed_blocks" if the tokens are not generated by # this request (either the prompt tokens or the previously - # generated tokens with preemption). - # TODO: or other groups with the same block_size + # generated tokens with preemption), or by other + # single_type_managers with the same block_size. # In this case we simply reuse the block hash. block_hash = new_block_hashes[i] else: @@ -177,10 +177,7 @@ def cache_full_blocks( block_hashes.append(block_hash) # Update and added the full block to the cache. - for b in blk.blocks: - b.block_hash = block_hash - b.manager_id = manager_id - blk.block_hash = block_hash + blk.init_block_hash(block_hash, manager_id) self.cached_block_hash_to_block[manager_id][block_hash][ blk.master_block_id] = blk if new_hashes is not None: @@ -200,37 +197,46 @@ def cache_full_blocks( if request.lora_request else None, )) - def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: + def get_new_blocks(self, num_block_bundle: int, + bundle_size: int) -> list[KVCacheBlockBundle]: """Get new blocks from the free block pool. Note that we do not check block cache in this function. Args: - num_blocks: The number of blocks to allocate. + num_block_bundle: The number of KVCacheBlockBundle to allocate. + bundle_size: The number of blocks in each KVCacheBlockBundle. Returns: A list of new block. """ - if num_blocks > self.get_num_free_blocks(): + num_total_blocks = num_block_bundle * bundle_size + if num_total_blocks > self.get_num_free_blocks(): raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") + f"Cannot get {num_total_blocks} free blocks from the pool") - ret: list[KVCacheBlock] = [] + flat_new_blocks: list[KVCacheBlock] = [] idx = 0 - while idx < num_blocks: + while idx < num_total_blocks: # First allocate blocks. curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 # If the block is cached, evict it. if self.enable_caching: self._maybe_evict_cached_block(curr_block) - curr_block.incr_ref() - ret.append(curr_block) + assert curr_block.block_hash is None + flat_new_blocks.append(curr_block) idx += 1 - return ret + new_blocks = [] + for i in range(num_block_bundle): + blocks = flat_new_blocks[i * bundle_size:(i + 1) * bundle_size] + block_bundle = KVCacheBlockBundle.from_kv_cache_blocks( + tuple(blocks)) + block_bundle.incr_ref() + new_blocks.append(block_bundle) + return new_blocks def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: """ @@ -249,8 +255,11 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: manager_id]: cached_blocks = ( self.cached_block_hash_to_block[manager_id][block_hash]) - assert block.block_id in cached_blocks - cached_blocks[block.block_id].reset_hash() + cached_block = cached_blocks[block.block_id] + # TODO: add notes + assert cached_block.master_block_id == block.block_id + assert cached_block.ref_cnt == 0 + cached_block.reset_hash() del cached_blocks[block.block_id] if len(cached_blocks) == 0: del self.cached_block_hash_to_block[manager_id][block_hash] @@ -260,7 +269,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: return True return False - def touch(self, blocks: list[list[GroupedKVCacheBlock]]) -> None: + def touch(self, blocks: list[list[KVCacheBlockBundle]]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -268,18 +277,18 @@ def touch(self, blocks: list[list[GroupedKVCacheBlock]]) -> None: Args: blocks: A list of blocks to touch. """ - # TODO: check whether we should manage ref_cnt at grouped_block level for blocks_one_manager in blocks: - for grouped_block in blocks_one_manager: - for block in grouped_block.blocks: - # ref_cnt=0 means this block is in the free list (i.e. - # eviction candidate), so remove it. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.remove(block) - block.incr_ref() + for block_bundle in blocks_one_manager: + if block_bundle.ref_cnt == 0: + # ref_cnt=0 means the blocks are in the free list (i.e. + # eviction candidate), so remove them. + for block in block_bundle.blocks: + if block != self.null_block: + self.free_block_queue.remove(block) + block_bundle.incr_ref() def free_blocks(self, - ordered_blocks: Iterable[GroupedKVCacheBlock]) -> None: + ordered_blocks: Iterable[KVCacheBlockBundle]) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. @@ -288,11 +297,13 @@ def free_blocks(self, priority. """ # TODO: make sure blocks in the first group are evicted first - for blk in ordered_blocks: - for block in blk.blocks: - block.decr_ref() + for block_bundle in ordered_blocks: + block_bundle.decr_ref() + if block_bundle.ref_cnt > 0: + continue + for block in block_bundle.blocks: # null_block should not be added to the free list. - if block.ref_cnt == 0 and block != self.null_block: + if block != self.null_block: self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index a91142ea3263..ca57bc47244d 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,18 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict from typing import Callable from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, GroupedKVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlockBundle from vllm.v1.core.single_type_kv_cache_manager import ( - SingleTypeKVCacheManager, get_manager_for_kv_cache_spec) -from vllm.v1.kv_cache_interface import KVCacheConfig + FullAttentionManager, SingleTypeKVCacheManager, + get_manager_for_kv_cache_spec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.request import Request class KVCacheCoordinator: """ Coordinator the KV cache of different KV cache groups. - # TODO: docstring for this class """ def __init__(self, kv_cache_config: KVCacheConfig, block_pool: BlockPool, @@ -22,11 +23,8 @@ def __init__(self, kv_cache_config: KVCacheConfig, block_pool: BlockPool, self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len - # the kv cache groups managed by the each manager - # manager_id -> list[kv_cache_group_id] self.manager_to_group, self.group_to_manager = ( self.generate_group_manager_map()) - self.num_single_type_manager = len(self.manager_to_group) self.single_type_managers: list[SingleTypeKVCacheManager] = [] for i in range(len(self.manager_to_group)): @@ -42,26 +40,24 @@ def __init__(self, kv_cache_config: KVCacheConfig, block_pool: BlockPool, manager_id=i, caching_hash_fn=caching_hash_fn, )) + self.verify_support_find_longest_cache_hit() - def find_longest_cache_hit( - self, request: Request, block_hashes_dict: dict[int, - list[BlockHashType]], - max_cache_hit_length: int - ) -> tuple[list[list[GroupedKVCacheBlock]], int]: - """Find the longest cache hit for each kv cache group. - TODO: add more notes + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[list[KVCacheBlockBundle]]) -> int: """ - # TODO: implement this - raise NotImplementedError("Not implemented") + Get the number of blocks needed to be allocated for the request. - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - for manager in self.single_type_managers: - manager.remove_skipped_blocks(request_id, num_computed_tokens) + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. - def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[list[GroupedKVCacheBlock]]) -> int: + Returns: + The number of blocks. + """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): num_blocks_to_allocate += manager.get_num_blocks_to_allocate( @@ -70,28 +66,62 @@ def get_num_blocks_to_allocate( def save_new_computed_blocks( self, request_id: str, - new_computed_blocks: list[list[GroupedKVCacheBlock]]) -> None: + new_computed_blocks: list[list[KVCacheBlockBundle]]) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ for i, manager in enumerate(self.single_type_managers): manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def cache_blocks(self, request: Request, - block_hashes: dict[int, list[BlockHashType]], - num_computed_tokens: int) -> None: - for manager in self.single_type_managers: - manager.cache_blocks(request, block_hashes[manager.block_size], - num_computed_tokens) + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[list[KVCacheBlockBundle]]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. - def allocate_new_blocks( - self, request_id: str, - num_tokens: int) -> list[list[GroupedKVCacheBlock]]: + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ new_blocks = [] for manager in self.single_type_managers: new_blocks.append( manager.allocate_new_blocks(request_id, num_tokens)) return new_blocks + def cache_blocks(self, request: Request, + block_hashes: dict[int, list[BlockHashType]], + num_computed_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + block_hashes: The block hashes of the request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + for manager in self.single_type_managers: + manager.cache_blocks(request, block_hashes[manager.block_size], + num_computed_tokens) + def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ for manager in self.single_type_managers: manager.free(request_id) @@ -100,6 +130,16 @@ def get_num_common_prefix_blocks( request_id: str, num_running_requests: int, ) -> list[int]: + """ + Get the number of common prefix blocks for a request. + + Args: + request_id: The request ID. + block_hashes: The block hashes of the request. + + Returns: + The number of common prefix blocks. + """ num_blocks_per_manager = [ manager.get_num_common_prefix_blocks(request_id, num_running_requests) @@ -111,26 +151,136 @@ def get_num_common_prefix_blocks( ] return num_blocks_per_group + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + """ + Remove the blocks that are no longer needed from `blocks` and replace + the removed blocks with null_block. + + Args: + request_id: The request ID. + num_computed_tokens: The number of tokens that have been computed. + """ + for manager in self.single_type_managers: + manager.remove_skipped_blocks(request_id, num_computed_tokens) + + def find_longest_cache_hit( + self, request: Request, block_hashes_dict: dict[int, + list[BlockHashType]], + max_cache_hit_length: int + ) -> tuple[list[list[KVCacheBlockBundle]], int]: + """ + Find the longest cache hit for the request. + + Args: + request: The request. + block_hashes_dict: The block hashes of the request. + max_cache_hit_length: TODO(Chen): docstring + + Returns: + A tuple containing: + - A list of the cache hit blocks for each single type manager. + - The number of tokens of the longest cache hit. + """ + if len(self.single_type_managers) == 1: + # Return the cache hit blocks for the only kv cache group. + block_size = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + hit_blocks = self.single_type_managers[0].find_longest_cache_hit( + block_hashes_dict[block_size], max_length=max_cache_hit_length) + return [hit_blocks], len(hit_blocks) * block_size + + elif len(self.single_type_managers) == 2: + # For simplicity, we assume the first manager is for full + # attention layers, and the block_size of full attention layers + # is divisible by other attention layers. This has been verified + # in verify_support_find_longest_cache_hit(). + + block_size_0 = self.single_type_managers[0].block_size + block_size_1 = self.single_type_managers[1].block_size + + # First, find the longest cache hit for full attention. + hit_blocks_full_attn = self.single_type_managers[ + 0].find_longest_cache_hit(block_hashes_dict[block_size_0], + max_length=max_cache_hit_length) + hit_length = len(hit_blocks_full_attn) * block_size_0 + + # Next, find the cache hit for the other attention WITHIN + # the cache hit of full attention. + hit_blocks_other_attn = self.single_type_managers[ + 1].find_longest_cache_hit(block_hashes_dict[block_size_1], + max_length=hit_length) + hit_length = len(hit_blocks_other_attn) * block_size_1 + assert hit_length % block_size_0 == 0 + + # Truncate the full attention cache hit to the length of the + # cache hit of the other attention. + del hit_blocks_full_attn[hit_length // block_size_0:] + + return [hit_blocks_full_attn, hit_blocks_other_attn], hit_length + + else: + raise NotImplementedError( + "KVCacheCoordinator does not support more than 2 different" + "types of layers yet.") + def generate_group_manager_map( self) -> tuple[list[list[int]], list[tuple[int, int]]]: - # TODO: refactor this function to ensure full attention is the first - # group - type_ids = [ - g.kv_cache_spec.type_id - for g in self.kv_cache_config.kv_cache_groups + """ + Generate the mapping between kv cache groups and managers. + + Returns: + manager_to_group: list[list[int]], the kv cache groups managed by + each manager. + group_to_manager: list[tuple[int, int]], the manager id and the + index of the group in the manager for each kv cache group. + """ + gathered = defaultdict(list) + full_attention_type_ids = set() + for i, g in enumerate(self.kv_cache_config.kv_cache_groups): + gathered[g.kv_cache_spec.type_id].append(i) + if isinstance(g.kv_cache_spec, FullAttentionSpec): + full_attention_type_ids.add(g.kv_cache_spec.type_id) + + manager_to_group = [] + for type_id in full_attention_type_ids: + manager_to_group.append(gathered[type_id]) + for type_id in gathered.keys() - full_attention_type_ids: + manager_to_group.append(gathered[type_id]) + + group_to_manager_dict = { + group_id: (manager_id, group_id_in_manager) + for manager_id, group_ids in enumerate(manager_to_group) + for group_id_in_manager, group_id in enumerate(group_ids) + } + group_to_manager = [ + group_to_manager_dict[i] + for i in range(len(self.kv_cache_config.kv_cache_groups)) ] - assert sorted(type_ids) == type_ids, "type_ids must be sorted" - manager_to_group: list[list[int]] = [] - for i, type_id in enumerate(type_ids): - if i == 0: - manager_to_group.append([i]) - else: - if type_id == type_ids[i - 1]: - manager_to_group[-1].append(i) - else: - manager_to_group.append([i]) - print("manager_to_group", manager_to_group) - group_to_manager = [(i, j) for i in range(len(manager_to_group)) - for j in range(len(manager_to_group[i]))] - print("group_to_manager", group_to_manager) return manager_to_group, group_to_manager + + def verify_support_find_longest_cache_hit(self) -> None: + """ + For simplicity, find_longest_cache_hit makes some assumptions on the + model architecture instead of provides a general solution. This function + checks if the assumptions hold. + NOTE(Chen): Please open an issue to discuss if you need other cases. + """ + if len(self.single_type_managers) == 1: + return + if len(self.single_type_managers) == 2: + if not isinstance(self.single_type_managers[0], + FullAttentionManager): + raise NotImplementedError( + "KVCacheCoordinator assumes hybrid models have at least one" + " full attention layer now") + block_size_0 = self.single_type_managers[0].block_size + block_size_1 = self.single_type_managers[1].block_size + if block_size_1 % block_size_0 != 0: + raise NotImplementedError( + "KVCacheCoordinator assumes the block_size of the full " + "attention layer is divisible by other layers now.") + else: + raise NotImplementedError( + "KVCacheCoordinator does not support more than 2 different " + "types of layers yet.") diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 396de64cd7c5..0ffc967c46b0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -9,7 +9,7 @@ from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_coordinator import KVCacheCoordinator -from vllm.v1.core.kv_cache_utils import (BlockHashType, GroupedKVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlockBundle, hash_request_tokens) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats @@ -20,7 +20,7 @@ @dataclass class KVCacheBlocks: - blocks: list[list[GroupedKVCacheBlock]] + blocks: list[list[KVCacheBlockBundle]] group_to_manager: list[tuple[int, int]] def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": @@ -232,7 +232,7 @@ def allocate_slots( new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = [ - [] for _ in range(self.coordinator.num_single_type_manager) + [] for _ in self.coordinator.single_type_managers ] # Free the blocks that are skipped during the attention computation diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 1afe24e20653..8fc885a5484e 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -3,7 +3,6 @@ import os from collections import defaultdict, deque from collections.abc import Sequence -from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -115,8 +114,6 @@ class KVCacheBlock: """KV-cache block metadata.""" # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int - # Reference count. - ref_cnt: int = 0 # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. _block_hash: Optional[BlockHashType] = None @@ -128,12 +125,6 @@ class KVCacheBlock: manager_id: int = -1 - def incr_ref(self): - self.ref_cnt += 1 - - def decr_ref(self): - self.ref_cnt -= 1 - @property def block_hash(self) -> Optional[BlockHashType]: return self._block_hash @@ -157,7 +148,6 @@ def __repr__(self) -> str: next_block_id = self.next_free_block.block_id \ if self.next_free_block else None return (f"KVCacheBlock(block_id={self.block_id}, " - f"ref_cnt={self.ref_cnt}, " f"_block_hash={self._block_hash}, " f"prev_free_block={prev_block_id}, " f"next_free_block={next_block_id})") @@ -856,40 +846,43 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): return kv_cache_configs -@contextmanager -def remove_last_block_hash_for_divisible_prompt_length( - block_hashes: dict[int, list[BlockHashType]], num_tokens: int): - """ - Remove the last block hash for the case where the prompt length is divisible - by the block size and all blocks are cached. - """ - last_block_hashs: dict[int, BlockHashType] = {} - for block_size in block_hashes: - if len(block_hashes[block_size]) * block_size == num_tokens: - last_block_hashs[block_size] = block_hashes[block_size].pop() - yield - for block_size, block_hash in last_block_hashs.items(): - block_hashes[block_size].append(block_hash) - - # KVCacheBlocks for the same block of all kv cache groups with the same kv cache # spec (and belongs to the same manager) # TODO: more notes -# TODO: optimize the creation of GroupedKVCacheBlock +# TODO: optimize the creation of KVCacheBlockBundle @dataclass -class GroupedKVCacheBlock: +class KVCacheBlockBundle: blocks: tuple[KVCacheBlock, ...] block_hash: Optional[BlockHashType] = None - master_block_id: int = -1 + # Reference count. ref_cnt: int = 0 + def incr_ref(self): + self.ref_cnt += 1 + + def decr_ref(self): + self.ref_cnt -= 1 + + @property + def master_block_id(self): + return self.blocks[0].block_id + @staticmethod def from_kv_cache_blocks(blocks: tuple[KVCacheBlock, ...]): - return GroupedKVCacheBlock(blocks=blocks, - block_hash=blocks[0].block_hash, - master_block_id=blocks[0].block_id) + return KVCacheBlockBundle(blocks=blocks, + block_hash=blocks[0].block_hash) def reset_hash(self): for block in self.blocks: block.reset_hash() self.block_hash = None + + def block_hash_is_none(self): + return self.block_hash is None and all(block.block_hash is None + for block in self.blocks) + + def init_block_hash(self, block_hash: BlockHashType, manager_id: int): + self.block_hash = block_hash + for b in self.blocks: + b.block_hash = block_hash + b.manager_id = manager_id diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 32d814f3a2e8..23b17a28f361 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,7 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, GroupedKVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlockBundle from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -49,7 +49,7 @@ def __init__( # for each request, so that we can free the blocks when the request # is finished. self.req_to_blocks: defaultdict[ - str, list[GroupedKVCacheBlock]] = defaultdict(list) + str, list[KVCacheBlockBundle]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. @@ -63,7 +63,7 @@ def __init__( def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, - new_computed_blocks: list[GroupedKVCacheBlock]) -> int: + new_computed_blocks: list[KVCacheBlockBundle]) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -92,7 +92,7 @@ def get_num_blocks_to_allocate( def save_new_computed_blocks( self, request_id: str, - new_computed_blocks: list[GroupedKVCacheBlock]) -> None: + new_computed_blocks: list[KVCacheBlockBundle]) -> None: """ Add the new computed blocks to the request. @@ -112,7 +112,7 @@ def save_new_computed_blocks( assert len(new_computed_blocks) == 0 def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[GroupedKVCacheBlock]: + num_tokens: int) -> list[KVCacheBlockBundle]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -131,16 +131,8 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - flat_new_blocks = self.block_pool.get_new_blocks( - num_new_blocks * self.num_kv_cache_groups) - # TODO: accelerate for num_blocks=1 - new_blocks = [] - for i in range(num_new_blocks): - blocks = flat_new_blocks[i * self.num_kv_cache_groups:(i + 1) * - self.num_kv_cache_groups] - grouped_block = GroupedKVCacheBlock.from_kv_cache_blocks( - tuple(blocks)) - new_blocks.append(grouped_block) + new_blocks = self.block_pool.get_new_blocks( + num_new_blocks, self.num_kv_cache_groups) req_blocks.extend(new_blocks) return new_blocks @@ -172,6 +164,12 @@ def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], self.num_cached_block[request.request_id] = num_full_blocks def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ # Default to [] in case a request is freed (aborted) before alloc. req_blocks = self.req_to_blocks.pop(request_id, []) @@ -199,8 +197,11 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError @abstractmethod - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[GroupedKVCacheBlock]: + def find_longest_cache_hit( + self, + block_hashes: list[BlockHashType], + max_length: int, + ) -> list[KVCacheBlockBundle]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. If no cache hit is found, return an empty list. @@ -210,7 +211,7 @@ def find_longest_cache_hit(self, block_hashes: list[BlockHashType], Args: block_hashes: The block hashes of the request. - max_length: The maximum length of the cache hit prefix. + max_length: The maximum length of the cache hit. Returns: A list of cached blocks with skipped blocks replaced by null block. @@ -226,10 +227,8 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks`. The removed - blocks should be replaced by null_block. Return the removed blocks in - eviction order, where the first returned block should be evicted first. - Don't free the removed blocks in this function. Need to be customized - for each attention type. + blocks should be replaced by null_block. Need to be customized for each + attention type. Args: request_id: The request ID. @@ -241,14 +240,14 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[GroupedKVCacheBlock]: - computed_blocks: list[GroupedKVCacheBlock] = [] + max_length: int) -> list[KVCacheBlockBundle]: + computed_blocks = [] max_num_blocks = max_length // self.block_size for i in range(max_num_blocks): block_hash = block_hashes[i] - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. + # block_hashes is a chain of block hashes. If a block hash is + # not in the cached_block_hash_to_id, the following block hashes + # are not computed yet for sure. if cached_block := self.block_pool.get_cached_block( block_hash, self.manager_id): computed_blocks.append(cached_block) @@ -292,13 +291,46 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # the last matched block. self.sliding_window_contiguous_blocks += 1 single_null_block = block_pool.null_block - self._null_block = GroupedKVCacheBlock.from_kv_cache_blocks( + self._null_block = KVCacheBlockBundle.from_kv_cache_blocks( tuple([single_null_block] * self.num_kv_cache_groups)) - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[GroupedKVCacheBlock]: - # TODO - return [] + def find_longest_cache_hit( + self, + block_hashes: list[BlockHashType], + max_length: int, + ) -> list[KVCacheBlockBundle]: + # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to + # optimize the time complexity from O(len(block_hashes)) to + # O(len(block_hashes) / sliding_window_contiguous_blocks + + # sliding_window_contiguous_blocks), + # which is good for low cache hit rate scenarios. + max_num_blocks = max_length // self.block_size + computed_blocks = [self._null_block] * max_num_blocks + num_contiguous_blocks = 0 + match_found = False + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := self.block_pool.get_cached_block( + block_hashes[i], self.manager_id): + computed_blocks[i] = cached_block + num_contiguous_blocks += 1 + if (num_contiguous_blocks + >= self.sliding_window_contiguous_blocks): + # Trim the trailing blocks. + # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # when sliding_window_contiguous_blocks=2. + del computed_blocks[i + num_contiguous_blocks:] + match_found = True + break + else: + num_contiguous_blocks = 0 + if not match_found: + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + del computed_blocks[num_contiguous_blocks:] + if self.use_eagle and len(computed_blocks) > 0: + computed_blocks.pop() + return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: @@ -307,7 +339,7 @@ def remove_skipped_blocks(self, request_id: str, last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size blocks = self.req_to_blocks[request_id] - removed_blocks: list[GroupedKVCacheBlock] = [] + removed_blocks: list[KVCacheBlockBundle] = [] for i in range(last_useful_block - 1, -1, -1): if blocks[i] == self._null_block: # If the block is already a null block, the blocks before it From 41e027ab5bab318158d54951596f1e866c9c9729 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 14 May 2025 20:26:40 -0700 Subject: [PATCH 03/44] fix bug Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 9 +++++++++ vllm/v1/core/kv_cache_manager.py | 8 ++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index ca57bc47244d..faf9e62fca67 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -164,6 +164,15 @@ def remove_skipped_blocks(self, request_id: str, for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, num_computed_tokens) + def get_block_ids(self, request_id: str) -> list[list[KVCacheBlockBundle]]: + """ + Get the block IDs for the request. + """ + return [ + manager.req_to_blocks[request_id] + for manager in self.single_type_managers + ] + def find_longest_cache_hit( self, request: Request, block_hashes_dict: dict[int, list[BlockHashType]], diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0ffc967c46b0..30c97f39a67e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -379,9 +379,5 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" - # TODO: implement this - return [] - # assert request_id in self.single_type_manager.req_to_blocks - # return KVCacheBlocks(self.single_type_manager.req_to_blocks - # [request_id] - # ).get_block_ids() + return KVCacheBlocks(self.coordinator.get_block_ids(request_id), + self.group_to_manager).get_block_ids() From 0735539c57e77c8e39cbc434d3c72cdbf5eef474 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 14 May 2025 22:59:17 -0700 Subject: [PATCH 04/44] minor updates Signed-off-by: Chen Zhang --- vllm/config.py | 2 +- vllm/v1/core/block_pool.py | 43 ++++++++++++++------ vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6bfe89d2fe68..082f22d5e869 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4356,7 +4356,7 @@ def __post_init__(self): if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not, we don't log + # can't know whether the model is hybrid or not now, we don't log # warning message here and will log it later. if not (current_platform.is_cuda() or current_platform.is_rocm()): # Hybrid KV cache manager is not supported on non-GPU platforms. diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 6c72060b2d83..9d9bea67775b 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -178,6 +178,18 @@ def cache_full_blocks( # Update and added the full block to the cache. blk.init_block_hash(block_hash, manager_id) + # We make all blocks in the same KVCacheBlockBundle cached & + # uncached together. This is achieved by: + # 1. Here, use the master_block_id as the representative of the + # KVCacheBlockBundle in the cache. + # 2. In `free_blocks`, add the master block to the free list before + # adding the other blocks in the bundle. + # 3. In `_maybe_evict_cached_block`, as the master block is in front + # of other blocks in the bundle, it will be the first evicted block + # in the bundle. When a master block needs to be evicted, we remove + # the full bundle from cached_block_hash_to_block and remove the + # master block from free_block_queue. The other blocks are still in + # the free_block_queue but won't be hit by get_cached_block. self.cached_block_hash_to_block[manager_id][block_hash][ blk.master_block_id] = blk if new_hashes is not None: @@ -197,9 +209,9 @@ def cache_full_blocks( if request.lora_request else None, )) - def get_new_blocks(self, num_block_bundle: int, - bundle_size: int) -> list[KVCacheBlockBundle]: - """Get new blocks from the free block pool. + def get_new_block_bundles(self, num_block_bundle: int, + bundle_size: int) -> list[KVCacheBlockBundle]: + """Get new block bundles from the free block pool. Note that we do not check block cache in this function. @@ -215,7 +227,7 @@ def get_new_blocks(self, num_block_bundle: int, raise ValueError( f"Cannot get {num_total_blocks} free blocks from the pool") - flat_new_blocks: list[KVCacheBlock] = [] + new_blocks: list[KVCacheBlock] = [] idx = 0 while idx < num_total_blocks: # First allocate blocks. @@ -226,17 +238,18 @@ def get_new_blocks(self, num_block_bundle: int, self._maybe_evict_cached_block(curr_block) assert curr_block.block_hash is None - flat_new_blocks.append(curr_block) + new_blocks.append(curr_block) idx += 1 - new_blocks = [] + new_block_bundles: list[KVCacheBlockBundle] = [] for i in range(num_block_bundle): - blocks = flat_new_blocks[i * bundle_size:(i + 1) * bundle_size] + blocks = new_blocks[i * bundle_size:(i + 1) * bundle_size] + # TODO: optimize the creation of KVCacheBlockBundle class block_bundle = KVCacheBlockBundle.from_kv_cache_blocks( tuple(blocks)) block_bundle.incr_ref() - new_blocks.append(block_bundle) - return new_blocks + new_block_bundles.append(block_bundle) + return new_block_bundles def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: """ @@ -253,10 +266,11 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: manager_id = block.manager_id if block_hash and block_hash in self.cached_block_hash_to_block[ manager_id]: - cached_blocks = ( - self.cached_block_hash_to_block[manager_id][block_hash]) + cached_blocks = self.cached_block_hash_to_block[manager_id][ + block_hash] cached_block = cached_blocks[block.block_id] - # TODO: add notes + # The block is the master block of the KVCacheBlockBundle. + # See comments in cache_full_blocks for details. assert cached_block.master_block_id == block.block_id assert cached_block.ref_cnt == 0 cached_block.reset_hash() @@ -296,11 +310,14 @@ def free_blocks(self, ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - # TODO: make sure blocks in the first group are evicted first for block_bundle in ordered_blocks: block_bundle.decr_ref() if block_bundle.ref_cnt > 0: continue + # NOTE: should add the master block to the free list before adding + # the other blocks. See the comment in `cache_full_blocks` + # for the reason. The following loop implicitly achieves it because + # the master block is the first block in the bundle. for block in block_bundle.blocks: # null_block should not be added to the free list. if block != self.null_block: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 23b17a28f361..b2aeb9c3aca8 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -131,7 +131,7 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - new_blocks = self.block_pool.get_new_blocks( + new_blocks = self.block_pool.get_new_block_bundles( num_new_blocks, self.num_kv_cache_groups) req_blocks.extend(new_blocks) return new_blocks From 5e5384002e5dd895f341aff0a8d1b55711e74764 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 14 May 2025 23:41:24 -0700 Subject: [PATCH 05/44] minor updates Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 4 +- vllm/v1/core/kv_cache_coordinator.py | 40 ++++++++++++-------- vllm/v1/core/kv_cache_manager.py | 16 +++----- vllm/v1/core/kv_cache_utils.py | 14 +++---- vllm/v1/core/sched/scheduler.py | 1 - vllm/v1/core/single_type_kv_cache_manager.py | 24 +++++------- vllm/v1/worker/gpu_model_runner.py | 12 +++++- 7 files changed, 59 insertions(+), 52 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9d9bea67775b..e889577e1885 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -244,7 +244,7 @@ def get_new_block_bundles(self, num_block_bundle: int, new_block_bundles: list[KVCacheBlockBundle] = [] for i in range(num_block_bundle): blocks = new_blocks[i * bundle_size:(i + 1) * bundle_size] - # TODO: optimize the creation of KVCacheBlockBundle class + # TODO: avoid frequent creation of KVCacheBlockBundle class block_bundle = KVCacheBlockBundle.from_kv_cache_blocks( tuple(blocks)) block_bundle.incr_ref() @@ -269,7 +269,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: cached_blocks = self.cached_block_hash_to_block[manager_id][ block_hash] cached_block = cached_blocks[block.block_id] - # The block is the master block of the KVCacheBlockBundle. + # The block is the master block of its KVCacheBlockBundle. # See comments in cache_full_blocks for details. assert cached_block.master_block_id == block.block_id assert cached_block.ref_cnt == 0 diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index faf9e62fca67..a8de932fabcd 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -16,16 +16,25 @@ class KVCacheCoordinator: Coordinator the KV cache of different KV cache groups. """ - def __init__(self, kv_cache_config: KVCacheConfig, block_pool: BlockPool, - max_model_len: int, use_eagle: bool, - caching_hash_fn: Callable): - self.block_pool = block_pool + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + caching_hash_fn: Callable, + enable_kv_cache_events: bool, + ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len + # One manager for each different kv_cache_spec, managing all kv cache + # groups with the same kv_cache_spec. self.manager_to_group, self.group_to_manager = ( self.generate_group_manager_map()) - + self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, + len(self.manager_to_group), + enable_kv_cache_events) self.single_type_managers: list[SingleTypeKVCacheManager] = [] for i in range(len(self.manager_to_group)): group_ids = self.manager_to_group[i] @@ -174,9 +183,10 @@ def get_block_ids(self, request_id: str) -> list[list[KVCacheBlockBundle]]: ] def find_longest_cache_hit( - self, request: Request, block_hashes_dict: dict[int, - list[BlockHashType]], - max_cache_hit_length: int + self, + request: Request, + block_hashes_dict: dict[int, list[BlockHashType]], + max_cache_hit_length: int, ) -> tuple[list[list[KVCacheBlockBundle]], int]: """ Find the longest cache hit for the request. @@ -184,7 +194,7 @@ def find_longest_cache_hit( Args: request: The request. block_hashes_dict: The block hashes of the request. - max_cache_hit_length: TODO(Chen): docstring + max_cache_hit_length: The maximum length of the cache hit. Returns: A tuple containing: @@ -244,18 +254,18 @@ def generate_group_manager_map( group_to_manager: list[tuple[int, int]], the manager id and the index of the group in the manager for each kv cache group. """ - gathered = defaultdict(list) - full_attention_type_ids = set() + groups_by_type_id: dict[str, list[int]] = defaultdict(list) + full_attention_type_ids: set[str] = set() for i, g in enumerate(self.kv_cache_config.kv_cache_groups): - gathered[g.kv_cache_spec.type_id].append(i) + groups_by_type_id[g.kv_cache_spec.type_id].append(i) if isinstance(g.kv_cache_spec, FullAttentionSpec): full_attention_type_ids.add(g.kv_cache_spec.type_id) manager_to_group = [] for type_id in full_attention_type_ids: - manager_to_group.append(gathered[type_id]) - for type_id in gathered.keys() - full_attention_type_ids: - manager_to_group.append(gathered[type_id]) + manager_to_group.append(groups_by_type_id[type_id]) + for type_id in groups_by_type_id.keys() - full_attention_type_ids: + manager_to_group.append(groups_by_type_id[type_id]) group_to_manager_dict = { group_id: (manager_id, group_id_in_manager) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 30c97f39a67e..f94f19d229f0 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -7,7 +7,6 @@ from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.utils import sha256 -from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_coordinator import KVCacheCoordinator from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlockBundle, hash_request_tokens) @@ -75,7 +74,6 @@ def __init__( log_stats: bool = False, enable_kv_cache_events: bool = False, ) -> None: - self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len self.enable_caching = enable_caching @@ -84,27 +82,25 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - # TODO: remove hardcode num_managers - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, 2, - enable_kv_cache_events) self.coordinator = KVCacheCoordinator( kv_cache_config=kv_cache_config, - block_pool=self.block_pool, max_model_len=self.max_model_len, use_eagle=self.use_eagle, + enable_caching=enable_caching, caching_hash_fn=self.caching_hash_fn, + enable_kv_cache_events=enable_kv_cache_events, ) self.group_to_manager = self.coordinator.group_to_manager + self.block_pool = self.coordinator.block_pool - # Mapping from request ID to kv block hashes. + self.all_block_sizes = set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) + # Mapping from request ID to kv block hashes of all block sizes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. - # TODO: update comment self.req_to_block_hashes: defaultdict[str, dict[ int, list[BlockHashType]]] = defaultdict(dict) - self.all_block_sizes = set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) @property def usage(self) -> float: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 8fc885a5484e..3f382bb7126f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -748,7 +748,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: if not is_hybrid(kv_cache_spec): return - # TODO: better warning message + logger.warning("Hybrid KV cache manager is disabled for this hybrid model," "There can be some waste of KV cache memory.") @@ -769,18 +769,15 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: ) if not is_hybrid(kv_cache_spec): - # TODO: better error message - raise ValueError( - "Hybrid KV cache manager is disabled but we failed to " - "convert the KV cache specs to one type.") + raise ValueError("Hybrid KV cache manager is disabled but failed to " + "convert the KV cache specs to one unified type.") def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ - Generates the KV cache configuration for a model - TODO: support hybrid models with more than one type of KV cache. + Generates the KV cache configuration for a model. Args: vllm_config: The global VllmConfig @@ -802,7 +799,8 @@ def get_kv_cache_config(vllm_config: VllmConfig, return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) elif is_kv_cache_page_size_uniform(kv_cache_spec): - # KV cache of all layers have the same page size. TODO more notes + # KV cache of all layers have the same page size. TODO notes about + # hybrid allocator return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec, available_memory) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5061cbc81203..4a0df4f652ed 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -352,7 +352,6 @@ def schedule(self) -> SchedulerOutput: request) else: # P/D: skip checking prefix cache if loaded from remote kvs. - # TODO: add util function to create empty blocks new_computed_blocks = KVCacheBlocks.create_empty( self.kv_cache_manager.group_to_manager) num_native_computed_tokens = 0 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index b2aeb9c3aca8..c43701cc682c 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -197,11 +197,8 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError @abstractmethod - def find_longest_cache_hit( - self, - block_hashes: list[BlockHashType], - max_length: int, - ) -> list[KVCacheBlockBundle]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlockBundle]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. If no cache hit is found, return an empty list. @@ -211,7 +208,7 @@ def find_longest_cache_hit( Args: block_hashes: The block hashes of the request. - max_length: The maximum length of the cache hit. + max_length: The maximum length of the cache hit prefix. Returns: A list of cached blocks with skipped blocks replaced by null block. @@ -241,13 +238,13 @@ class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit(self, block_hashes: list[BlockHashType], max_length: int) -> list[KVCacheBlockBundle]: - computed_blocks = [] + computed_blocks: list[KVCacheBlockBundle] = [] max_num_blocks = max_length // self.block_size for i in range(max_num_blocks): block_hash = block_hashes[i] - # block_hashes is a chain of block hashes. If a block hash is - # not in the cached_block_hash_to_id, the following block hashes - # are not computed yet for sure. + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. if cached_block := self.block_pool.get_cached_block( block_hash, self.manager_id): computed_blocks.append(cached_block) @@ -294,11 +291,8 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, self._null_block = KVCacheBlockBundle.from_kv_cache_blocks( tuple([single_null_block] * self.num_kv_cache_groups)) - def find_longest_cache_hit( - self, - block_hashes: list[BlockHashType], - max_length: int, - ) -> list[KVCacheBlockBundle]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlockBundle]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(len(block_hashes)) to # O(len(block_hashes) / sliding_window_contiguous_blocks + diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f4cb5cab7950..a631a65b1876 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1873,6 +1873,7 @@ def _initialize_kv_cache_buffer( """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. + Args: kv_cache_config: The KV cache config Returns: @@ -1901,6 +1902,7 @@ def _setup_kv_cache_shapes( ) -> dict[str, torch.Tensor]: """ Reshape the KV cache tensors to the desired shape. + Args: kv_cache_config: The KV cache config kv_cache_raw_tensors: The KV cache buffer of each layer, with @@ -1931,7 +1933,15 @@ def _setup_kv_cache_shapes( def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: - # TODO: docstring + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._initialize_kv_cache_buffer( kv_cache_config) From dcfe6caffc2726952be48176f6d8418ff57747ac Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 16 May 2025 06:01:21 -0700 Subject: [PATCH 06/44] avoid frequent creation of block_bundle Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 13 ++++++++----- vllm/v1/core/kv_cache_utils.py | 16 ++++++++++++---- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index e889577e1885..2c164189d553 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Iterable from typing import Callable, Optional @@ -44,6 +44,10 @@ def __init__( self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) ] + # A pool of block bundle instances, to avoid frequent creation of + # KVCacheBlockBundle class. + self._block_bundle_pool: deque[KVCacheBlockBundle] = deque( + KVCacheBlockBundle(blocks=()) for _ in range(num_gpu_blocks)) # Free block queue that constructs and manipulates a doubly linked # list of free blocks (including eviction candidates when caching is # enabled). @@ -244,8 +248,7 @@ def get_new_block_bundles(self, num_block_bundle: int, new_block_bundles: list[KVCacheBlockBundle] = [] for i in range(num_block_bundle): blocks = new_blocks[i * bundle_size:(i + 1) * bundle_size] - # TODO: avoid frequent creation of KVCacheBlockBundle class - block_bundle = KVCacheBlockBundle.from_kv_cache_blocks( + block_bundle = self._block_bundle_pool.pop().init_kv_cache_blocks( tuple(blocks)) block_bundle.incr_ref() new_block_bundles.append(block_bundle) @@ -272,8 +275,8 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # The block is the master block of its KVCacheBlockBundle. # See comments in cache_full_blocks for details. assert cached_block.master_block_id == block.block_id - assert cached_block.ref_cnt == 0 - cached_block.reset_hash() + cached_block.reset() + self._block_bundle_pool.append(cached_block) del cached_blocks[block.block_id] if len(cached_blocks) == 0: del self.cached_block_hash_to_block[manager_id][block_hash] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3f382bb7126f..6149eff0dcda 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -865,16 +865,24 @@ def decr_ref(self): def master_block_id(self): return self.blocks[0].block_id - @staticmethod - def from_kv_cache_blocks(blocks: tuple[KVCacheBlock, ...]): - return KVCacheBlockBundle(blocks=blocks, - block_hash=blocks[0].block_hash) + def init_kv_cache_blocks( + self, blocks: tuple[KVCacheBlock, ...]) -> 'KVCacheBlockBundle': + assert self.block_hash is None + assert self.ref_cnt == 0 + self.blocks = blocks + self.block_hash = blocks[0].block_hash + return self def reset_hash(self): for block in self.blocks: block.reset_hash() self.block_hash = None + def reset(self): + assert self.ref_cnt == 0 + self.reset_hash() + self.blocks = () + def block_hash_is_none(self): return self.block_hash is None and all(block.block_hash is None for block in self.blocks) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index c43701cc682c..eea5d574d9b1 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -288,7 +288,7 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # the last matched block. self.sliding_window_contiguous_blocks += 1 single_null_block = block_pool.null_block - self._null_block = KVCacheBlockBundle.from_kv_cache_blocks( + self._null_block = KVCacheBlockBundle( tuple([single_null_block] * self.num_kv_cache_groups)) def find_longest_cache_hit(self, block_hashes: list[BlockHashType], From b94ed656f79a429c6438a01f4e3b3aedf48fc2bb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 16 May 2025 06:33:08 -0700 Subject: [PATCH 07/44] update notes Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 2 +- vllm/v1/core/kv_cache_utils.py | 11 ++++++----- vllm/v1/kv_cache_interface.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 2c164189d553..ae23db7fba35 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -183,7 +183,7 @@ def cache_full_blocks( # Update and added the full block to the cache. blk.init_block_hash(block_hash, manager_id) # We make all blocks in the same KVCacheBlockBundle cached & - # uncached together. This is achieved by: + # evicted together. This is achieved by: # 1. Here, use the master_block_id as the representative of the # KVCacheBlockBundle in the cache. # 2. In `free_blocks`, add the master block to the free list before diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6149eff0dcda..a5923dafd204 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -123,6 +123,7 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None + # The single_type_kv_cache_manager this block belongs to. manager_id: int = -1 @property @@ -799,8 +800,9 @@ def get_kv_cache_config(vllm_config: VllmConfig, return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) elif is_kv_cache_page_size_uniform(kv_cache_spec): - # KV cache of all layers have the same page size. TODO notes about - # hybrid allocator + # KV cache of all layers have the same page size. Split the layers into + # groups with the same number of layers, and thus same total page size. + # See KVCacheConfig.kv_cache_groups for more details. return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec, available_memory) @@ -845,9 +847,8 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): # KVCacheBlocks for the same block of all kv cache groups with the same kv cache -# spec (and belongs to the same manager) -# TODO: more notes -# TODO: optimize the creation of KVCacheBlockBundle +# spec (and belongs to the same manager). All blocks in the bundle have the same +# block hash, and are allocated & freed & cached & evicted together. @dataclass class KVCacheBlockBundle: blocks: tuple[KVCacheBlock, ...] diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 821c08a78dbe..f058732706a2 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -216,7 +216,7 @@ class KVCacheConfig: 1. A model only uses full attention. The pattern is (num_hidden_layers * full), so there is only one group and the block table is shared by all layers. - 2. (WIP) A model with 10 full attention layers and 20 sliding window + 2. A model with 10 full attention layers and 20 sliding window attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so there are 3 groups, each of which represents 10 layers in the model. """ From deafbda98752e39ca547c34174a52b63dc564ea1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 16 May 2025 06:57:50 -0700 Subject: [PATCH 08/44] fix config Signed-off-by: Chen Zhang --- vllm/engine/arg_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0f18af5d8559..deaba2829f74 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -368,7 +368,7 @@ class EngineArgs: bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input - disable_hybrid_allocator: bool = ( + disable_hybrid_kv_cache_manager: bool = ( SchedulerConfig.disable_hybrid_kv_cache_manager) guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend @@ -817,6 +817,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **scheduler_kwargs["disable_chunked_mm_input"]) scheduler_group.add_argument("--scheduler-cls", **scheduler_kwargs["scheduler_cls"]) + scheduler_group.add_argument( + "--disable-hybrid-kv-cache-manager", + **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -1126,6 +1129,8 @@ def create_engine_config( max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, + disable_hybrid_kv_cache_manager=self. + disable_hybrid_kv_cache_manager, ) lora_config = LoRAConfig( From 18798da7d45876933a846f7107d30395bf5bd5c1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 16 May 2025 10:16:44 -0700 Subject: [PATCH 09/44] fix tests in v1/core Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 41 ++- tests/v1/core/test_prefix_caching.py | 273 ++++++++++--------- tests/v1/core/test_scheduler.py | 23 +- tests/v1/core/test_specialized_manager.py | 40 ++- vllm/v1/core/block_pool.py | 11 +- vllm/v1/core/kv_cache_coordinator.py | 2 - vllm/v1/core/kv_cache_manager.py | 18 +- vllm/v1/core/kv_cache_utils.py | 3 +- vllm/v1/core/single_type_kv_cache_manager.py | 8 +- 9 files changed, 225 insertions(+), 194 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e572100fe7a1..6ee4f2d46b19 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -19,7 +19,7 @@ hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor, + KVCacheGroupSpec, KVCacheNewTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -75,15 +75,8 @@ def test_kv_cache_block(): # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) assert block.block_id == 0 - assert block.ref_cnt == 0 assert block.block_hash is None - # Test reference count manipulation - block.incr_ref() - assert block.ref_cnt == 1 - block.decr_ref() - assert block.ref_cnt == 0 - # Test block hash setting and resetting block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3)) block.block_hash = block_hash @@ -387,8 +380,8 @@ def test_unify_kv_cache_configs(): KVCacheConfig( num_blocks=10, tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), + "layer2": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -399,8 +392,8 @@ def test_unify_kv_cache_configs(): KVCacheConfig( num_blocks=20, tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), + "layer2": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -417,8 +410,8 @@ def test_unify_kv_cache_configs(): KVCacheConfig( num_blocks=10, tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), + "layer2": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -429,8 +422,8 @@ def test_unify_kv_cache_configs(): KVCacheConfig( num_blocks=20, tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), + "layer2": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer2"], @@ -448,8 +441,8 @@ def test_unify_kv_cache_configs(): KVCacheConfig( num_blocks=10, tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), + "layer2": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -460,8 +453,8 @@ def test_unify_kv_cache_configs(): KVCacheConfig( num_blocks=20, tensors={ - "layer1": KVCacheTensor(100), - "layer2": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), + "layer2": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -584,7 +577,7 @@ def test_allocate_with_lookahead(): config = KVCacheConfig( num_blocks=10, tensors={ - "layer1": KVCacheTensor(100), + "layer1": KVCacheNewTensor(100), }, kv_cache_groups=[ KVCacheGroupSpec(["layer1"], @@ -607,7 +600,7 @@ def test_allocate_with_lookahead(): num_new_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) - assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks + assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks kv_cache_manager = KVCacheManager(kv_cache_config=config, @@ -618,7 +611,7 @@ def test_allocate_with_lookahead(): num_new_tokens=3, num_lookahead_tokens=2, ) - assert len(blocks.blocks) == 2 + assert len(blocks.get_block_ids()[0]) == 2 # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 @@ -629,4 +622,4 @@ def test_allocate_with_lookahead(): num_new_tokens=3, num_lookahead_tokens=4, ) - assert len(blocks.blocks) == 2 + assert len(blocks.get_block_ids()[0]) == 2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3da27786b1f2..e7d20fb989e6 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -13,7 +13,7 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - hash_block_tokens) + KVCacheBlockBundle, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -59,8 +59,9 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) def test_prefill(hash_algo): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, caching_hash_algo=hash_algo, @@ -78,11 +79,11 @@ def test_prefill(hash_algo): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 - assert not computed_blocks.blocks + assert len(manager.req_to_block_hashes[req0.request_id][block_size]) == 3 + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] @@ -93,28 +94,29 @@ def test_prefill(hash_algo): block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) assert manager.block_pool.blocks[block_id].block_hash == block_hash - assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None - assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + for block in blocks.blocks[0]: + assert block.ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(manager.req_to_block_hashes[req1.request_id][block_size]) == 3 assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[5]] - for block in computed_blocks.blocks: + for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -140,22 +142,18 @@ def test_prefill(hash_algo): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(manager.req_to_block_hashes[req2.request_id][block_size]) == 3 assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[6]] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. assert manager.block_pool.free_block_queue.num_free_blocks == 6 - assert all([ - b.ref_cnt == 0 - for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ]) assert len([ b for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) == 6 @@ -165,10 +163,10 @@ def test_prefill(hash_algo): # Cache miss and eviction. req3 = make_request("3", [99] * (16 * 10)) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 10, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) # This block ID order also checks the eviction order. assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] @@ -184,8 +182,9 @@ def test_prefill_plp(): 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -202,14 +201,14 @@ def test_prefill_plp(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 0 - assert not computed_blocks.blocks + assert len(manager.req_to_block_hashes[req0.request_id][block_size]) == 0 + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] - req0_block_hashes = [b.block_hash for b in blocks.blocks] + req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata parent_block_hash = None @@ -218,13 +217,14 @@ def test_prefill_plp(): block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) assert manager.block_pool.blocks[block_id].block_hash == block_hash - assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None - assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + for block in blocks.blocks[0]: + assert block.ref_cnt == 1 # Request #1 is a non-prompt-logprobs request: # Cache hit in the common prefix when the original block is still in use. @@ -232,15 +232,15 @@ def test_prefill_plp(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(manager.req_to_block_hashes[req1.request_id][block_size]) == 3 assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[5]] - for block in computed_blocks.blocks: + for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. @@ -268,21 +268,21 @@ def test_prefill_plp(): common_token_ids + unique_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 0 - assert not computed_blocks.blocks + assert len(manager.req_to_block_hashes[req2.request_id][block_size]) == 0 + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 - assert [b.block_hash for b in blocks.blocks] == req0_block_hashes + assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes assert block_ids != [[1, 2, 3, 4]] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. - for block_id in block_ids[0]: - assert manager.block_pool.blocks[block_id].ref_cnt == 1 + for block in blocks.blocks[0]: + assert block.ref_cnt == 1 manager.free(req2) @@ -302,10 +302,10 @@ def test_decode(): unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] @@ -314,10 +314,10 @@ def test_decode(): for _ in range(4): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 0 - assert manager.single_type_manager.req_to_blocks[ + assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 + assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-1].block_hash is None # Append slots with allocating a new block. @@ -327,12 +327,12 @@ def test_decode(): for _ in range(9 + 10): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 1 - assert manager.single_type_manager.req_to_blocks[ + assert new_blocks is not None and len(new_blocks.blocks[0]) == 1 + assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-2].block_hash is not None - assert manager.single_type_manager.req_to_blocks[ + assert manager.coordinator.single_type_managers[0].req_to_blocks[ req0.request_id][-1].block_hash is None @@ -346,23 +346,23 @@ def test_evict(): last_token_id = 5 * 16 + 7 req0 = make_request("0", list(range(last_token_id))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 5 * 16 + 7, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 6 # 5 full + 1 partial + assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 3 * 16, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 3 # 3 full blocks + assert len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 # 10 - (6 + 3) == 1 @@ -382,7 +382,7 @@ def test_evict(): assert computed_blocks.get_block_ids() == [[1, 2]] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[10]] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -404,12 +404,12 @@ def test_hash_block_correct_reuse(): num_tokens = block_size * 1 req = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 + assert len(blocks.blocks[0]) == 1 # Deallocate the block. manager.free(req) @@ -418,15 +418,15 @@ def test_hash_block_correct_reuse(): # block is cleared. req = make_request("1", list(range(num_tokens - 1))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req, num_tokens - 1, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 + assert len(blocks.blocks[0]) == 1 - assert manager.block_pool.blocks[ - blocks.blocks[0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks.blocks[0] + [0].master_block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -445,24 +445,24 @@ def test_computed_blocks_not_evicted(): num_tokens = block_size * 1 req0 = make_request("0", list(range(num_tokens))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 - assert blocks.blocks[0].block_id == 1 + assert len(blocks.blocks[0]) == 1 + assert blocks.blocks[0][0].master_block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 - assert blocks.blocks[0].block_id == 2 + assert len(blocks.blocks[0]) == 1 + assert blocks.blocks[0][0].master_block_id == 2 # Free the blocks. manager.free(req0) @@ -472,15 +472,15 @@ def test_computed_blocks_not_evicted(): # cached block rather than the first one. req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks.blocks) == 1 - assert computed_blocks.blocks[0].block_id == 1 + assert len(computed_blocks.blocks[0]) == 1 + assert computed_blocks.blocks[0][0].master_block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 1 - assert blocks.blocks[0].block_id == 2 + assert len(blocks.blocks[0]) == 1 + assert blocks.blocks[0][0].master_block_id == 2 def test_basic_prefix_caching_disabled(): @@ -497,12 +497,12 @@ def test_basic_prefix_caching_disabled(): req1 = make_request("1", list(range(10))) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, 10, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 3 + assert len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) @@ -510,20 +510,20 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16))) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 16, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert len(blocks.blocks) == 4 + assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. req3 = make_request("3", list(range(4))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 4, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert not blocks @@ -538,6 +538,7 @@ def test_cache_blocks(hash_fn): block_pool = BlockPool( num_gpu_blocks=5, enable_caching=True, + num_single_type_managers=1, ) # Req: # Block 0: [0, 1, 2, 3] @@ -547,7 +548,10 @@ def test_cache_blocks(hash_fn): req = make_request("0", list(range(14))) # Test that blocks are cached correctly for 2 full blocks from the start. - blocks = [KVCacheBlock(block_id=i) for i in range(2)] + blocks = [ + KVCacheBlockBundle(blocks=(KVCacheBlock(block_id=i), )) + for i in range(2) + ] block_hashes: list[BlockHashType] = [] block_pool.cache_full_blocks( @@ -558,13 +562,14 @@ def test_cache_blocks(hash_fn): num_full_blocks=2, block_size=block_size, hash_fn=hash_fn, + manager_id=0, ) - assert len(block_pool.cached_block_hash_to_block) == 2 + assert len(block_pool.cached_block_hash_to_block[0]) == 2 assert all([block.block_hash is not None for block in blocks]) # Test that blocks that don't start from the beginning are cached correctly. - blocks += [KVCacheBlock(block_id=2)] + blocks += [KVCacheBlockBundle(blocks=(KVCacheBlock(block_id=2), ))] block_pool.cache_full_blocks( request=req, blocks=blocks, @@ -573,8 +578,9 @@ def test_cache_blocks(hash_fn): num_full_blocks=3, block_size=block_size, hash_fn=hash_fn, + manager_id=0, ) - assert len(block_pool.cached_block_hash_to_block) == 3 + assert len(block_pool.cached_block_hash_to_block[0]) == 3 assert blocks[0].block_hash is not None @@ -582,6 +588,7 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, @@ -614,16 +621,16 @@ def test_mm_prefix_caching(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = manager.req_to_block_hashes[req0.request_id][block_size] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 @@ -632,9 +639,9 @@ def test_mm_prefix_caching(): for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 @@ -652,7 +659,7 @@ def test_mm_prefix_caching(): mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(computed_blocks.blocks) == 3 + assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -675,16 +682,16 @@ def test_cache_key_salting(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = manager.req_to_block_hashes[req0.request_id][block_size] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None assert block_hashes[2].extra_keys is None blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[1, 2, 3, 4]] req0.num_computed_tokens = 59 @@ -693,9 +700,9 @@ def test_cache_key_salting(): for _ in range(5): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert new_blocks is not None and len(new_blocks.blocks) == 0 + assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 # Now one more block that should not have extra keys. assert len(block_hashes) == 4 @@ -706,16 +713,16 @@ def test_cache_key_salting(): req1 = make_request("1", token_ids, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. - assert len(computed_blocks.blocks) == 3 + assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * block_size # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 req2 = make_request("2", token_ids, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(computed_blocks.blocks) == 0 + assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req2.request_id] + block_hashes = manager.req_to_block_hashes[req2.request_id][block_size] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) @@ -738,20 +745,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): common_token_ids = [i for i in range(3) for _ in range(16)] req0 = make_request("0", common_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, - len(computed_blocks.blocks) * 16, computed_blocks) - block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[ + req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert computed_blocks.blocks == block_part0 + assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, - len(computed_blocks.blocks) * 16, computed_blocks) - block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[ + req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) @@ -762,10 +773,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Req1-5(F)| Req2-0 | Req2-1 | ... | req2 = make_request("2", [7] * block_size * 2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -773,11 +785,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert manager.block_pool.free_block_queue.num_free_blocks == 5 req3 = make_request("3", common_token_ids * 3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) - assert computed_blocks.blocks == block_part1 + assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. assert manager.allocate_slots(req3, 48, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) is None # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} @@ -786,6 +798,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): + block_size = 16 manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, @@ -803,10 +816,10 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert len(computed_blocks.blocks) == 3 + assert len(manager.req_to_block_hashes[req1.request_id][block_size]) == 3 + assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, - len(computed_blocks.blocks) * 16, + len(computed_blocks.blocks[0]) * 16, computed_blocks) assert blocks.get_block_ids() == [[5]] @@ -819,7 +832,7 @@ def test_reset_prefix_cache(): manager.free(req1) assert manager.reset_prefix_cache() - assert not manager.block_pool.cached_block_hash_to_block + assert not manager.block_pool.cached_block_hash_to_block[0] assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) @@ -836,10 +849,11 @@ def test_prefix_cache_stats_disabled(): # Call all functions that check whether log_stats is disabled. req = make_request("0", list(range(16))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert not computed_blocks.blocks + assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req, 16, - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -869,7 +883,7 @@ def test_kv_cache_events(blocks_to_cache: int): block = events[-1] assert (len(block.block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + manager.block_pool.cached_block_hash_to_block[0])) assert len(block.token_ids) == block.block_size * len(block.block_hashes) assert len(manager.block_pool.kv_event_queue) == 0 @@ -888,7 +902,7 @@ def test_kv_cache_events(blocks_to_cache: int): assert len(events) == blocks_to_cache + 1 assert (isinstance(events[-2], BlockRemoved)) assert (len(events[-1].block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + manager.block_pool.cached_block_hash_to_block[0])) # All Blocks Cleared # Should see a single all blocks cleared event @@ -897,7 +911,7 @@ def test_kv_cache_events(blocks_to_cache: int): events = manager.take_events() assert isinstance(events[-1], AllBlocksCleared) - assert len(manager.block_pool.cached_block_hash_to_block) == 0 + assert len(manager.block_pool.cached_block_hash_to_block[0]) == 0 def test_eagle_enabled_removes_last_block(): @@ -918,7 +932,8 @@ def test_eagle_enabled_removes_last_block(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) manager.free(req) # New request with same tokens + Eagle enabled @@ -928,7 +943,7 @@ def test_eagle_enabled_removes_last_block(): # Should retain 1 block: # 1. Original 3 blocks → pop last hash → 2 matched blocks # 2. drop last matched block → 1 remaining block - assert len(computed_blocks.blocks) == 1 + assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # 16 tokens @@ -948,14 +963,15 @@ def test_eagle_with_partial_blocks(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) manager.free(req) # New request with Eagle enabled req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks.blocks) == 1 + assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size @@ -988,9 +1004,11 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks) * 16, computed_blocks) + len(computed_blocks.blocks[0]) * 16, + computed_blocks) # record the block hash of the first block in the request for later use - block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + block_hash_first_block = manager.req_to_block_hashes[ + req.request_id][block_size][0] assert block_hash_first_block is not None manager.free(req) @@ -998,13 +1016,14 @@ def test_eagle_with_sliding_window(): req_eagle = make_request("partial_eagle", token_ids) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining - assert len(computed_blocks.blocks) == 1 + assert len(computed_blocks.blocks[0]) == 1 assert num_tokens == 1 * block_size # Evict the first block in the request - assert manager.block_pool.get_cached_block( - block_hash_first_block) is not None - manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block) + assert manager.block_pool.get_cached_block(block_hash_first_block, + manager_id=0) is not None + manager.block_pool.cached_block_hash_to_block[0].pop( + block_hash_first_block) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids) @@ -1012,5 +1031,5 @@ def test_eagle_with_sliding_window(): # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, # there will be no matched prefix. - assert len(computed_blocks.blocks) == 0 + assert len(computed_blocks.blocks[0]) == 0 assert num_tokens == 0 diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f40d477a0036..9d8ba51ee84b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -812,13 +812,13 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: - blocks = (scheduler.kv_cache_manager.single_type_manager. - req_to_blocks[req_id]) + blocks = (scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[req_id]) hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] - assert (scheduler.kv_cache_manager.single_type_manager. + assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS - assert len(hashes) == EXPECTED_TOTAL_BLOCKS + assert len(hashes[block_size]) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. BLOCKS_PER_REQ = num_tokens / block_size @@ -1196,21 +1196,22 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len( - scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len( - scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + # TODO(Chen): find a way to test no leak on ref_cnt. # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. - for block in scheduler.kv_cache_manager.block_pool.blocks: - assert block.ref_cnt == 0 - # assert block._block_hash is None + # for block in scheduler.kv_cache_manager.block_pool.blocks: + # assert block.ref_cnt == 0 + # assert block._block_hash is None # assert ( # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block # ) == 0) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 101a2379be37..49ba2e15454c 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,7 +3,8 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + KVCacheBlockBundle) from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec @@ -13,7 +14,8 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): block_pool, use_eagle=False, num_kv_cache_groups=1, - caching_hash_fn=lambda x: x) + caching_hash_fn=lambda x: x, + manager_id=0) def test_sliding_window_possible_cached_prefix(): @@ -27,7 +29,9 @@ def test_sliding_window_possible_cached_prefix(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=100, + enable_caching=True, + num_single_type_managers=1) manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): @@ -35,14 +39,14 @@ def run_one_case(block_is_cached, expect_length): BlockHashType(i, ()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block.clear() + block_pool.cached_block_hash_to_block[0].clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[block_hash] = { - i: block_pool.blocks[i + 10] + block_pool.cached_block_hash_to_block[0][block_hash] = { + i: KVCacheBlockBundle(blocks=(block_pool.blocks[i + 10], )) } computed_blocks = manager.find_longest_cache_hit( @@ -50,13 +54,13 @@ def run_one_case(block_is_cached, expect_length): len(block_hash_list) * block_size) assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block + assert all(block == manager.null_block for block in computed_blocks[:expect_length - 2]) for i in range(2): if i < expect_length: block_index = expect_length - i - 1 assert computed_blocks[ - block_index].block_id == block_index + 10 + block_index].master_block_id == block_index + 10 run_one_case([False] * 10, 0) run_one_case([True], 1) @@ -88,29 +92,33 @@ def test_sliding_window_remove_skipped_blocks(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, + enable_caching=True, + num_single_type_managers=1) manager = get_sliding_window_manager(sliding_window_spec, block_pool) null_block_id = block_pool.null_block.block_id - def id_to_block_table(ids): + def id_to_block_table(ids) -> list[KVCacheBlockBundle]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlockBundle(blocks=(KVCacheBlock(id_), )) + if id_ != null_block_id else manager.null_block for id_ in ids ] - def assert_block_id(block_table, ids): + def assert_block_id(block_table: list[KVCacheBlockBundle], ids: list[int]): for block, id_ in zip(block_table, ids): if id_ == null_block_id: - assert block == block_pool.null_block + assert block == manager.null_block else: - assert block.block_id == id_ + assert block.master_block_id == id_ original_block_ids = [ 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 ] block_table = id_to_block_table(original_block_ids) + for block in block_table: + block.incr_ref() manager.req_to_blocks["test"] = block_table manager.remove_skipped_blocks("test", 0) @@ -143,3 +151,5 @@ def assert_block_id(block_table, ids): # of removed blocks should be [1003, 1002]. manager.remove_skipped_blocks("test", 11) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) + + manager.free("test") diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ae23db7fba35..967e36c532c4 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -45,7 +45,9 @@ def __init__( KVCacheBlock(idx) for idx in range(num_gpu_blocks) ] # A pool of block bundle instances, to avoid frequent creation of - # KVCacheBlockBundle class. + # KVCacheBlockBundle class. As each KVCacheBlockBundle contains a + # distinct set of blocks, the number of KVCacheBlockBundle object won't + # exceed num_gpu_blocks. self._block_bundle_pool: deque[KVCacheBlockBundle] = deque( KVCacheBlockBundle(blocks=()) for _ in range(num_gpu_blocks)) # Free block queue that constructs and manipulates a doubly linked @@ -275,8 +277,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # The block is the master block of its KVCacheBlockBundle. # See comments in cache_full_blocks for details. assert cached_block.master_block_id == block.block_id - cached_block.reset() - self._block_bundle_pool.append(cached_block) + self._block_bundle_pool.append(cached_block.reset()) del cached_blocks[block.block_id] if len(cached_blocks) == 0: del self.cached_block_hash_to_block[manager_id][block_hash] @@ -326,6 +327,10 @@ def free_blocks(self, if block != self.null_block: self.free_block_queue.append(block) + if (block_bundle.block_hash is None and + block_bundle.master_block_id != self.null_block.block_id): + self._block_bundle_pool.append(block_bundle.reset()) + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index a8de932fabcd..1b6768f346ab 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -184,7 +184,6 @@ def get_block_ids(self, request_id: str) -> list[list[KVCacheBlockBundle]]: def find_longest_cache_hit( self, - request: Request, block_hashes_dict: dict[int, list[BlockHashType]], max_cache_hit_length: int, ) -> tuple[list[list[KVCacheBlockBundle]], int]: @@ -192,7 +191,6 @@ def find_longest_cache_hit( Find the longest cache hit for the request. Args: - request: The request. block_hashes_dict: The block hashes of the request. max_cache_hit_length: The maximum length of the cache hit. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index f94f19d229f0..81ace317f5a8 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Callable, Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -99,8 +99,13 @@ def __init__( # Mapping from request ID to kv block hashes of all block sizes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. + empty_block_hash_fn: Callable[[], dict[int, list[BlockHashType]]] = ( + lambda: { + block_size: [] + for block_size in self.all_block_sizes + }) self.req_to_block_hashes: defaultdict[str, dict[ - int, list[BlockHashType]]] = defaultdict(dict) + int, list[BlockHashType]]] = defaultdict(empty_block_hash_fn) @property def usage(self) -> float: @@ -144,8 +149,8 @@ def get_computed_blocks(self, # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: + block_hashes = self.req_to_block_hashes.get(request.request_id, None) + if block_hashes is None: block_hashes = { block_size: hash_request_tokens(self.caching_hash_fn, block_size, request) @@ -164,9 +169,8 @@ def get_computed_blocks(self, # num_computed_tokens to be block-size aligned. Removing this limitation # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 - computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(request, block_hashes, + self.coordinator.find_longest_cache_hit(block_hashes, max_cache_hit_length)) if self.log_stats: @@ -262,7 +266,7 @@ def allocate_slots( if self.enable_caching: self.block_pool.touch(new_computed_block_list) else: - assert not new_computed_block_list, ( + assert all(not blocks for blocks in new_computed_block_list), ( "Computed blocks should be empty when " "prefix caching is disabled") diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a5923dafd204..774e33bb1e83 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -879,10 +879,11 @@ def reset_hash(self): block.reset_hash() self.block_hash = None - def reset(self): + def reset(self) -> 'KVCacheBlockBundle': assert self.ref_cnt == 0 self.reset_hash() self.blocks = () + return self def block_hash_is_none(self): return self.block_hash is None and all(block.block_hash is None diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index eea5d574d9b1..c2f44e6fba83 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -288,7 +288,7 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # the last matched block. self.sliding_window_contiguous_blocks += 1 single_null_block = block_pool.null_block - self._null_block = KVCacheBlockBundle( + self.null_block = KVCacheBlockBundle( tuple([single_null_block] * self.num_kv_cache_groups)) def find_longest_cache_hit(self, block_hashes: list[BlockHashType], @@ -299,7 +299,7 @@ def find_longest_cache_hit(self, block_hashes: list[BlockHashType], # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // self.block_size - computed_blocks = [self._null_block] * max_num_blocks + computed_blocks = [self.null_block] * max_num_blocks num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. @@ -335,13 +335,13 @@ def remove_skipped_blocks(self, request_id: str, blocks = self.req_to_blocks[request_id] removed_blocks: list[KVCacheBlockBundle] = [] for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: + if blocks[i] == self.null_block: # If the block is already a null block, the blocks before it # should also have been set to null blocks by the previous calls # to this function. break removed_blocks.append(blocks[i]) - blocks[i] = self._null_block + blocks[i] = self.null_block self.block_pool.free_blocks(removed_blocks) def get_num_common_prefix_blocks(self, request_id: str, From 94bd895b89fe360277479a7a7f116e5c9116ae9d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 19 May 2025 20:16:33 -0700 Subject: [PATCH 10/44] mapping as clas attribute Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 36 ++++++++++++++++++-------------- vllm/v1/core/sched/scheduler.py | 3 +-- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 81ace317f5a8..c87712393083 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, ClassVar, Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -20,21 +20,25 @@ @dataclass class KVCacheBlocks: blocks: list[list[KVCacheBlockBundle]] - group_to_manager: list[tuple[int, int]] + """ + blocks[i][j].blocks[k] refers to the i-th single_type_manager, the j-th + block of the tokens, and the k-th kv cache group managed by that + single_type_manager. + """ + group_to_manager: ClassVar[list[tuple[int, int]]] = [] + """ + tuple(manager_id, group_id_in_manager) for each kv cache group. + """ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" - assert self.group_to_manager is other.group_to_manager return KVCacheBlocks( - [blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)], - self.group_to_manager) + [blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)]) @classmethod - def create_empty( - cls, group_to_manager: list[tuple[int, int]]) -> "KVCacheBlocks": + def create_empty(cls) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" - return cls([[] for _ in range(len(group_to_manager))], - group_to_manager) + return cls([[] for _ in range(len(cls.group_to_manager))]) def get_block_ids(self) -> list[list[int]]: """ @@ -92,6 +96,7 @@ def __init__( enable_kv_cache_events=enable_kv_cache_events, ) self.group_to_manager = self.coordinator.group_to_manager + KVCacheBlocks.group_to_manager = self.group_to_manager self.block_pool = self.coordinator.block_pool self.all_block_sizes = set(g.kv_cache_spec.block_size @@ -145,7 +150,7 @@ def get_computed_blocks(self, # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching or request.sampling_params.prompt_logprobs is not None): - return KVCacheBlocks.create_empty(self.group_to_manager), 0 + return KVCacheBlocks.create_empty(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -178,8 +183,7 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += len(request.all_token_ids) self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks, - self.group_to_manager), num_new_computed_tokens + return KVCacheBlocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, @@ -281,7 +285,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(new_blocks, self.group_to_manager) + return KVCacheBlocks(new_blocks) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with @@ -290,7 +294,7 @@ def allocate_slots( request, self.req_to_block_hashes[request.request_id], num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) - return KVCacheBlocks(new_blocks, self.group_to_manager) + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -379,5 +383,5 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" - return KVCacheBlocks(self.coordinator.get_block_ids(request_id), - self.group_to_manager).get_block_ids() + return KVCacheBlocks( + self.coordinator.get_block_ids(request_id)).get_block_ids() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4a0df4f652ed..e002f5bcb679 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -352,8 +352,7 @@ def schedule(self) -> SchedulerOutput: request) else: # P/D: skip checking prefix cache if loaded from remote kvs. - new_computed_blocks = KVCacheBlocks.create_empty( - self.kv_cache_manager.group_to_manager) + new_computed_blocks = KVCacheBlocks.create_empty() num_native_computed_tokens = 0 # Get externally-cached tokens if using a KVConnector. From e50e33ced19978ba2ded43ed319637754f5268ba Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 23 May 2025 02:51:25 -0700 Subject: [PATCH 11/44] simplify KVCacheBlocks Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 55 +++++++++++++---- vllm/v1/core/kv_cache_manager.py | 62 ++++++++------------ vllm/v1/core/sched/scheduler.py | 1 - vllm/v1/core/single_type_kv_cache_manager.py | 3 + 4 files changed, 74 insertions(+), 47 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1b6768f346ab..e8aa8d1e7eda 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -3,7 +3,8 @@ from typing import Callable from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlockBundle +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, + KVCacheBlockBundle) from vllm.v1.core.single_type_kv_cache_manager import ( FullAttentionManager, SingleTypeKVCacheManager, get_manager_for_kv_cache_spec) @@ -49,6 +50,7 @@ def __init__( manager_id=i, caching_hash_fn=caching_hash_fn, )) + self.computed_blocks: dict[str, list[list[KVCacheBlockBundle]]] = {} self.verify_support_find_longest_cache_hit() def get_num_blocks_to_allocate( @@ -89,7 +91,7 @@ def save_new_computed_blocks( new_computed_blocks[i]) def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[list[KVCacheBlockBundle]]: + num_tokens: int) -> list[list[KVCacheBlock]]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -106,7 +108,7 @@ def allocate_new_blocks(self, request_id: str, for manager in self.single_type_managers: new_blocks.append( manager.allocate_new_blocks(request_id, num_tokens)) - return new_blocks + return self.to_groups(new_blocks) def cache_blocks(self, request: Request, block_hashes: dict[int, list[BlockHashType]], @@ -173,24 +175,26 @@ def remove_skipped_blocks(self, request_id: str, for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, num_computed_tokens) - def get_block_ids(self, request_id: str) -> list[list[KVCacheBlockBundle]]: + def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: """ - Get the block IDs for the request. + Get the blocks for the request. """ - return [ + return self.to_groups([ manager.req_to_blocks[request_id] for manager in self.single_type_managers - ] + ]) def find_longest_cache_hit( self, + request_id: str, block_hashes_dict: dict[int, list[BlockHashType]], max_cache_hit_length: int, - ) -> tuple[list[list[KVCacheBlockBundle]], int]: + ) -> tuple[list[list[KVCacheBlock]], int]: """ Find the longest cache hit for the request. Args: + request_id: The request ID. block_hashes_dict: The block hashes of the request. max_cache_hit_length: The maximum length of the cache hit. @@ -205,7 +209,9 @@ def find_longest_cache_hit( 0].kv_cache_spec.block_size hit_blocks = self.single_type_managers[0].find_longest_cache_hit( block_hashes_dict[block_size], max_length=max_cache_hit_length) - return [hit_blocks], len(hit_blocks) * block_size + if len(hit_blocks) > 0: + self.computed_blocks[request_id] = [hit_blocks] + return self.to_groups([hit_blocks]), len(hit_blocks) * block_size elif len(self.single_type_managers) == 2: # For simplicity, we assume the first manager is for full @@ -234,13 +240,31 @@ def find_longest_cache_hit( # cache hit of the other attention. del hit_blocks_full_attn[hit_length // block_size_0:] - return [hit_blocks_full_attn, hit_blocks_other_attn], hit_length + hit_blocks_two_mgr = [hit_blocks_full_attn, hit_blocks_other_attn] + if hit_length > 0: + self.computed_blocks[request_id] = hit_blocks_two_mgr + return self.to_groups(hit_blocks_two_mgr), hit_length else: raise NotImplementedError( "KVCacheCoordinator does not support more than 2 different" "types of layers yet.") + def get_computed_blocks( + self, request_id: str, + num_computed_tokens: int) -> list[list[KVCacheBlockBundle]]: + """ + Get the computed blocks for the request. + """ + if num_computed_tokens == 0: + assert request_id not in self.computed_blocks + return [[] for _ in self.single_type_managers] + computed_blocks = self.computed_blocks.pop(request_id) + for i, manager in enumerate(self.single_type_managers): + assert len(computed_blocks[i] * + manager.block_size) == num_computed_tokens + return computed_blocks + def generate_group_manager_map( self) -> tuple[list[list[int]], list[tuple[int, int]]]: """ @@ -301,3 +325,14 @@ def verify_support_find_longest_cache_hit(self) -> None: raise NotImplementedError( "KVCacheCoordinator does not support more than 2 different " "types of layers yet.") + + def to_groups( + self, block_bundles: list[list[KVCacheBlockBundle]] + ) -> list[list[KVCacheBlock]]: + blocks = [] + for manager_id, group_id_in_manager in self.group_to_manager: + blocks.append([ + blk.blocks[group_id_in_manager] + for blk in block_bundles[manager_id] + ]) + return blocks diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c87712393083..bc2780023d31 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -8,7 +8,7 @@ from vllm.logger import init_logger from vllm.utils import sha256 from vllm.v1.core.kv_cache_coordinator import KVCacheCoordinator -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlockBundle, +from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats @@ -19,27 +19,18 @@ @dataclass class KVCacheBlocks: - blocks: list[list[KVCacheBlockBundle]] + blocks: list[list[KVCacheBlock]] """ - blocks[i][j].blocks[k] refers to the i-th single_type_manager, the j-th - block of the tokens, and the k-th kv cache group managed by that - single_type_manager. - """ - group_to_manager: ClassVar[list[tuple[int, int]]] = [] - """ - tuple(manager_id, group_id_in_manager) for each kv cache group. + blocks[i][j].blocks[k] refers to the i-th kv_cache_group and the j-th + block of the tokens. """ + num_kv_cache_groups: ClassVar[int] def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( [blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)]) - @classmethod - def create_empty(cls) -> "KVCacheBlocks": - """Creates a new KVCacheBlocks instance with no blocks.""" - return cls([[] for _ in range(len(cls.group_to_manager))]) - def get_block_ids(self) -> list[list[int]]: """ Converts the KVCacheBlocks instance to block_ids. @@ -50,18 +41,20 @@ def get_block_ids(self) -> list[list[int]]: * each inner list contains the block_ids of the blocks in that group """ block_ids = [] - for manager_id, group_id_in_manager in self.group_to_manager: - block_ids.append([ - blk.blocks[group_id_in_manager].block_id - for blk in self.blocks[manager_id] - ]) + for group in self.blocks: + block_ids.append([blk.block_id for blk in group]) return block_ids + @classmethod + def create_empty(cls) -> "KVCacheBlocks": + """Creates a new KVCacheBlocks instance with no blocks.""" + return cls([[] for _ in range(cls.num_kv_cache_groups)]) + def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" - assert len(self.group_to_manager) == 1, "Only one group is supported" + assert self.num_kv_cache_groups == 1, "Only one group is supported" return [ - block.master_block_id for block in self.blocks[0] + block.block_id for block in self.blocks[0] if block.block_hash is None ] @@ -95,8 +88,8 @@ def __init__( caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, ) - self.group_to_manager = self.coordinator.group_to_manager - KVCacheBlocks.group_to_manager = self.group_to_manager + KVCacheBlocks.num_kv_cache_groups = len( + kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.all_block_sizes = set(g.kv_cache_spec.block_size @@ -175,7 +168,8 @@ def get_computed_blocks(self, # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(block_hashes, + self.coordinator.find_longest_cache_hit(request.request_id, + block_hashes, max_cache_hit_length)) if self.log_stats: @@ -190,7 +184,6 @@ def allocate_slots( request: Request, num_new_tokens: int, num_new_computed_tokens: int = 0, - new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: @@ -232,12 +225,9 @@ def allocate_slots( if num_new_tokens == 0: raise ValueError("num_new_tokens must be greater than 0") - if new_computed_blocks is not None: - new_computed_block_list = new_computed_blocks.blocks - else: - new_computed_block_list = [ - [] for _ in self.coordinator.single_type_managers - ] + # Get the new computed blocks detected by get_computed_blocks. + new_computed_blocks = self.coordinator.get_computed_blocks( + request.request_id, num_new_computed_tokens) # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -259,7 +249,7 @@ def allocate_slots( num_blocks_to_allocate = (self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, num_tokens=num_tokens_need_slot, - new_computed_blocks=new_computed_block_list, + new_computed_blocks=new_computed_blocks, )) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -268,16 +258,16 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_block_list) + self.block_pool.touch(new_computed_blocks) else: - assert all(not blocks for blocks in new_computed_block_list), ( + assert all(not blocks for blocks in new_computed_blocks), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. self.coordinator.save_new_computed_blocks(request.request_id, - new_computed_block_list) + new_computed_blocks) new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot) @@ -384,4 +374,4 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" return KVCacheBlocks( - self.coordinator.get_block_ids(request_id)).get_block_ids() + self.coordinator.get_blocks(request_id)).get_block_ids() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e002f5bcb679..e5a9017489e6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -400,7 +400,6 @@ def schedule(self) -> SchedulerOutput: request, num_new_tokens + num_external_computed_tokens, num_native_computed_tokens, - new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index c2f44e6fba83..8bcb5594d007 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -51,6 +51,9 @@ def __init__( self.req_to_blocks: defaultdict[ str, list[KVCacheBlockBundle]] = defaultdict(list) + self.req_to_hit_blocks: defaultdict[ + str, list[KVCacheBlockBundle]] = defaultdict(list) + # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. # This is only used to track the RUNNING requests, we do not track the From 5c2887af7d1083f7a944516d1ed24d3aa04ebf0c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 23 May 2025 03:21:12 -0700 Subject: [PATCH 12/44] unify interface Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 23 ++++++++++++++--------- vllm/v1/core/kv_cache_manager.py | 10 ++++++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index e8aa8d1e7eda..7aa393f89d8e 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -91,7 +91,7 @@ def save_new_computed_blocks( new_computed_blocks[i]) def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[list[KVCacheBlock]]: + num_tokens: int) -> list[list[KVCacheBlockBundle]]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -108,7 +108,7 @@ def allocate_new_blocks(self, request_id: str, for manager in self.single_type_managers: new_blocks.append( manager.allocate_new_blocks(request_id, num_tokens)) - return self.to_groups(new_blocks) + return new_blocks def cache_blocks(self, request: Request, block_hashes: dict[int, list[BlockHashType]], @@ -175,21 +175,21 @@ def remove_skipped_blocks(self, request_id: str, for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, num_computed_tokens) - def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: + def get_blocks(self, request_id: str) -> list[list[KVCacheBlockBundle]]: """ Get the blocks for the request. """ - return self.to_groups([ + return [ manager.req_to_blocks[request_id] for manager in self.single_type_managers - ]) + ] def find_longest_cache_hit( self, request_id: str, block_hashes_dict: dict[int, list[BlockHashType]], max_cache_hit_length: int, - ) -> tuple[list[list[KVCacheBlock]], int]: + ) -> tuple[list[list[KVCacheBlockBundle]], int]: """ Find the longest cache hit for the request. @@ -211,7 +211,7 @@ def find_longest_cache_hit( block_hashes_dict[block_size], max_length=max_cache_hit_length) if len(hit_blocks) > 0: self.computed_blocks[request_id] = [hit_blocks] - return self.to_groups([hit_blocks]), len(hit_blocks) * block_size + return [hit_blocks], len(hit_blocks) * block_size elif len(self.single_type_managers) == 2: # For simplicity, we assume the first manager is for full @@ -243,7 +243,7 @@ def find_longest_cache_hit( hit_blocks_two_mgr = [hit_blocks_full_attn, hit_blocks_other_attn] if hit_length > 0: self.computed_blocks[request_id] = hit_blocks_two_mgr - return self.to_groups(hit_blocks_two_mgr), hit_length + return hit_blocks_two_mgr, hit_length else: raise NotImplementedError( @@ -326,9 +326,14 @@ def verify_support_find_longest_cache_hit(self) -> None: "KVCacheCoordinator does not support more than 2 different " "types of layers yet.") - def to_groups( + def to_group_format( self, block_bundles: list[list[KVCacheBlockBundle]] ) -> list[list[KVCacheBlock]]: + """ + Convert the blocks from `list[list[KVCacheBlockBundle]` + (`list[KVCacheBlockBundle]` for each manager) to + `list[list[KVCacheBlock]]`(`list[KVCacheBlock]` for each group). + """ blocks = [] for manager_id, group_id_in_manager in self.group_to_manager: blocks.append([ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bc2780023d31..644617aa162e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -177,7 +177,8 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += len(request.all_token_ids) self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks), num_new_computed_tokens + return KVCacheBlocks(self.coordinator.to_group_format( + computed_blocks)), num_new_computed_tokens def allocate_slots( self, @@ -275,7 +276,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(self.coordinator.to_group_format(new_blocks)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with @@ -284,7 +285,7 @@ def allocate_slots( request, self.req_to_block_hashes[request.request_id], num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(self.coordinator.to_group_format(new_blocks)) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -374,4 +375,5 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" return KVCacheBlocks( - self.coordinator.get_blocks(request_id)).get_block_ids() + self.coordinator.to_group_format( + self.coordinator.get_blocks(request_id))).get_block_ids() From 32afa9d3433df8e6c8800dbe1782511b4fef138b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 24 May 2025 03:01:58 -0700 Subject: [PATCH 13/44] add a tofix Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 774e33bb1e83..cee57ecdef37 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -633,6 +633,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = num_tokens / vllm_config.model_config.max_model_len + # TODO: fix for hybrid allocator logger.info("Maximum concurrency for %s tokens per request: %.2fx", max_model_len_str, max_concurrency) From c75c9b5b24974c70a5fca8105c72002bb5c386f0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 May 2025 05:02:42 -0700 Subject: [PATCH 14/44] a runable version without prefix caching Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 178 ++++++-------- vllm/v1/core/kv_cache_coordinator.py | 211 ++++++++--------- vllm/v1/core/kv_cache_manager.py | 32 +-- vllm/v1/core/kv_cache_utils.py | 236 +++++++++---------- vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/core/single_type_kv_cache_manager.py | 139 +++++------ 6 files changed, 359 insertions(+), 438 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 967e36c532c4..fab85a8ce4f7 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Iterable from typing import Callable, Optional from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, KVCacheBlockBundle, +from vllm.v1.core.kv_cache_utils import (BlockHashType, + BlockHashTypeWithGroupId, + FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) from vllm.v1.request import Request @@ -34,7 +35,6 @@ def __init__( self, num_gpu_blocks: int, enable_caching: bool, - num_single_type_managers: int, enable_kv_cache_events: bool = False, ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 @@ -44,70 +44,64 @@ def __init__( self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) ] - # A pool of block bundle instances, to avoid frequent creation of - # KVCacheBlockBundle class. As each KVCacheBlockBundle contains a - # distinct set of blocks, the number of KVCacheBlockBundle object won't - # exceed num_gpu_blocks. - self._block_bundle_pool: deque[KVCacheBlockBundle] = deque( - KVCacheBlockBundle(blocks=()) for _ in range(num_gpu_blocks)) # Free block queue that constructs and manipulates a doubly linked # list of free blocks (including eviction candidates when caching is # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {manager_id: {block_hash: {block ID: KVCacheBlockBundle}}}. - # A cached block is a full block with a block hash that can be used for - # prefix caching. + # {tuple[block_hash, manager_id]: {block ID: block}}. A cached block is + # a full block with a block hash that can be used for prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. - # Use KVCacheBlockBundle to make sure different kv cache groups managed - # by the same single_type_manager are cached & evicted together. # NOTE: We currently don't de-duplicate the blocks in the cache, # meaning that if a block becomes full and is cached, we don't check # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: list[dict[BlockHashType, dict[ - int, KVCacheBlockBundle]]] = [ - defaultdict(dict) for _ in range(num_single_type_managers) - ] + self.cached_block_hash_to_block: dict[BlockHashTypeWithGroupId, dict[ + int, KVCacheBlock]] = defaultdict(dict) + # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to # avoid freeing it. self.null_block = self.free_block_queue.popleft() - self.num_single_type_managers = num_single_type_managers self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] - def get_cached_block(self, block_hash: BlockHashType, - manager_id: int) -> Optional[KVCacheBlockBundle]: + def get_cached_block( + self, block_hash: BlockHashType, + kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. + TODO: update notes Args: block_hash: The hash value of the block. - manager_id: The id of the single_type_manager. + kv_cache_group_id: The id of the KV cache group. Returns: The cached block if it exists, or None. """ - cached_blocks = self.cached_block_hash_to_block[manager_id].get( - block_hash) - if not cached_blocks: - return None - first_block_id = next(iter(cached_blocks)) - return cached_blocks[first_block_id] + cached_blocks = [] + for group_id in kv_cache_group_ids: + cached_blocks_one_group = self.cached_block_hash_to_block[ + BlockHashTypeWithGroupId(block_hash, group_id)] + if not cached_blocks_one_group: + return None + first_block_id = next(iter(cached_blocks_one_group)) + cached_blocks.append(cached_blocks_one_group[first_block_id]) + return cached_blocks def cache_full_blocks( self, request: Request, - blocks: list[KVCacheBlockBundle], + blocks: list[KVCacheBlock], block_hashes: list[BlockHashType], num_cached_blocks: int, num_full_blocks: int, block_size: int, - manager_id: int, + kv_cache_group_id: int, hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. @@ -127,7 +121,7 @@ def cache_full_blocks( num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. - manager_id: The id of the single_type_manager. + kv_cache_group_id: The id of the KV cache group. hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: @@ -142,13 +136,13 @@ def cache_full_blocks( else: prev_block = blocks[num_cached_blocks - 1] assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.hash_value + prev_block_hash_value = prev_block.block_hash.get_hash_value() parent_block_hash = prev_block_hash_value new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): - assert blk.block_hash_is_none() + assert blk.block_hash is None if i < len(new_block_hashes): # The block hash may already be computed in @@ -183,21 +177,11 @@ def cache_full_blocks( block_hashes.append(block_hash) # Update and added the full block to the cache. - blk.init_block_hash(block_hash, manager_id) - # We make all blocks in the same KVCacheBlockBundle cached & - # evicted together. This is achieved by: - # 1. Here, use the master_block_id as the representative of the - # KVCacheBlockBundle in the cache. - # 2. In `free_blocks`, add the master block to the free list before - # adding the other blocks in the bundle. - # 3. In `_maybe_evict_cached_block`, as the master block is in front - # of other blocks in the bundle, it will be the first evicted block - # in the bundle. When a master block needs to be evicted, we remove - # the full bundle from cached_block_hash_to_block and remove the - # master block from free_block_queue. The other blocks are still in - # the free_block_queue but won't be hit by get_cached_block. - self.cached_block_hash_to_block[manager_id][block_hash][ - blk.master_block_id] = blk + block_hash_with_group_id = BlockHashTypeWithGroupId( + block_hash, kv_cache_group_id) + blk.block_hash = block_hash_with_group_id + self.cached_block_hash_to_block[block_hash_with_group_id][ + blk.block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) prev_block_hash_value = block_hash.hash_value @@ -215,9 +199,8 @@ def cache_full_blocks( if request.lora_request else None, )) - def get_new_block_bundles(self, num_block_bundle: int, - bundle_size: int) -> list[KVCacheBlockBundle]: - """Get new block bundles from the free block pool. + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: + """Get new blocks from the free block pool. Note that we do not check block cache in this function. @@ -228,33 +211,26 @@ def get_new_block_bundles(self, num_block_bundle: int, Returns: A list of new block. """ - num_total_blocks = num_block_bundle * bundle_size - if num_total_blocks > self.get_num_free_blocks(): + if num_blocks > self.get_num_free_blocks(): raise ValueError( - f"Cannot get {num_total_blocks} free blocks from the pool") + f"Cannot get {num_blocks} free blocks from the pool") - new_blocks: list[KVCacheBlock] = [] + ret: list[KVCacheBlock] = [] idx = 0 - while idx < num_total_blocks: + while idx < num_blocks: # First allocate blocks. curr_block = self.free_block_queue.popleft() + assert curr_block.ref_cnt == 0 # If the block is cached, evict it. if self.enable_caching: self._maybe_evict_cached_block(curr_block) - assert curr_block.block_hash is None - new_blocks.append(curr_block) + curr_block.incr_ref() + ret.append(curr_block) idx += 1 - new_block_bundles: list[KVCacheBlockBundle] = [] - for i in range(num_block_bundle): - blocks = new_blocks[i * bundle_size:(i + 1) * bundle_size] - block_bundle = self._block_bundle_pool.pop().init_kv_cache_blocks( - tuple(blocks)) - block_bundle.incr_ref() - new_block_bundles.append(block_bundle) - return new_block_bundles + return ret def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: """ @@ -268,26 +244,20 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: True if the block is evicted, False otherwise. """ block_hash = block.block_hash - manager_id = block.manager_id - if block_hash and block_hash in self.cached_block_hash_to_block[ - manager_id]: - cached_blocks = self.cached_block_hash_to_block[manager_id][ - block_hash] - cached_block = cached_blocks[block.block_id] - # The block is the master block of its KVCacheBlockBundle. - # See comments in cache_full_blocks for details. - assert cached_block.master_block_id == block.block_id - self._block_bundle_pool.append(cached_block.reset()) - del cached_blocks[block.block_id] - if len(cached_blocks) == 0: - del self.cached_block_hash_to_block[manager_id][block_hash] + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] + if self.enable_kv_cache_events: self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.hash_value])) + BlockRemoved(block_hashes=[block_hash.get_hash_value()])) return True return False - def touch(self, blocks: list[list[KVCacheBlockBundle]]) -> None: + def touch(self, blocks: list[list[KVCacheBlock]]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -296,17 +266,14 @@ def touch(self, blocks: list[list[KVCacheBlockBundle]]) -> None: blocks: A list of blocks to touch. """ for blocks_one_manager in blocks: - for block_bundle in blocks_one_manager: - if block_bundle.ref_cnt == 0: - # ref_cnt=0 means the blocks are in the free list (i.e. - # eviction candidate), so remove them. - for block in block_bundle.blocks: - if block != self.null_block: - self.free_block_queue.remove(block) - block_bundle.incr_ref() - - def free_blocks(self, - ordered_blocks: Iterable[KVCacheBlockBundle]) -> None: + for block in blocks_one_manager: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. @@ -314,22 +281,11 @@ def free_blocks(self, ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - for block_bundle in ordered_blocks: - block_bundle.decr_ref() - if block_bundle.ref_cnt > 0: - continue - # NOTE: should add the master block to the free list before adding - # the other blocks. See the comment in `cache_full_blocks` - # for the reason. The following loop implicitly achieves it because - # the master block is the first block in the bundle. - for block in block_bundle.blocks: - # null_block should not be added to the free list. - if block != self.null_block: - self.free_block_queue.append(block) - - if (block_bundle.block_hash is None and - block_bundle.master_block_id != self.null_block.block_id): - self._block_bundle_pool.append(block_bundle.reset()) + for block in ordered_blocks: + block.decr_ref() + # null_block should not be added to the free list. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -348,9 +304,7 @@ def reset_prefix_cache(self) -> bool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = [ - defaultdict(dict) for _ in range(self.num_single_type_managers) - ] + self.cached_block_hash_to_block = defaultdict(dict) # Remove all hashes from all blocks. for block in self.blocks: diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 7aa393f89d8e..4bc0ccbdf298 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -3,11 +3,9 @@ from typing import Callable from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - KVCacheBlockBundle) +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, SingleTypeKVCacheManager, - get_manager_for_kv_cache_spec) + SingleTypeKVCacheManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.request import Request @@ -29,33 +27,27 @@ def __init__( self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len - # One manager for each different kv_cache_spec, managing all kv cache - # groups with the same kv_cache_spec. - self.manager_to_group, self.group_to_manager = ( - self.generate_group_manager_map()) self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, - len(self.manager_to_group), enable_kv_cache_events) self.single_type_managers: list[SingleTypeKVCacheManager] = [] - for i in range(len(self.manager_to_group)): - group_ids = self.manager_to_group[i] - kv_cache_spec = kv_cache_config.kv_cache_groups[ - group_ids[0]].kv_cache_spec + for i in range(len(self.kv_cache_config.kv_cache_groups)): + kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + i].kv_cache_spec self.single_type_managers.append( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, use_eagle=use_eagle, - num_kv_cache_groups=len(self.manager_to_group[i]), - manager_id=i, + kv_cache_group_id=i, caching_hash_fn=caching_hash_fn, )) - self.computed_blocks: dict[str, list[list[KVCacheBlockBundle]]] = {} - self.verify_support_find_longest_cache_hit() + + self.type0_group_ids, self.type1_group_ids = ( + self.verify_support_find_longest_cache_hit()) def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, - new_computed_blocks: list[list[KVCacheBlockBundle]]) -> int: + new_computed_blocks: list[list[KVCacheBlock]]) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -77,7 +69,7 @@ def get_num_blocks_to_allocate( def save_new_computed_blocks( self, request_id: str, - new_computed_blocks: list[list[KVCacheBlockBundle]]) -> None: + new_computed_blocks: list[list[KVCacheBlock]]) -> None: """ Add the new computed blocks to the request. @@ -91,7 +83,7 @@ def save_new_computed_blocks( new_computed_blocks[i]) def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[list[KVCacheBlockBundle]]: + num_tokens: int) -> list[list[KVCacheBlock]]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -151,15 +143,11 @@ def get_num_common_prefix_blocks( Returns: The number of common prefix blocks. """ - num_blocks_per_manager = [ + num_blocks_per_group = [ manager.get_num_common_prefix_blocks(request_id, num_running_requests) for manager in self.single_type_managers ] - num_blocks_per_group = [ - num_blocks_per_manager[manager_id] - for manager_id, _ in self.group_to_manager - ] return num_blocks_per_group def remove_skipped_blocks(self, request_id: str, @@ -175,7 +163,7 @@ def remove_skipped_blocks(self, request_id: str, for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, num_computed_tokens) - def get_blocks(self, request_id: str) -> list[list[KVCacheBlockBundle]]: + def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: """ Get the blocks for the request. """ @@ -189,7 +177,7 @@ def find_longest_cache_hit( request_id: str, block_hashes_dict: dict[int, list[BlockHashType]], max_cache_hit_length: int, - ) -> tuple[list[list[KVCacheBlockBundle]], int]: + ) -> tuple[list[list[KVCacheBlock]], int]: """ Find the longest cache hit for the request. @@ -203,67 +191,53 @@ def find_longest_cache_hit( - A list of the cache hit blocks for each single type manager. - The number of tokens of the longest cache hit. """ - if len(self.single_type_managers) == 1: - # Return the cache hit blocks for the only kv cache group. - block_size = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size - hit_blocks = self.single_type_managers[0].find_longest_cache_hit( - block_hashes_dict[block_size], max_length=max_cache_hit_length) - if len(hit_blocks) > 0: - self.computed_blocks[request_id] = [hit_blocks] - return [hit_blocks], len(hit_blocks) * block_size - - elif len(self.single_type_managers) == 2: - # For simplicity, we assume the first manager is for full - # attention layers, and the block_size of full attention layers - # is divisible by other attention layers. This has been verified - # in verify_support_find_longest_cache_hit(). - - block_size_0 = self.single_type_managers[0].block_size - block_size_1 = self.single_type_managers[1].block_size - - # First, find the longest cache hit for full attention. - hit_blocks_full_attn = self.single_type_managers[ - 0].find_longest_cache_hit(block_hashes_dict[block_size_0], - max_length=max_cache_hit_length) - hit_length = len(hit_blocks_full_attn) * block_size_0 - - # Next, find the cache hit for the other attention WITHIN - # the cache hit of full attention. - hit_blocks_other_attn = self.single_type_managers[ - 1].find_longest_cache_hit(block_hashes_dict[block_size_1], - max_length=hit_length) - hit_length = len(hit_blocks_other_attn) * block_size_1 - assert hit_length % block_size_0 == 0 - - # Truncate the full attention cache hit to the length of the - # cache hit of the other attention. - del hit_blocks_full_attn[hit_length // block_size_0:] - - hit_blocks_two_mgr = [hit_blocks_full_attn, hit_blocks_other_attn] - if hit_length > 0: - self.computed_blocks[request_id] = hit_blocks_two_mgr - return hit_blocks_two_mgr, hit_length - - else: - raise NotImplementedError( - "KVCacheCoordinator does not support more than 2 different" - "types of layers yet.") - - def get_computed_blocks( - self, request_id: str, - num_computed_tokens: int) -> list[list[KVCacheBlockBundle]]: - """ - Get the computed blocks for the request. - """ - if num_computed_tokens == 0: - assert request_id not in self.computed_blocks - return [[] for _ in self.single_type_managers] - computed_blocks = self.computed_blocks.pop(request_id) - for i, manager in enumerate(self.single_type_managers): - assert len(computed_blocks[i] * - manager.block_size) == num_computed_tokens - return computed_blocks + return [[] for _ in self.kv_cache_config.kv_cache_groups], 0 + # if len(self.kv_cache_config.kv_cache_groups) == 1: + # # Return the cache hit blocks for the only kv cache group. + # block_size = self.kv_cache_config.kv_cache_groups[ + # 0].kv_cache_spec.block_size + # hit_blocks = self.single_type_managers[0].find_longest_cache_hit( + # block_hashes_dict[block_size], max_length=max_cache_hit_length) # noqa + # if len(hit_blocks) > 0: + # self.computed_blocks[request_id] = [hit_blocks] + # return [hit_blocks], len(hit_blocks) * block_size + + # elif len(self.kv_cache_config.kv_cache_groups) > 1: + # # For simplicity, we assume the first manager is for full + # # attention layers, and the block_size of full attention layers + # # is divisible by other attention layers. This has been verified + # # in verify_support_find_longest_cache_hit(). + + # block_size_0 = self.single_type_managers[0].block_size + # block_size_1 = self.single_type_managers[1].block_size + + # # First, find the longest cache hit for full attention. + # hit_blocks_full_attn = self.single_type_managers[ + # 0].find_longest_cache_hit(block_hashes_dict[block_size_0], + # max_length=max_cache_hit_length) + # hit_length = len(hit_blocks_full_attn) * block_size_0 + + # # Next, find the cache hit for the other attention WITHIN + # # the cache hit of full attention. + # hit_blocks_other_attn = self.single_type_managers[ + # 1].find_longest_cache_hit(block_hashes_dict[block_size_1], + # max_length=hit_length) + # hit_length = len(hit_blocks_other_attn) * block_size_1 + # assert hit_length % block_size_0 == 0 + + # # Truncate the full attention cache hit to the length of the + # # cache hit of the other attention. + # del hit_blocks_full_attn[hit_length // block_size_0:] + + # hit_blocks_two_mgr = [hit_blocks_full_attn, hit_blocks_other_attn] + # if hit_length > 0: + # self.computed_blocks[request_id] = hit_blocks_two_mgr + # return hit_blocks_two_mgr, hit_length + + # else: + # raise AssertionError("This line should be unreachable as " + # "unsupported cases should be caught by " + # "verify_support_find_longest_cache_hit()") def generate_group_manager_map( self) -> tuple[list[list[int]], list[tuple[int, int]]]: @@ -300,44 +274,45 @@ def generate_group_manager_map( ] return manager_to_group, group_to_manager - def verify_support_find_longest_cache_hit(self) -> None: + def verify_support_find_longest_cache_hit( + self) -> tuple[list[int], list[int]]: """ For simplicity, find_longest_cache_hit makes some assumptions on the model architecture instead of provides a general solution. This function checks if the assumptions hold. NOTE(Chen): Please open an issue to discuss if you need other cases. + + TODO: add more notes """ - if len(self.single_type_managers) == 1: - return - if len(self.single_type_managers) == 2: - if not isinstance(self.single_type_managers[0], - FullAttentionManager): - raise NotImplementedError( - "KVCacheCoordinator assumes hybrid models have at least one" - " full attention layer now") - block_size_0 = self.single_type_managers[0].block_size - block_size_1 = self.single_type_managers[1].block_size + if len(self.kv_cache_config.kv_cache_groups) == 1: + return list(range(len(self.kv_cache_config.kv_cache_groups))), [] + else: + groups_by_type_id: dict[str, list[int]] = defaultdict(list) + full_attention_type_ids: set[str] = set() + for i, g in enumerate(self.kv_cache_config.kv_cache_groups): + groups_by_type_id[g.kv_cache_spec.type_id].append(i) + if isinstance(g.kv_cache_spec, FullAttentionSpec): + full_attention_type_ids.add(g.kv_cache_spec.type_id) + + assert len(full_attention_type_ids) == 1, ( + "find_longest_cache_hit assumes hybrid models have exactly " + "one type of full attention groups now") + assert len(groups_by_type_id) == 2, ( + "find_longest_cache_hit assumes hybrid models have exactly " + "one other type of groups except full attention now") + + type0_group_ids = groups_by_type_id[next( + iter(full_attention_type_ids))] + type1_group_ids = groups_by_type_id[next( + iter(groups_by_type_id.keys() - full_attention_type_ids))] + + block_size_0 = self.kv_cache_config.kv_cache_groups[ + type0_group_ids[0]].kv_cache_spec.block_size + block_size_1 = self.kv_cache_config.kv_cache_groups[ + type1_group_ids[0]].kv_cache_spec.block_size if block_size_1 % block_size_0 != 0: raise NotImplementedError( "KVCacheCoordinator assumes the block_size of the full " "attention layer is divisible by other layers now.") - else: - raise NotImplementedError( - "KVCacheCoordinator does not support more than 2 different " - "types of layers yet.") - def to_group_format( - self, block_bundles: list[list[KVCacheBlockBundle]] - ) -> list[list[KVCacheBlock]]: - """ - Convert the blocks from `list[list[KVCacheBlockBundle]` - (`list[KVCacheBlockBundle]` for each manager) to - `list[list[KVCacheBlock]]`(`list[KVCacheBlock]` for each group). - """ - blocks = [] - for manager_id, group_id_in_manager in self.group_to_manager: - blocks.append([ - blk.blocks[group_id_in_manager] - for blk in block_bundles[manager_id] - ]) - return blocks + return type0_group_ids, type1_group_ids diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 644617aa162e..954aa8a405b9 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -21,8 +21,7 @@ class KVCacheBlocks: blocks: list[list[KVCacheBlock]] """ - blocks[i][j].blocks[k] refers to the i-th kv_cache_group and the j-th - block of the tokens. + blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. """ num_kv_cache_groups: ClassVar[int] @@ -91,6 +90,7 @@ def __init__( KVCacheBlocks.num_kv_cache_groups = len( kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool + self.kv_cache_config = kv_cache_config self.all_block_sizes = set(g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups) @@ -177,14 +177,14 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += len(request.all_token_ids) self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(self.coordinator.to_group_format( - computed_blocks)), num_new_computed_tokens + return KVCacheBlocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, request: Request, num_new_tokens: int, num_new_computed_tokens: int = 0, + new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: @@ -226,9 +226,12 @@ def allocate_slots( if num_new_tokens == 0: raise ValueError("num_new_tokens must be greater than 0") - # Get the new computed blocks detected by get_computed_blocks. - new_computed_blocks = self.coordinator.get_computed_blocks( - request.request_id, num_new_computed_tokens) + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks + else: + new_computed_block_list = [ + [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) + ] # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -250,7 +253,7 @@ def allocate_slots( num_blocks_to_allocate = (self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, num_tokens=num_tokens_need_slot, - new_computed_blocks=new_computed_blocks, + new_computed_blocks=new_computed_block_list, )) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): @@ -259,16 +262,16 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self.block_pool.touch(new_computed_blocks) + self.block_pool.touch(new_computed_block_list) else: - assert all(not blocks for blocks in new_computed_blocks), ( + assert all(not blocks for blocks in new_computed_block_list), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. self.coordinator.save_new_computed_blocks(request.request_id, - new_computed_blocks) + new_computed_block_list) new_blocks = self.coordinator.allocate_new_blocks( request.request_id, num_tokens_need_slot) @@ -276,7 +279,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(self.coordinator.to_group_format(new_blocks)) + return KVCacheBlocks(new_blocks) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with @@ -285,7 +288,7 @@ def allocate_slots( request, self.req_to_block_hashes[request.request_id], num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) - return KVCacheBlocks(self.coordinator.to_group_format(new_blocks)) + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -375,5 +378,4 @@ def take_events(self) -> list[KVCacheEvent]: def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" return KVCacheBlocks( - self.coordinator.to_group_format( - self.coordinator.get_blocks(request_id))).get_block_ids() + self.coordinator.get_blocks(request_id)).get_block_ids() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index cee57ecdef37..e4570b24f96a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" + import os from collections import defaultdict, deque from collections.abc import Sequence @@ -25,6 +26,7 @@ class BlockHashType(NamedTuple): hash collisions when the hash value is the same. By using SHA256 however, hash collisions are practically impossible. """ + # Hash value of the block in an integer. hash_value: int # Token IDs in the block. @@ -33,6 +35,14 @@ class BlockHashType(NamedTuple): extra_keys: Optional[Any] = None +class BlockHashTypeWithGroupId(NamedTuple): + block_hash: BlockHashType + group_id: int + + def get_hash_value(self) -> int: + return self.block_hash.hash_value + + # The hash seed for the first block of the prefix block sequence. # # Even if the hash function is the builtin hash(), we use sha256 to generate @@ -43,8 +53,9 @@ class BlockHashType(NamedTuple): # variable if set such that processes can share the seed if needed. # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. -NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( - 'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED')) +NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") + if os.getenv("PYTHONHASHSEED") is None else sha256( + os.getenv("PYTHONHASHSEED"))) class PrefixCachingMetrics: @@ -112,26 +123,32 @@ def hit_rate(self) -> float: @dataclass class KVCacheBlock: """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int + # Reference count. + ref_cnt: int = 0 # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. - _block_hash: Optional[BlockHashType] = None + _block_hash: Optional[BlockHashTypeWithGroupId] = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None - # The single_type_kv_cache_manager this block belongs to. - manager_id: int = -1 + def incr_ref(self): + self.ref_cnt += 1 + + def decr_ref(self): + self.ref_cnt -= 1 @property - def block_hash(self) -> Optional[BlockHashType]: + def block_hash(self) -> Optional[BlockHashTypeWithGroupId]: return self._block_hash @block_hash.setter - def block_hash(self, block_hash: BlockHashType): + def block_hash(self, block_hash: BlockHashTypeWithGroupId): assert self.block_hash is None, ( "The block already has a hash. This should not happen.") self._block_hash = block_hash @@ -144,11 +161,12 @@ def reset_hash(self): def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = self.prev_free_block.block_id \ - if self.prev_free_block else None - next_block_id = self.next_free_block.block_id \ - if self.next_free_block else None + prev_block_id = (self.prev_free_block.block_id + if self.prev_free_block else None) + next_block_id = (self.next_free_block.block_id + if self.next_free_block else None) return (f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " f"_block_hash={self._block_hash}, " f"prev_free_block={prev_block_id}, " f"next_free_block={next_block_id})") @@ -272,14 +290,16 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_positions) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return (bool(request.mm_positions) or (request.lora_request is not None) + or (request.cache_salt is not None)) -def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, - end_token_idx: int, - start_mm_idx: int) -> tuple[list[Any], int]: +def _gen_mm_extra_hash_keys( + request: Request, + start_token_idx: int, + end_token_idx: int, + start_mm_idx: int, +) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -361,8 +381,11 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + request: Request, + start_token_idx: int, + end_token_idx: int, + start_mm_idx: int, +) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -379,8 +402,9 @@ def generate_block_hash_extra_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - cache_salt_keys: list[str] = [request.cache_salt] if ( - start_token_idx == 0 and request.cache_salt) else [] + cache_salt_keys: list[str] = ([request.cache_salt] if + (start_token_idx == 0 + and request.cache_salt) else []) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys @@ -391,10 +415,11 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable, - parent_block_hash: Optional[int], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: + hash_function: Callable, + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None, +) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -415,10 +440,13 @@ def hash_block_tokens( parent_block_hash = NONE_HASH curr_block_token_ids_tuple = tuple(curr_block_token_ids) + # NOTE: not add group_id. return BlockHashType( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, extra_keys) + curr_block_token_ids_tuple, + extra_keys, + ) def hash_request_tokens(hash_function: Any, block_size: int, @@ -453,16 +481,22 @@ def hash_request_tokens(hash_function: Any, block_size: int, req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) - block_hash = hash_block_tokens(hash_function, parent_block_hash_value, - block_token_ids, req_extra_keys) + block_hash = hash_block_tokens( + hash_function, + parent_block_hash_value, + block_token_ids, + req_extra_keys, + ) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret -def estimate_max_model_len(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> int: +def estimate_max_model_len( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> int: """ Estimates the maximum model length that can fit in the available memory using binary search. @@ -508,9 +542,11 @@ def fits_in_memory(model_len: int) -> bool: return result -def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int): +def check_enough_kv_cache_memory( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +): """ Checks whether `available_memory` is enough for the KV cache to hold at least one request with the model's max_model_len. @@ -545,9 +581,9 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " + f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/GiB_bytes:.2f} GiB)." + f"memory ({available_memory / GiB_bytes:.2f} GiB)." f"{estimated_msg} " f" Try increasing `gpu_memory_utilization` or decreasing " f"`max_model_len` when initializing the engine.") @@ -557,20 +593,20 @@ def create_kv_cache_group_specs( kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: """ - Create KVCacheGroupSpec object for each kv cache group layer. - The layers in the same group should share the same - KVCacheSpec. - - Args: - kv_cache_spec: - A mapping from each layer name to its corresponding KVCacheSpec. - grouped_layer_names: - A list of kv cache groups, where each element is a list of layer - names that belong to the same group and should share the same - KVCacheSpec. - Returns: - A list of KVCacheGroupSpec objects, one for each group. - """ + Create KVCacheGroupSpec object for each kv cache group layer. + The layers in the same group should share the same + KVCacheSpec. + + Args: + kv_cache_spec: + A mapping from each layer name to its corresponding KVCacheSpec. + grouped_layer_names: + A list of kv cache groups, where each element is a list of layer + names that belong to the same group and should share the same + KVCacheSpec. + Returns: + A list of KVCacheGroupSpec objects, one for each group. + """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: layer_specs = [ @@ -597,9 +633,11 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: return len(layer_keys) == 1 -def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_config_uniform_type( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. @@ -621,11 +659,13 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override + num_gpu_blocks_override = ( + vllm_config.cache_config.num_gpu_blocks_override) logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + num_blocks, + num_gpu_blocks_override, + ) num_blocks = num_gpu_blocks_override num_tokens = num_blocks * vllm_config.cache_config.block_size @@ -634,8 +674,11 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = num_tokens / vllm_config.model_config.max_model_len # TODO: fix for hybrid allocator - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) per_layer_size = page_size * num_blocks # All layers have the same KV cache spec, so we create one kv cache group @@ -655,12 +698,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + kv_cache_spec: dict[str, KVCacheSpec], ) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. Args: kv_cache_spec: The KVCacheSpec of each attention layer in the model - + Returns: True if all layers have the same page size, False otherwise. """ @@ -670,8 +713,10 @@ def is_kv_cache_page_size_uniform( def _get_kv_cache_config_uniform_page_size( - vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one page size. Args: @@ -701,7 +746,8 @@ def _get_kv_cache_config_uniform_page_size( logger.warning( "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa num_padding_layers, - num_padding_layers / len(layers) * 100) + num_padding_layers / len(layers) * 100, + ) for i in range(0, len(layers), group_size): grouped_layers.append(layers[i:i + group_size]) @@ -735,8 +781,8 @@ def _get_kv_cache_config_uniform_page_size( def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ - This function tries to convert the KV cache specs to one type if the model - is a hybrid model with multiple type of KV cache. It will convert all + This function tries to convert the KV cache specs to one type if the model + is a hybrid model with multiple type of KV cache. It will convert all SlidingWindowSpec to FullAttentionSpec if both types are present. Args: @@ -775,9 +821,11 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: "convert the KV cache specs to one unified type.") -def get_kv_cache_config(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def get_kv_cache_config( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: """ Generates the KV cache configuration for a model. @@ -845,53 +893,3 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): kv_cache_config.num_blocks = min_num_blocks return kv_cache_configs - - -# KVCacheBlocks for the same block of all kv cache groups with the same kv cache -# spec (and belongs to the same manager). All blocks in the bundle have the same -# block hash, and are allocated & freed & cached & evicted together. -@dataclass -class KVCacheBlockBundle: - blocks: tuple[KVCacheBlock, ...] - block_hash: Optional[BlockHashType] = None - # Reference count. - ref_cnt: int = 0 - - def incr_ref(self): - self.ref_cnt += 1 - - def decr_ref(self): - self.ref_cnt -= 1 - - @property - def master_block_id(self): - return self.blocks[0].block_id - - def init_kv_cache_blocks( - self, blocks: tuple[KVCacheBlock, ...]) -> 'KVCacheBlockBundle': - assert self.block_hash is None - assert self.ref_cnt == 0 - self.blocks = blocks - self.block_hash = blocks[0].block_hash - return self - - def reset_hash(self): - for block in self.blocks: - block.reset_hash() - self.block_hash = None - - def reset(self) -> 'KVCacheBlockBundle': - assert self.ref_cnt == 0 - self.reset_hash() - self.blocks = () - return self - - def block_hash_is_none(self): - return self.block_hash is None and all(block.block_hash is None - for block in self.blocks) - - def init_block_hash(self, block_hash: BlockHashType, manager_id: int): - self.block_hash = block_hash - for b in self.blocks: - b.block_hash = block_hash - b.manager_id = manager_id diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e5a9017489e6..e002f5bcb679 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -400,6 +400,7 @@ def schedule(self) -> SchedulerOutput: request, num_new_tokens + num_external_computed_tokens, num_native_computed_tokens, + new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8bcb5594d007..7bfebb9abf94 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,7 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlockBundle +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -22,8 +22,7 @@ def __init__( kv_cache_spec: KVCacheSpec, block_pool: BlockPool, use_eagle: bool, - num_kv_cache_groups: int, - manager_id: int, + kv_cache_group_id: int, caching_hash_fn: Callable, ) -> None: """ @@ -32,9 +31,7 @@ def __init__( kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. use_eagle: Whether to use eagle. - num_kv_cache_groups: The number of kv cache groups managed by this - manager. - manager_id: The id of this manager. + kv_cache_group_id: The id of the kv cache group of this manager. caching_hash_fn: The caching hash function. """ @@ -48,11 +45,8 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[ - str, list[KVCacheBlockBundle]] = defaultdict(list) - - self.req_to_hit_blocks: defaultdict[ - str, list[KVCacheBlockBundle]] = defaultdict(list) + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. @@ -60,13 +54,12 @@ def __init__( # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.num_kv_cache_groups = num_kv_cache_groups self.caching_hash_fn = caching_hash_fn - self.manager_id = manager_id + self.kv_cache_group_id = kv_cache_group_id def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlockBundle]) -> int: + new_computed_blocks: list[KVCacheBlock]) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -90,12 +83,11 @@ def get_num_blocks_to_allocate( # it as needed to be allocated. num_evictable_computed_blocks = sum(blk.ref_cnt == 0 for blk in new_computed_blocks) - return ((num_new_blocks + num_evictable_computed_blocks) * - self.num_kv_cache_groups) + return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( self, request_id: str, - new_computed_blocks: list[KVCacheBlockBundle]) -> None: + new_computed_blocks: list[KVCacheBlock]) -> None: """ Add the new computed blocks to the request. @@ -115,7 +107,7 @@ def save_new_computed_blocks( assert len(new_computed_blocks) == 0 def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlockBundle]: + num_tokens: int) -> list[KVCacheBlock]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -134,8 +126,7 @@ def allocate_new_blocks(self, request_id: str, if num_new_blocks <= 0: return [] else: - new_blocks = self.block_pool.get_new_block_bundles( - num_new_blocks, self.num_kv_cache_groups) + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) return new_blocks @@ -160,7 +151,7 @@ def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, - manager_id=self.manager_id, + kv_cache_group_id=self.kv_cache_group_id, hash_fn=self.caching_hash_fn, ) @@ -201,7 +192,7 @@ def get_num_common_prefix_blocks(self, request_id: str, @abstractmethod def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlockBundle]: + max_length: int) -> list[KVCacheBlock]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. If no cache hit is found, return an empty list. @@ -240,22 +231,23 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlockBundle]: - computed_blocks: list[KVCacheBlockBundle] = [] - max_num_blocks = max_length // self.block_size - for i in range(max_num_blocks): - block_hash = block_hashes[i] - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block( - block_hash, self.manager_id): - computed_blocks.append(cached_block) - else: - break - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() - return computed_blocks + max_length: int) -> list[KVCacheBlock]: + return [] + # computed_blocks: list[KVCacheBlock] = [] + # max_num_blocks = max_length // self.block_size + # for i in range(max_num_blocks): + # block_hash = block_hashes[i] + # # block_hashes is a chain of block hashes. If a block hash is not + # # in the cached_block_hash_to_id, the following block hashes are + # # not computed yet for sure. + # if cached_block := self.block_pool.get_cached_block( + # block_hash, self.manager_id): + # computed_blocks.append(cached_block) + # else: + # break + # if self.use_eagle and len(computed_blocks) > 0: + # computed_blocks.pop() + # return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: @@ -290,44 +282,43 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # contiguous blocks needed for prefix cache hit by one and dropping # the last matched block. self.sliding_window_contiguous_blocks += 1 - single_null_block = block_pool.null_block - self.null_block = KVCacheBlockBundle( - tuple([single_null_block] * self.num_kv_cache_groups)) + self.null_block = block_pool.null_block def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlockBundle]: + max_length: int) -> list[KVCacheBlock]: + return [] # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # optimize the time complexity from O(len(block_hashes)) to - # O(len(block_hashes) / sliding_window_contiguous_blocks + - # sliding_window_contiguous_blocks), - # which is good for low cache hit rate scenarios. - max_num_blocks = max_length // self.block_size - computed_blocks = [self.null_block] * max_num_blocks - num_contiguous_blocks = 0 - match_found = False - # Search from right to left and early stop when a match is found. - for i in range(max_num_blocks - 1, -1, -1): - if cached_block := self.block_pool.get_cached_block( - block_hashes[i], self.manager_id): - computed_blocks[i] = cached_block - num_contiguous_blocks += 1 - if (num_contiguous_blocks - >= self.sliding_window_contiguous_blocks): - # Trim the trailing blocks. - # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] - # when sliding_window_contiguous_blocks=2. - del computed_blocks[i + num_contiguous_blocks:] - match_found = True - break - else: - num_contiguous_blocks = 0 - if not match_found: - # The first `num_contiguous_blocks` is a cache hit even if - # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - del computed_blocks[num_contiguous_blocks:] - if self.use_eagle and len(computed_blocks) > 0: - computed_blocks.pop() - return computed_blocks + # # optimize the time complexity from O(len(block_hashes)) to + # # O(len(block_hashes) / sliding_window_contiguous_blocks + + # # sliding_window_contiguous_blocks), + # # which is good for low cache hit rate scenarios. + # max_num_blocks = max_length // self.block_size + # computed_blocks = [self.null_block] * max_num_blocks + # num_contiguous_blocks = 0 + # match_found = False + # # Search from right to left and early stop when a match is found. + # for i in range(max_num_blocks - 1, -1, -1): + # if cached_block := self.block_pool.get_cached_block( + # block_hashes[i], self.manager_id): + # computed_blocks[i] = cached_block + # num_contiguous_blocks += 1 + # if (num_contiguous_blocks + # >= self.sliding_window_contiguous_blocks): + # # Trim the trailing blocks. + # # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # # when sliding_window_contiguous_blocks=2. + # del computed_blocks[i + num_contiguous_blocks:] + # match_found = True + # break + # else: + # num_contiguous_blocks = 0 + # if not match_found: + # # The first `num_contiguous_blocks` is a cache hit even if + # # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + # del computed_blocks[num_contiguous_blocks:] + # if self.use_eagle and len(computed_blocks) > 0: + # computed_blocks.pop() + # return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: @@ -336,7 +327,7 @@ def remove_skipped_blocks(self, request_id: str, last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlockBundle] = [] + removed_blocks: list[KVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): if blocks[i] == self.null_block: # If the block is already a null block, the blocks before it From 5800dcf509d37db850c815800296bb008f654cdf Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 May 2025 06:07:41 -0700 Subject: [PATCH 15/44] support prefix caching Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 258 ++++++++++--------- vllm/v1/core/kv_cache_manager.py | 4 +- vllm/v1/core/single_type_kv_cache_manager.py | 119 +++++---- 3 files changed, 197 insertions(+), 184 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 4bc0ccbdf298..0a0163b2f7a0 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from abc import abstractmethod from collections import defaultdict from typing import Callable @@ -42,9 +43,6 @@ def __init__( caching_hash_fn=caching_hash_fn, )) - self.type0_group_ids, self.type1_group_ids = ( - self.verify_support_find_longest_cache_hit()) - def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, new_computed_blocks: list[list[KVCacheBlock]]) -> int: @@ -172,6 +170,83 @@ def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: for manager in self.single_type_managers ] + @abstractmethod + def find_longest_cache_hit( + self, request_id: str, + block_hashes_dict: dict[int, list[BlockHashType]], + max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: + pass + + +class UnifiedKVCacheCoordinator(KVCacheCoordinator): + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, enable_caching: bool, + caching_hash_fn: Callable, enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, + enable_caching, caching_hash_fn, + enable_kv_cache_events) + self.block_size = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "UnifiedKVCacheCoordinator assumes only one kv cache group") + + def find_longest_cache_hit( + self, request_id: str, + block_hashes_dict: dict[int, list[BlockHashType]], + max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: + hit_blocks = self.single_type_managers[0].find_longest_cache_hit( + block_hashes_dict[self.block_size], max_cache_hit_length, [0]) + return hit_blocks, len(hit_blocks[0]) * self.block_size + + +class HybridKVCacheCoordinator(KVCacheCoordinator): + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, enable_caching: bool, + caching_hash_fn: Callable, enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, + enable_caching, caching_hash_fn, + enable_kv_cache_events) + self.initialize_group_ids() + + def initialize_group_ids(self) -> None: + """ + For simplicity, find_longest_cache_hit makes some assumptions on the + model architecture instead of provides a general solution. This function + checks if the assumptions hold. + NOTE(Chen): Please open an issue to discuss if you need other cases. + + TODO: add more notes + """ + groups_by_type_id: dict[str, list[int]] = defaultdict(list) + full_attention_type_ids: set[str] = set() + for i, g in enumerate(self.kv_cache_config.kv_cache_groups): + groups_by_type_id[g.kv_cache_spec.type_id].append(i) + if isinstance(g.kv_cache_spec, FullAttentionSpec): + full_attention_type_ids.add(g.kv_cache_spec.type_id) + + assert len(full_attention_type_ids) == 1, ( + "find_longest_cache_hit assumes hybrid models have exactly " + "one type of full attention groups now") + assert len(groups_by_type_id) == 2, ( + "find_longest_cache_hit assumes hybrid models have exactly " + "one other type of groups except full attention now") + + self.full_attention_group_ids = groups_by_type_id[next( + iter(full_attention_type_ids))] + self.other_group_ids = groups_by_type_id[next( + iter(groups_by_type_id.keys() - full_attention_type_ids))] + + self.full_attention_block_size = self.kv_cache_config.kv_cache_groups[ + self.full_attention_group_ids[0]].kv_cache_spec.block_size + self.other_block_size = self.kv_cache_config.kv_cache_groups[ + self.other_group_ids[0]].kv_cache_spec.block_size + if self.other_block_size % self.full_attention_block_size != 0: + raise NotImplementedError( + "KVCacheCoordinator assumes the block_size of the full " + "attention layer is divisible by other layers now.") + def find_longest_cache_hit( self, request_id: str, @@ -191,128 +266,55 @@ def find_longest_cache_hit( - A list of the cache hit blocks for each single type manager. - The number of tokens of the longest cache hit. """ - return [[] for _ in self.kv_cache_config.kv_cache_groups], 0 - # if len(self.kv_cache_config.kv_cache_groups) == 1: - # # Return the cache hit blocks for the only kv cache group. - # block_size = self.kv_cache_config.kv_cache_groups[ - # 0].kv_cache_spec.block_size - # hit_blocks = self.single_type_managers[0].find_longest_cache_hit( - # block_hashes_dict[block_size], max_length=max_cache_hit_length) # noqa - # if len(hit_blocks) > 0: - # self.computed_blocks[request_id] = [hit_blocks] - # return [hit_blocks], len(hit_blocks) * block_size - - # elif len(self.kv_cache_config.kv_cache_groups) > 1: - # # For simplicity, we assume the first manager is for full - # # attention layers, and the block_size of full attention layers - # # is divisible by other attention layers. This has been verified - # # in verify_support_find_longest_cache_hit(). - - # block_size_0 = self.single_type_managers[0].block_size - # block_size_1 = self.single_type_managers[1].block_size - - # # First, find the longest cache hit for full attention. - # hit_blocks_full_attn = self.single_type_managers[ - # 0].find_longest_cache_hit(block_hashes_dict[block_size_0], - # max_length=max_cache_hit_length) - # hit_length = len(hit_blocks_full_attn) * block_size_0 - - # # Next, find the cache hit for the other attention WITHIN - # # the cache hit of full attention. - # hit_blocks_other_attn = self.single_type_managers[ - # 1].find_longest_cache_hit(block_hashes_dict[block_size_1], - # max_length=hit_length) - # hit_length = len(hit_blocks_other_attn) * block_size_1 - # assert hit_length % block_size_0 == 0 - - # # Truncate the full attention cache hit to the length of the - # # cache hit of the other attention. - # del hit_blocks_full_attn[hit_length // block_size_0:] - - # hit_blocks_two_mgr = [hit_blocks_full_attn, hit_blocks_other_attn] - # if hit_length > 0: - # self.computed_blocks[request_id] = hit_blocks_two_mgr - # return hit_blocks_two_mgr, hit_length - - # else: - # raise AssertionError("This line should be unreachable as " - # "unsupported cases should be caught by " - # "verify_support_find_longest_cache_hit()") - - def generate_group_manager_map( - self) -> tuple[list[list[int]], list[tuple[int, int]]]: - """ - Generate the mapping between kv cache groups and managers. - - Returns: - manager_to_group: list[list[int]], the kv cache groups managed by - each manager. - group_to_manager: list[tuple[int, int]], the manager id and the - index of the group in the manager for each kv cache group. - """ - groups_by_type_id: dict[str, list[int]] = defaultdict(list) - full_attention_type_ids: set[str] = set() - for i, g in enumerate(self.kv_cache_config.kv_cache_groups): - groups_by_type_id[g.kv_cache_spec.type_id].append(i) - if isinstance(g.kv_cache_spec, FullAttentionSpec): - full_attention_type_ids.add(g.kv_cache_spec.type_id) - - manager_to_group = [] - for type_id in full_attention_type_ids: - manager_to_group.append(groups_by_type_id[type_id]) - for type_id in groups_by_type_id.keys() - full_attention_type_ids: - manager_to_group.append(groups_by_type_id[type_id]) - - group_to_manager_dict = { - group_id: (manager_id, group_id_in_manager) - for manager_id, group_ids in enumerate(manager_to_group) - for group_id_in_manager, group_id in enumerate(group_ids) - } - group_to_manager = [ - group_to_manager_dict[i] - for i in range(len(self.kv_cache_config.kv_cache_groups)) - ] - return manager_to_group, group_to_manager - - def verify_support_find_longest_cache_hit( - self) -> tuple[list[int], list[int]]: - """ - For simplicity, find_longest_cache_hit makes some assumptions on the - model architecture instead of provides a general solution. This function - checks if the assumptions hold. - NOTE(Chen): Please open an issue to discuss if you need other cases. - - TODO: add more notes - """ - if len(self.kv_cache_config.kv_cache_groups) == 1: - return list(range(len(self.kv_cache_config.kv_cache_groups))), [] - else: - groups_by_type_id: dict[str, list[int]] = defaultdict(list) - full_attention_type_ids: set[str] = set() - for i, g in enumerate(self.kv_cache_config.kv_cache_groups): - groups_by_type_id[g.kv_cache_spec.type_id].append(i) - if isinstance(g.kv_cache_spec, FullAttentionSpec): - full_attention_type_ids.add(g.kv_cache_spec.type_id) - - assert len(full_attention_type_ids) == 1, ( - "find_longest_cache_hit assumes hybrid models have exactly " - "one type of full attention groups now") - assert len(groups_by_type_id) == 2, ( - "find_longest_cache_hit assumes hybrid models have exactly " - "one other type of groups except full attention now") - - type0_group_ids = groups_by_type_id[next( - iter(full_attention_type_ids))] - type1_group_ids = groups_by_type_id[next( - iter(groups_by_type_id.keys() - full_attention_type_ids))] - - block_size_0 = self.kv_cache_config.kv_cache_groups[ - type0_group_ids[0]].kv_cache_spec.block_size - block_size_1 = self.kv_cache_config.kv_cache_groups[ - type1_group_ids[0]].kv_cache_spec.block_size - if block_size_1 % block_size_0 != 0: - raise NotImplementedError( - "KVCacheCoordinator assumes the block_size of the full " - "attention layer is divisible by other layers now.") - - return type0_group_ids, type1_group_ids + # For simplicity, we assume the first manager is for full + # attention layers, and the block_size of full attention layers + # is divisible by other attention layers. This has been verified + # in verify_support_find_longest_cache_hit(). + + # First, find the longest cache hit for full attention. + hit_blocks_full_attn = self.single_type_managers[ + 0].find_longest_cache_hit( + block_hashes_dict[self.full_attention_block_size], + max_length=max_cache_hit_length, + kv_cache_group_ids=self.full_attention_group_ids) + hit_length = len( + hit_blocks_full_attn[0]) * self.full_attention_block_size + + # Next, find the cache hit for the other attention WITHIN + # the cache hit of full attention. + hit_blocks_other_attn = self.single_type_managers[ + 1].find_longest_cache_hit(block_hashes_dict[self.other_block_size], + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids) + hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size + assert hit_length % self.full_attention_block_size == 0 + + # Truncate the full attention cache hit to the length of the + # cache hit of the other attention. + for i in range(len(hit_blocks_full_attn)): + del hit_blocks_full_attn[i][hit_length // + self.full_attention_block_size:] + # Merge the hit blocks of full attention and other attention. + hit_blocks = hit_blocks_other_attn + for group_id, blocks in enumerate(hit_blocks_full_attn): + del blocks[hit_length // self.full_attention_block_size:] + # NOTE: there is only one full attention group in most cases. So + # the time complexity of insert is fine. + hit_blocks.insert(group_id, blocks) + return hit_blocks, hit_length + + +def get_kv_cache_coordinator( + kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, + enable_caching: bool, caching_hash_fn: Callable, + enable_kv_cache_events: bool) -> KVCacheCoordinator: + if len(kv_cache_config.kv_cache_groups) == 1: + return UnifiedKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, + caching_hash_fn, + enable_kv_cache_events) + else: + return HybridKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, + caching_hash_fn, + enable_kv_cache_events) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 954aa8a405b9..08b20493f2d1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -7,7 +7,7 @@ from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.utils import sha256 -from vllm.v1.core.kv_cache_coordinator import KVCacheCoordinator +from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) from vllm.v1.kv_cache_interface import KVCacheConfig @@ -79,7 +79,7 @@ def __init__( # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.coordinator = KVCacheCoordinator( + self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7bfebb9abf94..3b9a9b6c5397 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -191,8 +191,9 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError @abstractmethod - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlock]: + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType], max_length: int, + kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. If no cache hit is found, return an empty list. @@ -230,24 +231,29 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlock]: - return [] - # computed_blocks: list[KVCacheBlock] = [] - # max_num_blocks = max_length // self.block_size - # for i in range(max_num_blocks): - # block_hash = block_hashes[i] - # # block_hashes is a chain of block hashes. If a block hash is not - # # in the cached_block_hash_to_id, the following block hashes are - # # not computed yet for sure. - # if cached_block := self.block_pool.get_cached_block( - # block_hash, self.manager_id): - # computed_blocks.append(cached_block) - # else: - # break - # if self.use_eagle and len(computed_blocks) > 0: - # computed_blocks.pop() - # return computed_blocks + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType], max_length: int, + kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: + # NOTE: different from other list[list[KVCacheBlock]] + computed_blocks: list[list[KVCacheBlock]] = [ + [] for _ in range(len(kv_cache_group_ids)) + ] + max_num_blocks = max_length // self.block_size + for i in range(max_num_blocks): + block_hash = block_hashes[i] + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.block_pool.get_cached_block( + block_hash, kv_cache_group_ids): + for j in range(len(kv_cache_group_ids)): + computed_blocks[j].append(cached_block[j]) + else: + break + if self.use_eagle and len(computed_blocks) > 0: + for j in range(len(kv_cache_group_ids)): + computed_blocks[j].pop() + return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: @@ -284,41 +290,46 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, self.sliding_window_contiguous_blocks += 1 self.null_block = block_pool.null_block - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], - max_length: int) -> list[KVCacheBlock]: - return [] + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType], max_length: int, + kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # # optimize the time complexity from O(len(block_hashes)) to - # # O(len(block_hashes) / sliding_window_contiguous_blocks + - # # sliding_window_contiguous_blocks), - # # which is good for low cache hit rate scenarios. - # max_num_blocks = max_length // self.block_size - # computed_blocks = [self.null_block] * max_num_blocks - # num_contiguous_blocks = 0 - # match_found = False - # # Search from right to left and early stop when a match is found. - # for i in range(max_num_blocks - 1, -1, -1): - # if cached_block := self.block_pool.get_cached_block( - # block_hashes[i], self.manager_id): - # computed_blocks[i] = cached_block - # num_contiguous_blocks += 1 - # if (num_contiguous_blocks - # >= self.sliding_window_contiguous_blocks): - # # Trim the trailing blocks. - # # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] - # # when sliding_window_contiguous_blocks=2. - # del computed_blocks[i + num_contiguous_blocks:] - # match_found = True - # break - # else: - # num_contiguous_blocks = 0 - # if not match_found: - # # The first `num_contiguous_blocks` is a cache hit even if - # # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - # del computed_blocks[num_contiguous_blocks:] - # if self.use_eagle and len(computed_blocks) > 0: - # computed_blocks.pop() - # return computed_blocks + # optimize the time complexity from O(len(block_hashes)) to + # O(len(block_hashes) / sliding_window_contiguous_blocks + + # sliding_window_contiguous_blocks), + # which is good for low cache hit rate scenarios. + max_num_blocks = max_length // self.block_size + computed_blocks = [[self.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids))] + num_contiguous_blocks = 0 + match_found = False + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := self.block_pool.get_cached_block( + block_hashes[i], kv_cache_group_ids): + for j in range(len(kv_cache_group_ids)): + computed_blocks[j][i] = cached_block[j] + num_contiguous_blocks += 1 + if (num_contiguous_blocks + >= self.sliding_window_contiguous_blocks): + # Trim the trailing blocks. + # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # when sliding_window_contiguous_blocks=2. + for j in range(len(kv_cache_group_ids)): + del computed_blocks[j][i + num_contiguous_blocks:] + match_found = True + break + else: + num_contiguous_blocks = 0 + if not match_found: + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + for j in range(len(kv_cache_group_ids)): + del computed_blocks[j][num_contiguous_blocks:] + if self.use_eagle and len(computed_blocks) > 0: + for j in range(len(kv_cache_group_ids)): + computed_blocks[j].pop() + return computed_blocks def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: From c96e13b72ac3eb05c4779aeb9ec66892f7c861f2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 May 2025 06:28:54 -0700 Subject: [PATCH 16/44] fix bug for --disable-hybrid-kv-cache-manager Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e4570b24f96a..0b1b618c2d44 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -816,7 +816,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: sliding_window=spec.sliding_window, ) - if not is_hybrid(kv_cache_spec): + if is_hybrid(kv_cache_spec): raise ValueError("Hybrid KV cache manager is disabled but failed to " "convert the KV cache specs to one unified type.") From d8ad1be423e16f7db285f564072ea43e1b642c68 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 May 2025 07:22:09 -0700 Subject: [PATCH 17/44] clean up Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 14 ++-- vllm/v1/core/kv_cache_coordinator.py | 24 +++---- vllm/v1/core/kv_cache_manager.py | 37 +++++----- vllm/v1/core/kv_cache_utils.py | 75 ++++++++------------ vllm/v1/core/single_type_kv_cache_manager.py | 14 ++-- 5 files changed, 66 insertions(+), 98 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index fab85a8ce4f7..2638a08e6b22 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -27,7 +27,6 @@ class BlockPool: Args: num_gpu_blocks: The number of blocks in the pool. enable_caching: Whether to enable prefix caching. - num_single_type_managers: The number of single_type_managers. enable_kv_cache_events: Whether to enable kv cache events. """ @@ -49,7 +48,7 @@ def __init__( # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {tuple[block_hash, manager_id]: {block ID: block}}. A cached block is + # {block_hash: {block ID: block}}. A cached block is # a full block with a block hash that can be used for prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. @@ -72,16 +71,16 @@ def __init__( def get_cached_block( self, block_hash: BlockHashType, kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: - """Get a cached block by the block hash, or None if cache miss. + """Get the cached block by the block hash for each group in + `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. - TODO: update notes Args: block_hash: The hash value of the block. - kv_cache_group_id: The id of the KV cache group. + kv_cache_group_ids: The ids of the KV cache groups. Returns: - The cached block if it exists, or None. + The cached blocks if exists, or None. """ cached_blocks = [] for group_id in kv_cache_group_ids: @@ -205,8 +204,7 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: Note that we do not check block cache in this function. Args: - num_block_bundle: The number of KVCacheBlockBundle to allocate. - bundle_size: The number of blocks in each KVCacheBlockBundle. + num_blocks: The number of blocks to allocate. Returns: A list of new block. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 0a0163b2f7a0..34f9f24433cd 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -100,8 +100,7 @@ def allocate_new_blocks(self, request_id: str, manager.allocate_new_blocks(request_id, num_tokens)) return new_blocks - def cache_blocks(self, request: Request, - block_hashes: dict[int, list[BlockHashType]], + def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], num_computed_tokens: int) -> None: """ Cache the blocks for the request. @@ -113,8 +112,7 @@ def cache_blocks(self, request: Request, (including tokens that are already cached). """ for manager in self.single_type_managers: - manager.cache_blocks(request, block_hashes[manager.block_size], - num_computed_tokens) + manager.cache_blocks(request, block_hashes, num_computed_tokens) def free(self, request_id: str) -> None: """ @@ -172,8 +170,7 @@ def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: @abstractmethod def find_longest_cache_hit( - self, request_id: str, - block_hashes_dict: dict[int, list[BlockHashType]], + self, block_hashes: list[BlockHashType], max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: pass @@ -192,11 +189,10 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, "UnifiedKVCacheCoordinator assumes only one kv cache group") def find_longest_cache_hit( - self, request_id: str, - block_hashes_dict: dict[int, list[BlockHashType]], + self, block_hashes: list[BlockHashType], max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: hit_blocks = self.single_type_managers[0].find_longest_cache_hit( - block_hashes_dict[self.block_size], max_cache_hit_length, [0]) + block_hashes, max_cache_hit_length, [0]) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -249,16 +245,14 @@ def initialize_group_ids(self) -> None: def find_longest_cache_hit( self, - request_id: str, - block_hashes_dict: dict[int, list[BlockHashType]], + block_hashes: list[BlockHashType], max_cache_hit_length: int, ) -> tuple[list[list[KVCacheBlock]], int]: """ Find the longest cache hit for the request. Args: - request_id: The request ID. - block_hashes_dict: The block hashes of the request. + block_hashes: The block hashes of the request. max_cache_hit_length: The maximum length of the cache hit. Returns: @@ -274,7 +268,7 @@ def find_longest_cache_hit( # First, find the longest cache hit for full attention. hit_blocks_full_attn = self.single_type_managers[ 0].find_longest_cache_hit( - block_hashes_dict[self.full_attention_block_size], + block_hashes, max_length=max_cache_hit_length, kv_cache_group_ids=self.full_attention_group_ids) hit_length = len( @@ -283,7 +277,7 @@ def find_longest_cache_hit( # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. hit_blocks_other_attn = self.single_type_managers[ - 1].find_longest_cache_hit(block_hashes_dict[self.other_block_size], + 1].find_longest_cache_hit(block_hashes, max_length=hit_length, kv_cache_group_ids=self.other_group_ids) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 08b20493f2d1..7e4066f4a6af 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Callable, ClassVar, Optional +from typing import ClassVar, Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -78,6 +78,12 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + assert len( + set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) + ) == 1, "Only one block size is supported for now" + self.block_size = kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, @@ -92,18 +98,11 @@ def __init__( self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config - self.all_block_sizes = set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) # Mapping from request ID to kv block hashes of all block sizes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. - empty_block_hash_fn: Callable[[], dict[int, list[BlockHashType]]] = ( - lambda: { - block_size: [] - for block_size in self.all_block_sizes - }) - self.req_to_block_hashes: defaultdict[str, dict[ - int, list[BlockHashType]]] = defaultdict(empty_block_hash_fn) + self.req_to_block_hashes: defaultdict[ + str, list[BlockHashType]] = defaultdict(list) @property def usage(self) -> float: @@ -147,13 +146,10 @@ def get_computed_blocks(self, # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes.get(request.request_id, None) - if block_hashes is None: - block_hashes = { - block_size: - hash_request_tokens(self.caching_hash_fn, block_size, request) - for block_size in self.all_block_sizes - } + block_hashes = self.req_to_block_hashes[request.request_id] + if not block_hashes: + block_hashes = hash_request_tokens(self.caching_hash_fn, + self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes if self.log_stats: @@ -168,8 +164,7 @@ def get_computed_blocks(self, # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(request.request_id, - block_hashes, + self.coordinator.find_longest_cache_hit(block_hashes, max_cache_hit_length)) if self.log_stats: @@ -250,11 +245,11 @@ def allocate_slots( num_computed_tokens + num_new_tokens + num_lookahead_tokens, self.max_model_len) - num_blocks_to_allocate = (self.coordinator.get_num_blocks_to_allocate( + num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, - )) + ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 0b1b618c2d44..03d72ceb4a64 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -156,7 +156,6 @@ def block_hash(self, block_hash: BlockHashTypeWithGroupId): def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None - self.manager_id = -1 def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ @@ -294,12 +293,9 @@ def need_extra_keys(request: Request) -> bool: or (request.cache_salt is not None)) -def _gen_mm_extra_hash_keys( - request: Request, - start_token_idx: int, - end_token_idx: int, - start_mm_idx: int, -) -> tuple[list[Any], int]: +def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, + end_token_idx: int, + start_mm_idx: int) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -402,9 +398,8 @@ def generate_block_hash_extra_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - cache_salt_keys: list[str] = ([request.cache_salt] if - (start_token_idx == 0 - and request.cache_salt) else []) + cache_salt_keys: list[str] = [request.cache_salt] if ( + start_token_idx == 0 and request.cache_salt) else [] extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys @@ -415,11 +410,10 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable, - parent_block_hash: Optional[int], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None, -) -> BlockHashType: + hash_function: Callable, + parent_block_hash: Optional[int], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -444,9 +438,7 @@ def hash_block_tokens( return BlockHashType( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, - extra_keys, - ) + curr_block_token_ids_tuple, extra_keys) def hash_request_tokens(hash_function: Any, block_size: int, @@ -492,11 +484,9 @@ def hash_request_tokens(hash_function: Any, block_size: int, return ret -def estimate_max_model_len( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -) -> int: +def estimate_max_model_len(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> int: """ Estimates the maximum model length that can fit in the available memory using binary search. @@ -542,11 +532,9 @@ def fits_in_memory(model_len: int) -> bool: return result -def check_enough_kv_cache_memory( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -): +def check_enough_kv_cache_memory(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int): """ Checks whether `available_memory` is enough for the KV cache to hold at least one request with the model's max_model_len. @@ -581,9 +569,9 @@ def check_enough_kv_cache_memory( raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV " + f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory / GiB_bytes:.2f} GiB)." + f"memory ({available_memory/GiB_bytes:.2f} GiB)." f"{estimated_msg} " f" Try increasing `gpu_memory_utilization` or decreasing " f"`max_model_len` when initializing the engine.") @@ -633,11 +621,9 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: return len(layer_keys) == 1 -def _get_kv_cache_config_uniform_type( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -) -> KVCacheConfig: +def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. @@ -659,14 +645,11 @@ def _get_kv_cache_config_uniform_type( num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = ( - vllm_config.cache_config.num_gpu_blocks_override) + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override logger.info( - "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", - num_blocks, - num_gpu_blocks_override, - ) - num_blocks = num_gpu_blocks_override + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) num_tokens = num_blocks * vllm_config.cache_config.block_size num_tokens_str = f"{num_tokens:,}" @@ -698,7 +681,7 @@ def _get_kv_cache_config_uniform_type( def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec], ) -> bool: + kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. Args: @@ -713,10 +696,8 @@ def is_kv_cache_page_size_uniform( def _get_kv_cache_config_uniform_page_size( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -) -> KVCacheConfig: + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one page size. Args: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 3b9a9b6c5397..7ca90d31d92f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -219,8 +219,9 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks`. The removed - blocks should be replaced by null_block. Need to be customized for each - attention type. + blocks should be replaced by null_block. Return the removed blocks in + eviction order, where the first returned block should be evicted first. + Need to be customized for each attention type. Args: request_id: The request ID. @@ -234,7 +235,6 @@ class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit( self, block_hashes: list[BlockHashType], max_length: int, kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: - # NOTE: different from other list[list[KVCacheBlock]] computed_blocks: list[list[KVCacheBlock]] = [ [] for _ in range(len(kv_cache_group_ids)) ] @@ -288,7 +288,7 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, # contiguous blocks needed for prefix cache hit by one and dropping # the last matched block. self.sliding_window_contiguous_blocks += 1 - self.null_block = block_pool.null_block + self._null_block = block_pool.null_block def find_longest_cache_hit( self, block_hashes: list[BlockHashType], max_length: int, @@ -299,7 +299,7 @@ def find_longest_cache_hit( # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // self.block_size - computed_blocks = [[self.null_block] * max_num_blocks + computed_blocks = [[self._null_block] * max_num_blocks for _ in range(len(kv_cache_group_ids))] num_contiguous_blocks = 0 match_found = False @@ -340,13 +340,13 @@ def remove_skipped_blocks(self, request_id: str, blocks = self.req_to_blocks[request_id] removed_blocks: list[KVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self.null_block: + if blocks[i] == self._null_block: # If the block is already a null block, the blocks before it # should also have been set to null blocks by the previous calls # to this function. break removed_blocks.append(blocks[i]) - blocks[i] = self.null_block + blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) def get_num_common_prefix_blocks(self, request_id: str, From 375168888e5a733d79291964c37dcc61ed23c5c2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 May 2025 07:33:23 -0700 Subject: [PATCH 18/44] fix padding calculation Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 39 +++++++++++++--------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 03d72ceb4a64..3723c4425311 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -53,9 +53,8 @@ def get_hash_value(self) -> int: # variable if set such that processes can share the seed if needed. # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. -NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") - if os.getenv("PYTHONHASHSEED") is None else sha256( - os.getenv("PYTHONHASHSEED"))) +NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( + "PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED")) class PrefixCachingMetrics: @@ -123,7 +122,6 @@ def hit_rate(self) -> float: @dataclass class KVCacheBlock: """KV-cache block metadata.""" - # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int # Reference count. @@ -160,10 +158,10 @@ def reset_hash(self): def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = (self.prev_free_block.block_id - if self.prev_free_block else None) - next_block_id = (self.next_free_block.block_id - if self.next_free_block else None) + prev_block_id = self.prev_free_block.block_id \ + if self.prev_free_block else None + next_block_id = self.next_free_block.block_id \ + if self.next_free_block else None return (f"KVCacheBlock(block_id={self.block_id}, " f"ref_cnt={self.ref_cnt}, " f"_block_hash={self._block_hash}, " @@ -289,8 +287,9 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return (bool(request.mm_positions) or (request.lora_request is not None) - or (request.cache_salt is not None)) + return bool(request.mm_positions) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -377,11 +376,8 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( - request: Request, - start_token_idx: int, - end_token_idx: int, - start_mm_idx: int, -) -> tuple[Optional[tuple[Any, ...]], int]: + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -434,7 +430,6 @@ def hash_block_tokens( parent_block_hash = NONE_HASH curr_block_token_ids_tuple = tuple(curr_block_token_ids) - # NOTE: not add group_id. return BlockHashType( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), @@ -473,12 +468,8 @@ def hash_request_tokens(hash_function: Any, block_size: int, req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start, end, curr_mm_idx) - block_hash = hash_block_tokens( - hash_function, - parent_block_hash_value, - block_token_ids, - req_extra_keys, - ) + block_hash = hash_block_tokens(hash_function, parent_block_hash_value, + block_token_ids, req_extra_keys) ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret @@ -722,8 +713,8 @@ def _get_kv_cache_config_uniform_page_size( group_size = min([len(layers) for layers in same_type_layers.values()]) grouped_layers = [] for layers in same_type_layers.values(): - num_padding_layers = len(layers) % group_size - if num_padding_layers > 0: + num_padding_layers = group_size - len(layers) % group_size + if num_padding_layers != group_size: logger.warning( "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa num_padding_layers, From 159f51c56721bc8d0b94e7b5afcfa2be4e251bd6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 May 2025 07:45:35 -0700 Subject: [PATCH 19/44] clean up Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 37 ++++++++++---------- vllm/v1/core/single_type_kv_cache_manager.py | 5 ++- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 34f9f24433cd..04cd95292fa3 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -124,11 +124,8 @@ def free(self, request_id: str) -> None: for manager in self.single_type_managers: manager.free(request_id) - def get_num_common_prefix_blocks( - self, - request_id: str, - num_running_requests: int, - ) -> list[int]: + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> list[int]: """ Get the number of common prefix blocks for a request. @@ -176,6 +173,10 @@ def find_longest_cache_hit( class UnifiedKVCacheCoordinator(KVCacheCoordinator): + """ + KV cache coordinator for unified models with only one KV cache type, and + thus one kv cache group. + """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, @@ -197,6 +198,13 @@ def find_longest_cache_hit( class HybridKVCacheCoordinator(KVCacheCoordinator): + """ + KV cache coordinator for hybrid models with multiple KV cache types, and + thus multiple kv cache groups. + To simplify `find_longest_cache_hit`, it only supports the combination of + two types of KV cache groups, and one of them must be full attention. + May extend to more general cases in the future. + """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, @@ -204,16 +212,13 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, super().__init__(kv_cache_config, max_model_len, use_eagle, enable_caching, caching_hash_fn, enable_kv_cache_events) - self.initialize_group_ids() + self.verify_and_split_kv_cache_groups() - def initialize_group_ids(self) -> None: + def verify_and_split_kv_cache_groups(self) -> None: """ - For simplicity, find_longest_cache_hit makes some assumptions on the - model architecture instead of provides a general solution. This function - checks if the assumptions hold. - NOTE(Chen): Please open an issue to discuss if you need other cases. - - TODO: add more notes + Verifies that the model has exactly two types of KV cache groups, and + one of them is full attention. Then, split the kv cache groups into full + attention groups and other groups. """ groups_by_type_id: dict[str, list[int]] = defaultdict(list) full_attention_type_ids: set[str] = set() @@ -260,11 +265,6 @@ def find_longest_cache_hit( - A list of the cache hit blocks for each single type manager. - The number of tokens of the longest cache hit. """ - # For simplicity, we assume the first manager is for full - # attention layers, and the block_size of full attention layers - # is divisible by other attention layers. This has been verified - # in verify_support_find_longest_cache_hit(). - # First, find the longest cache hit for full attention. hit_blocks_full_attn = self.single_type_managers[ 0].find_longest_cache_hit( @@ -288,6 +288,7 @@ def find_longest_cache_hit( for i in range(len(hit_blocks_full_attn)): del hit_blocks_full_attn[i][hit_length // self.full_attention_block_size:] + # Merge the hit blocks of full attention and other attention. hit_blocks = hit_blocks_other_attn for group_id, blocks in enumerate(hit_blocks_full_attn): diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7ca90d31d92f..58148e1c8e3a 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -196,7 +196,9 @@ def find_longest_cache_hit( kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: """ Get the longest cache hit prefix of the blocks that is not longer than - `max_length`. If no cache hit is found, return an empty list. + `max_length`. The prefix should be a common prefix hit for all the + kv cache groups in `kv_cache_group_ids`. If no cache hit is found, + return an empty list. If eagle is enabled, drop the last matched block to force recompute the last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type. @@ -204,6 +206,7 @@ def find_longest_cache_hit( Args: block_hashes: The block hashes of the request. max_length: The maximum length of the cache hit prefix. + kv_cache_group_ids: The ids of the kv cache groups. Returns: A list of cached blocks with skipped blocks replaced by null block. From 84f27b9f5c9466fbd20ca0a94b70091ed97b5025 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 1 Jun 2025 09:19:46 -0700 Subject: [PATCH 20/44] update code inside vllm/core Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 67 ++++++++++----- vllm/v1/core/kv_cache_manager.py | 25 +++--- vllm/v1/core/sched/scheduler.py | 7 +- vllm/v1/core/single_type_kv_cache_manager.py | 90 +++++++++++++------- 4 files changed, 122 insertions(+), 67 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 04cd95292fa3..90cd9cbb1ff2 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -13,7 +13,7 @@ class KVCacheCoordinator: """ - Coordinator the KV cache of different KV cache groups. + Coordinate the KV cache of different KV cache groups. """ def __init__( @@ -31,6 +31,10 @@ def __init__( self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events) self.single_type_managers: list[SingleTypeKVCacheManager] = [] + + # Needs special handling for find_longest_cache_hit if eagle is enabled + self.use_eagle = use_eagle + for i in range(len(self.kv_cache_config.kv_cache_groups)): kv_cache_spec = self.kv_cache_config.kv_cache_groups[ i].kv_cache_spec @@ -38,7 +42,6 @@ def __init__( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, - use_eagle=use_eagle, kv_cache_group_id=i, caching_hash_fn=caching_hash_fn, )) @@ -172,10 +175,11 @@ def find_longest_cache_hit( pass -class UnifiedKVCacheCoordinator(KVCacheCoordinator): +class SingleGroupKVCacheCoordinator(KVCacheCoordinator): """ - KV cache coordinator for unified models with only one KV cache type, and - thus one kv cache group. + KV cache coordinator for models with only one KV cache group. This is the + case for models with only one KV cache type, e.g., all attention layers use + full attention or all attention layers use sliding window attention. """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, @@ -184,8 +188,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, super().__init__(kv_cache_config, max_model_len, use_eagle, enable_caching, caching_hash_fn, enable_kv_cache_events) - self.block_size = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size + self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec + self.block_size = self.kv_cache_spec.block_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnifiedKVCacheCoordinator assumes only one kv cache group") @@ -193,7 +198,13 @@ def find_longest_cache_hit( self, block_hashes: list[BlockHashType], max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: hit_blocks = self.single_type_managers[0].find_longest_cache_hit( - block_hashes, max_cache_hit_length, [0]) + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=[0], + block_pool=self.block_pool, + kv_cache_spec=self.kv_cache_spec, + use_eagle=self.use_eagle, + ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -239,10 +250,13 @@ def verify_and_split_kv_cache_groups(self) -> None: self.other_group_ids = groups_by_type_id[next( iter(groups_by_type_id.keys() - full_attention_type_ids))] - self.full_attention_block_size = self.kv_cache_config.kv_cache_groups[ - self.full_attention_group_ids[0]].kv_cache_spec.block_size - self.other_block_size = self.kv_cache_config.kv_cache_groups[ - self.other_group_ids[0]].kv_cache_spec.block_size + self.full_attention_spec = self.kv_cache_config.kv_cache_groups[ + self.full_attention_group_ids[0]].kv_cache_spec + self.other_spec = self.kv_cache_config.kv_cache_groups[ + self.other_group_ids[0]].kv_cache_spec + + self.full_attention_block_size = self.full_attention_spec.block_size + self.other_block_size = self.other_spec.block_size if self.other_block_size % self.full_attention_block_size != 0: raise NotImplementedError( "KVCacheCoordinator assumes the block_size of the full " @@ -267,19 +281,28 @@ def find_longest_cache_hit( """ # First, find the longest cache hit for full attention. hit_blocks_full_attn = self.single_type_managers[ - 0].find_longest_cache_hit( - block_hashes, + self.full_attention_group_ids[0]].find_longest_cache_hit( + block_hashes=block_hashes, max_length=max_cache_hit_length, - kv_cache_group_ids=self.full_attention_group_ids) + kv_cache_group_ids=self.full_attention_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.full_attention_spec, + use_eagle=self.use_eagle, + ) hit_length = len( hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. hit_blocks_other_attn = self.single_type_managers[ - 1].find_longest_cache_hit(block_hashes, - max_length=hit_length, - kv_cache_group_ids=self.other_group_ids) + self.other_group_ids[0]].find_longest_cache_hit( + block_hashes=block_hashes, + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.other_spec, + use_eagle=self.use_eagle, + ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size assert hit_length % self.full_attention_block_size == 0 @@ -304,10 +327,10 @@ def get_kv_cache_coordinator( enable_caching: bool, caching_hash_fn: Callable, enable_kv_cache_events: bool) -> KVCacheCoordinator: if len(kv_cache_config.kv_cache_groups) == 1: - return UnifiedKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - caching_hash_fn, - enable_kv_cache_events) + return SingleGroupKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, + caching_hash_fn, + enable_kv_cache_events) else: return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 7e4066f4a6af..09c75e15eb2a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -2,7 +2,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -23,7 +23,6 @@ class KVCacheBlocks: """ blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. """ - num_kv_cache_groups: ClassVar[int] def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" @@ -44,14 +43,9 @@ def get_block_ids(self) -> list[list[int]]: block_ids.append([blk.block_id for blk in group]) return block_ids - @classmethod - def create_empty(cls) -> "KVCacheBlocks": - """Creates a new KVCacheBlocks instance with no blocks.""" - return cls([[] for _ in range(cls.num_kv_cache_groups)]) - def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" - assert self.num_kv_cache_groups == 1, "Only one group is supported" + assert len(self.blocks) == 1, "Only one group is supported" return [ block.block_id for block in self.blocks[0] if block.block_hash is None @@ -93,8 +87,7 @@ def __init__( caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, ) - KVCacheBlocks.num_kv_cache_groups = len( - kv_cache_config.kv_cache_groups) + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config @@ -142,7 +135,7 @@ def get_computed_blocks(self, # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching or request.sampling_params.prompt_logprobs is not None): - return KVCacheBlocks.create_empty(), 0 + return self.create_empty_block_list(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. @@ -374,3 +367,13 @@ def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" return KVCacheBlocks( self.coordinator.get_blocks(request_id)).get_block_ids() + + def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + num_computed_tokens: int) -> None: + """Cache the blocks for the request.""" + self.coordinator.cache_blocks(request, block_hashes, + num_computed_tokens) + + def create_empty_block_list(self) -> KVCacheBlocks: + """Creates a new KVCacheBlocks instance with no blocks.""" + return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)]) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e002f5bcb679..facdc2f71b90 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -352,7 +352,8 @@ def schedule(self) -> SchedulerOutput: request) else: # P/D: skip checking prefix cache if loaded from remote kvs. - new_computed_blocks = KVCacheBlocks.create_empty() + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) num_native_computed_tokens = 0 # Get externally-cached tokens if using a KVConnector. @@ -966,7 +967,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: num_computed_tokens = len(block_ids) * self.block_size if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 - self.kv_cache_manager.coordinator.cache_blocks( + self.kv_cache_manager.cache_blocks( request, self.kv_cache_manager.req_to_block_hashes[request.request_id], num_computed_tokens, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 58148e1c8e3a..eb94fdd211a4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -21,7 +21,6 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, - use_eagle: bool, kv_cache_group_id: int, caching_hash_fn: Callable, ) -> None: @@ -30,7 +29,6 @@ def __init__( Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. - use_eagle: Whether to use eagle. kv_cache_group_id: The id of the kv cache group of this manager. caching_hash_fn: The caching hash function. """ @@ -39,9 +37,6 @@ def __init__( self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool - # Needs special handling for find_longest_cache_hit if eagle is enabled - self.use_eagle = use_eagle - # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. @@ -190,10 +185,17 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError + @classmethod @abstractmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHashType], max_length: int, - kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: + cls, + block_hashes: list[BlockHashType], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. The prefix should be a common prefix hit for all the @@ -207,12 +209,19 @@ def find_longest_cache_hit( block_hashes: The block hashes of the request. max_length: The maximum length of the cache hit prefix. kv_cache_group_ids: The ids of the kv cache groups. + block_pool: The block pool. + kv_cache_spec: The kv cache spec. + use_eagle: Whether to use eagle. Returns: - A list of cached blocks with skipped blocks replaced by null block. + A list of cached blocks with skipped blocks replaced by null block + for each kv cache group in `kv_cache_group_ids`. + Return a list of length `len(kv_cache_group_ids)`, where the i-th + element is a list of cached blocks for the i-th kv cache group + in `kv_cache_group_ids`. For example, sliding window manager should return a list like - [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and - sliding window 8. + [[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4 + and sliding window 8 and len(kv_cache_group_ids) = 1. """ raise NotImplementedError @@ -235,25 +244,34 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): + @classmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHashType], max_length: int, - kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: + cls, + block_hashes: list[BlockHashType], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: + assert isinstance(kv_cache_spec, FullAttentionSpec), ( + "FullAttentionManager can only be used for full attention groups") computed_blocks: list[list[KVCacheBlock]] = [ [] for _ in range(len(kv_cache_group_ids)) ] - max_num_blocks = max_length // self.block_size + max_num_blocks = max_length // kv_cache_spec.block_size for i in range(max_num_blocks): block_hash = block_hashes[i] # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block( + if cached_block := block_pool.get_cached_block( block_hash, kv_cache_group_ids): for j in range(len(kv_cache_group_ids)): computed_blocks[j].append(cached_block[j]) else: break - if self.use_eagle and len(computed_blocks) > 0: + if use_eagle and len(computed_blocks) > 0: for j in range(len(kv_cache_group_ids)): computed_blocks[j].pop() return computed_blocks @@ -278,43 +296,53 @@ def get_num_common_prefix_blocks(self, request_id: str, class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - use_eagle: bool, **kwargs) -> None: - super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs) + **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, **kwargs) self.sliding_window = kv_cache_spec.sliding_window + self._null_block = block_pool.null_block + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHashType], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: + assert isinstance(kv_cache_spec, SlidingWindowSpec), ( + "SlidingWindowManager can only be used for sliding window groups") + # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window - self.sliding_window_contiguous_blocks = cdiv( - (kv_cache_spec.sliding_window - 1), self.block_size) - if self.use_eagle: + sliding_window_contiguous_blocks = cdiv( + kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size) + if use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of # contiguous blocks needed for prefix cache hit by one and dropping # the last matched block. - self.sliding_window_contiguous_blocks += 1 - self._null_block = block_pool.null_block + sliding_window_contiguous_blocks += 1 - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType], max_length: int, - kv_cache_group_ids: list[int]) -> list[list[KVCacheBlock]]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(len(block_hashes)) to # O(len(block_hashes) / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - max_num_blocks = max_length // self.block_size - computed_blocks = [[self._null_block] * max_num_blocks + max_num_blocks = max_length // kv_cache_spec.block_size + computed_blocks = [[block_pool.null_block] * max_num_blocks for _ in range(len(kv_cache_group_ids))] num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): - if cached_block := self.block_pool.get_cached_block( + if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids): for j in range(len(kv_cache_group_ids)): computed_blocks[j][i] = cached_block[j] num_contiguous_blocks += 1 - if (num_contiguous_blocks - >= self.sliding_window_contiguous_blocks): + if (num_contiguous_blocks >= sliding_window_contiguous_blocks): # Trim the trailing blocks. # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. @@ -329,7 +357,7 @@ def find_longest_cache_hit( # `num_contiguous_blocks < sliding_window_contiguous_blocks`. for j in range(len(kv_cache_group_ids)): del computed_blocks[j][num_contiguous_blocks:] - if self.use_eagle and len(computed_blocks) > 0: + if use_eagle and len(computed_blocks) > 0: for j in range(len(kv_cache_group_ids)): computed_blocks[j].pop() return computed_blocks From 9c0480242909b625d7c20ab018616a2d6023f85a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 1 Jun 2025 10:25:30 -0700 Subject: [PATCH 21/44] update kv cache init Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 118 ++++++++++++++++++++--------- vllm/v1/kv_cache_interface.py | 28 ++----- vllm/v1/worker/gpu_model_runner.py | 27 ++++--- vllm/v1/worker/tpu_model_runner.py | 49 ++++++------ 4 files changed, 125 insertions(+), 97 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3723c4425311..4725755eb6e4 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -11,9 +11,8 @@ from vllm.logger import init_logger from vllm.utils import GiB_bytes, sha256 from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheNewTensor, - KVCacheReuseTensor, KVCacheSpec, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheSpec, + KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -612,6 +611,38 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: return len(layer_keys) == 1 +def get_num_blocks(vllm_config: VllmConfig, num_layers: int, + available_memory: int, page_size: int) -> int: + """ + Get the number of kv cache blocks. + + Args: + vllm_config: The global VllmConfig + num_layers: The number of layers + available_memory: Memory available for KV cache in bytes. + page_size: The page size of the KV cache. + + """ + num_blocks = int(available_memory // page_size // num_layers) + num_blocks = max(num_blocks, 0) + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + return num_blocks + + +def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: + """ + Get the page size of the KV cache. + """ + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + assert len(page_sizes) == 1 + return page_sizes.pop() + + def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: @@ -628,12 +659,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, The generated KVCacheConfig """ - page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} - assert len(page_sizes) == 1 - page_size = page_sizes.pop() - - num_blocks = int(available_memory // page_size // len(kv_cache_spec)) - num_blocks = max(num_blocks, 0) + page_size = get_uniform_page_size(kv_cache_spec) + num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), + available_memory, page_size) if vllm_config.cache_config.num_gpu_blocks_override is not None: num_gpu_blocks_override = \ @@ -647,7 +675,6 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = num_tokens / vllm_config.model_config.max_model_len - # TODO: fix for hybrid allocator logger.info( "Maximum concurrency for %s tokens per request: %.2fx", max_model_len_str, @@ -659,12 +686,15 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, # for all layers. grouped_layer_names = [list(kv_cache_spec.keys())] + # Each layer uses a separate Tensor to store its KV cache. + kv_cache_tensors = [ + KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) + for layer_name in kv_cache_spec + ] + kv_cache_config = KVCacheConfig( num_blocks=num_blocks, - tensors={ - layer_name: KVCacheNewTensor(size=per_layer_size) - for layer_name in kv_cache_spec - }, + kv_cache_tensors=kv_cache_tensors, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, grouped_layer_names), ) @@ -722,32 +752,48 @@ def _get_kv_cache_config_uniform_page_size( ) for i in range(0, len(layers), group_size): grouped_layers.append(layers[i:i + group_size]) + kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec, + grouped_layers) - # Divide the available memory equally among all layers in the first group. - # The memory layout in the example will be: - # full.0: Tensor with size=available_memory//2 - # full.1: Tensor with size=available_memory//2 - kv_cache_spec_first_group = { - layer_name: kv_cache_spec[layer_name] - for layer_name in grouped_layers[0] - } - kv_cache_config = _get_kv_cache_config_uniform_type( - vllm_config, kv_cache_spec_first_group, available_memory) - - # Reuse the KV cache tensors of the first group for the other groups. + # Determine how model runners should initialize the KV cache tensors. + # We will have group_size memory pools, each is shared by one layer from + # each group. As layers of different groups have different block table, + # they will use different parts of the shared Tensor. # The memory layout in the example will be: # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 # full.1, sw.1: share another Tensor with size=available_memory//2 - # Layers of different groups have different block table, so they will - # use different parts of the shared Tensor. - for layers in grouped_layers[1:]: - for layer_name, layer_name_first_group in zip( - layers, grouped_layers[0][:len(layers)]): - kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( - reused_layer_name=layer_name_first_group) - - kv_cache_config.kv_cache_groups = create_kv_cache_group_specs( - kv_cache_spec, grouped_layers) + page_size = get_uniform_page_size(kv_cache_spec) + num_blocks = get_num_blocks(vllm_config, group_size, available_memory, + page_size) + per_memory_pool_size = page_size * num_blocks + kv_cache_tensors = [] + for i in range(group_size): + shared_by = [] + for j in range(len(kv_cache_groups)): + if i < len(grouped_layers[j]): + shared_by.append(grouped_layers[j][i]) + kv_cache_tensors.append( + KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) + + # Print the KV cache size and maximum concurrency. + # TODO in this PR: Now just copy from the uniform type implementation. + # Should reimplement this for hybrid model + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = num_tokens / vllm_config.model_config.max_model_len + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=kv_cache_tensors, + kv_cache_groups=kv_cache_groups, + ) return kv_cache_config diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f058732706a2..c14d2caa63c8 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -154,28 +154,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass -class KVCacheTensorBase: +class KVCacheTensor: """ - A dataclass for specifying how the workers should initialize the KV cache - for a layer. + A class for specifying the KV cache tensor. """ - pass - - -@dataclass -class KVCacheNewTensor(KVCacheTensorBase): - """ - Initialize the KV cache with a tensor of `size` bytes. - """ - size: int # The size of KV cache Tensor in bytes - - -@dataclass -class KVCacheReuseTensor(KVCacheTensorBase): - """ - Reuse the KV cache tensor of `layer_name` for the current layer. - """ - reused_layer_name: str + size: int # size of the KV cache tensor in bytes + shared_by: list[str] # layer names that share the same KV cache tensor @dataclass @@ -197,8 +181,8 @@ class KVCacheConfig: """ """The number of KV cache blocks""" num_blocks: int - """layer_name -> how to initialize KV cache for that layer""" - tensors: dict[str, KVCacheTensorBase] + """How should model runner initialize the KV cache tensors for each layer""" + kv_cache_tensors: list[KVCacheTensor] """ The kv cache groups of the model. The layers in the models are repeated with some patterns, e.g., a model diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a631a65b1876..674cacf73173 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -39,8 +39,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheNewTensor, - KVCacheReuseTensor, KVCacheSpec, + KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -1881,18 +1880,18 @@ def _initialize_kv_cache_buffer( corresponding memory buffer for KV cache. """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} - for layer_name, tensor_config in kv_cache_config.tensors.items(): - if isinstance(tensor_config, KVCacheNewTensor): - # A new tensor with `tensor_config.size` bytes - kv_cache_raw_tensors[layer_name] = torch.zeros( - tensor_config.size, dtype=torch.int8, device=self.device) - for layer_name, tensor_config in kv_cache_config.tensors.items(): - if isinstance(tensor_config, KVCacheReuseTensor): - # Reuse a tensor from `kv_cache_raw_tensors` - kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ - tensor_config.reused_layer_name] - assert len(kv_cache_raw_tensors) == len( - kv_cache_config.tensors), "Some layers are not initialized" + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + for layer_name in kv_cache_tensor.shared_by: + kv_cache_raw_tensors[layer_name] = tensor + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + layer_names.update(group.layer_names) + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" return kv_cache_raw_tensors def _setup_kv_cache_shapes( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index cf701f7c1fe9..1b61ec64cc14 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -27,11 +27,9 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasMetadata) +from vllm.v1.attention.backends.pallas import PallasMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheNewTensor, +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -1267,27 +1265,28 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: 0].get_cpu_tensor().dtype kv_caches: dict[str, torch.Tensor] = {} - - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert isinstance(tensor_config, KVCacheNewTensor) - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype - - tpu_kv_cache = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) - - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError + # TODO in this PR: update to the new kv cache config interface. + # comment out temporarily to pass type checker + # for kv_cache_group in kv_cache_config.kv_cache_groups: + # kv_cache_spec = kv_cache_group.kv_cache_spec + # for layer_name in kv_cache_group.layer_names: + # tensor_config = kv_cache_config.tensors[layer_name] + # assert isinstance(tensor_config, KVCacheNewTensor) + # assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 + # num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes # noqa + # if isinstance(kv_cache_spec, AttentionSpec): + # kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( # noqa + # num_blocks, kv_cache_spec.block_size, + # kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + # dtype = kv_cache_spec.dtype + + # tpu_kv_cache = torch.zeros(kv_cache_shape, + # dtype=dtype, + # device=self.device) + + # kv_caches[layer_name] = tpu_kv_cache + # else: + # raise NotImplementedError bind_kv_cache( kv_caches, From 72c2671717c318520c754c1720ce00192109ff90 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 1 Jun 2025 10:26:04 -0700 Subject: [PATCH 22/44] coordinator Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 90cd9cbb1ff2..6d804a42a67b 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,7 +6,8 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - SingleTypeKVCacheManager, get_manager_for_kv_cache_spec) + FullAttentionManager, SingleTypeKVCacheManager, + get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.request import Request @@ -255,6 +256,14 @@ def verify_and_split_kv_cache_groups(self) -> None: self.other_spec = self.kv_cache_config.kv_cache_groups[ self.other_group_ids[0]].kv_cache_spec + self.full_attention_manager_cls = FullAttentionManager + other_attention_clses = set(self.single_type_managers[i].__class__ + for i in self.other_group_ids) + assert len(other_attention_clses) == 1, ( + "KVCacheCoordinator assumes all other groups have the same " + "attention manager class now.") + self.other_attention_cls = next(iter(other_attention_clses)) + self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size if self.other_block_size % self.full_attention_block_size != 0: @@ -280,29 +289,29 @@ def find_longest_cache_hit( - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. - hit_blocks_full_attn = self.single_type_managers[ - self.full_attention_group_ids[0]].find_longest_cache_hit( + hit_blocks_full_attn = ( + self.full_attention_manager_cls.find_longest_cache_hit( block_hashes=block_hashes, max_length=max_cache_hit_length, kv_cache_group_ids=self.full_attention_group_ids, block_pool=self.block_pool, kv_cache_spec=self.full_attention_spec, use_eagle=self.use_eagle, - ) + )) hit_length = len( hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. - hit_blocks_other_attn = self.single_type_managers[ - self.other_group_ids[0]].find_longest_cache_hit( + hit_blocks_other_attn = ( + self.other_attention_cls.find_longest_cache_hit( block_hashes=block_hashes, max_length=hit_length, kv_cache_group_ids=self.other_group_ids, block_pool=self.block_pool, kv_cache_spec=self.other_spec, use_eagle=self.use_eagle, - ) + )) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size assert hit_length % self.full_attention_block_size == 0 From 05f3406026efe7313d51e7992abf93f748a7b380 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 1 Jun 2025 20:22:59 -0700 Subject: [PATCH 23/44] add some notes Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 9 +++++++++ vllm/v1/core/kv_cache_utils.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 09c75e15eb2a..278aa51e1bf3 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -19,9 +19,18 @@ @dataclass class KVCacheBlocks: + """ + The allocation result of KVCacheManager, work as the interface between + Scheduler and KVCacheManager, to hide KVCacheManager's internal data + structure from the Scheduler. + """ blocks: list[list[KVCacheBlock]] """ blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. + We don't use block of tokens as the outer dimension because it assumes all + kv_cache_groups have the same number of blocks, which is true for now but + will be broken if we want to give different block_size to different + kv_cache_groups in the future. """ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 4725755eb6e4..49646181a5db 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -740,6 +740,13 @@ def _get_kv_cache_config_uniform_page_size( # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) # split to 3 groups with 2 layers each: # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + # FIXME(Chen): At the moment of writing this code (2025-06-02), all + # open-source hybrid model follows a n:1 pattern between different attention + # types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and + # full), so we can use the "1" in the n:1 pattern as the group size, which + # is the minimum number of layers among all attention types. Need a better + # strategy if we want to support more complex patterns (e.g., 20 full + 30 + # sw, where the group size should be 10). group_size = min([len(layers) for layers in same_type_layers.values()]) grouped_layers = [] for layers in same_type_layers.values(): From 3019f1b69522e06e87950a54d6cbe8dc6467862b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 00:53:40 -0700 Subject: [PATCH 24/44] add notes about assumptions Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 49646181a5db..1979d7d1fc58 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -720,7 +720,36 @@ def _get_kv_cache_config_uniform_page_size( vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ - Generates the KV cache configuration for a model with one page size. + Generates the KV cache configuration for models with a uniform page size. + + NOTE(Chen): To simplify the kv cache management logic for hybrid models, we + make the following assumptions: + 1. Physical memory per block: Must be the same across all KV cache groups. + Breaking this assumption is non-trivial due to memory fragmentation concerns + when allocating blocks of different sizes. + 2. Tokens per block (block_size): currently, we directly use + `CacheConfig.block_size` for all layers. It can be extended to vary by KV + cache group, but within each KV cache group, all layers must share the same + block size. + 3. Physical memory per token per layer: This property is decided by model + config. Currently we only support models that have the same physical memory + per token per layer for all layers. Can be relaxed with a simple extension, + but still need to keep physical memory per block per group the same. + 4. Number of layers per group: Currently assumed the same for all layers. + Can be relaxed with a simple extension, but still need to keep byte per + block per group the same. + 5. Attention type within groups: All layers in a group must share the same + attention type. One exception is that, when + `--disable-hybrid-kv-cache-manager` is true, the single group for full + attention layers may also include attention layers using sliding window or + LLaMA 4 local attention. + 6. Support for multiple attention types: The design for most components is + general to an arbitrary number of attention types. But + `find_longest_cache_hit` only supports one attention type or two + types of full-attention plus exactly one another type. The general + implementation of this function is feasible but we don't know how to + implement it cleanly yet. + Args: vllm_config: The global VllmConfig kv_cache_spec: The KVCacheSpec of each attention layer in the model From ca6a00b89e3a01e90f8bbef37954499900265aac Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 01:12:35 -0700 Subject: [PATCH 25/44] simplify verify_and_split_kv_cache_groups Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 62 +++++++++++++++------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 6d804a42a67b..53f5e818ca9f 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod -from collections import defaultdict -from typing import Callable +from typing import Callable, Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock @@ -232,44 +231,49 @@ def verify_and_split_kv_cache_groups(self) -> None: one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ - groups_by_type_id: dict[str, list[int]] = defaultdict(list) - full_attention_type_ids: set[str] = set() + full_attention_type_id: Optional[str] = None + other_type_id: Optional[str] = None + self.full_attention_group_ids: list[int] = [] + self.other_group_ids: list[int] = [] for i, g in enumerate(self.kv_cache_config.kv_cache_groups): - groups_by_type_id[g.kv_cache_spec.type_id].append(i) if isinstance(g.kv_cache_spec, FullAttentionSpec): - full_attention_type_ids.add(g.kv_cache_spec.type_id) + if full_attention_type_id is None: + full_attention_type_id = g.kv_cache_spec.type_id + else: + assert full_attention_type_id == g.kv_cache_spec.type_id, ( + "HybridKVCacheCoordinator assumes exactly one type of " + "full attention groups now.") + self.full_attention_group_ids.append(i) + else: + if other_type_id is None: + other_type_id = g.kv_cache_spec.type_id + else: + assert other_type_id == g.kv_cache_spec.type_id, ( + "HybridKVCacheCoordinator assumes " + "exactly one other type of groups now.") + self.other_group_ids.append(i) + + assert full_attention_type_id is not None, ( + "HybridKVCacheCoordinator assumes exactly one type of full " + "attention groups now.") + assert other_type_id is not None, ( + "HybridKVCacheCoordinator assumes exactly one type of other " + "groups now.") - assert len(full_attention_type_ids) == 1, ( - "find_longest_cache_hit assumes hybrid models have exactly " - "one type of full attention groups now") - assert len(groups_by_type_id) == 2, ( - "find_longest_cache_hit assumes hybrid models have exactly " - "one other type of groups except full attention now") - - self.full_attention_group_ids = groups_by_type_id[next( - iter(full_attention_type_ids))] - self.other_group_ids = groups_by_type_id[next( - iter(groups_by_type_id.keys() - full_attention_type_ids))] + self.full_attention_manager_cls = FullAttentionManager + self.other_attention_cls = self.single_type_managers[ + self.other_group_ids[0]].__class__ self.full_attention_spec = self.kv_cache_config.kv_cache_groups[ self.full_attention_group_ids[0]].kv_cache_spec self.other_spec = self.kv_cache_config.kv_cache_groups[ self.other_group_ids[0]].kv_cache_spec - self.full_attention_manager_cls = FullAttentionManager - other_attention_clses = set(self.single_type_managers[i].__class__ - for i in self.other_group_ids) - assert len(other_attention_clses) == 1, ( - "KVCacheCoordinator assumes all other groups have the same " - "attention manager class now.") - self.other_attention_cls = next(iter(other_attention_clses)) - self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - if self.other_block_size % self.full_attention_block_size != 0: - raise NotImplementedError( - "KVCacheCoordinator assumes the block_size of the full " - "attention layer is divisible by other layers now.") + assert self.other_block_size % self.full_attention_block_size == 0, ( + "KVCacheCoordinator assumes the block_size of the full " + "attention layer is divisible by other layers now.") def find_longest_cache_hit( self, From 15b4449d2e58d66c4d7a941fa8394b06b4200ee0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 02:02:20 -0700 Subject: [PATCH 26/44] update explaination Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 51 ++++++++++++++++++++++++++-------- vllm/v1/kv_cache_interface.py | 22 +++------------ 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 1979d7d1fc58..e8f832f85f27 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -720,29 +720,51 @@ def _get_kv_cache_config_uniform_page_size( vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: """ - Generates the KV cache configuration for models with a uniform page size. - - NOTE(Chen): To simplify the kv cache management logic for hybrid models, we - make the following assumptions: + Generates the KV cache configuration for hybrid models with multiple + attention types but still with a uniform page size (physical memory per + block per layer) for all layers. + + Detailed explanation about kv cache management of hybrid models: + The layers in the models are repeated with some patterns, e.g., a model + with 10 full attention layers and 20 sliding window attention layers can be + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + The KVCacheManager allocates different block tables for each of the 3 layers + in the pattern, and repeats each of them 10 times to generate the + block_table for the 30 layers in the model. + Therefore, we can group the layers in the model into 3 kv_cache_groups, each + of which contains 10 layers in the model. + The KVCacheManager allocates the block_table for each group based on its + kv_cache spec, and the model runner applies the block table to each layer + in the group. + For example: + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. It is already handled by + `_get_kv_cache_config_uniform_type`. + 2. A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + there are 3 kv_cache_groups, each of which represents 10 layers. + + To simplify the implementation, we make the following assumptions: 1. Physical memory per block: Must be the same across all KV cache groups. Breaking this assumption is non-trivial due to memory fragmentation concerns when allocating blocks of different sizes. - 2. Tokens per block (block_size): currently, we directly use + 2. Tokens per block (block_size): Currently, we directly use `CacheConfig.block_size` for all layers. It can be extended to vary by KV cache group, but within each KV cache group, all layers must share the same block size. 3. Physical memory per token per layer: This property is decided by model config. Currently we only support models that have the same physical memory per token per layer for all layers. Can be relaxed with a simple extension, - but still need to keep physical memory per block per group the same. + but still need to keep physical memory per block the same for all groups. 4. Number of layers per group: Currently assumed the same for all layers. - Can be relaxed with a simple extension, but still need to keep byte per - block per group the same. + Can be relaxed with a simple extension, but still need to keep physical + memory per block the same for all groups. 5. Attention type within groups: All layers in a group must share the same attention type. One exception is that, when `--disable-hybrid-kv-cache-manager` is true, the single group for full attention layers may also include attention layers using sliding window or - LLaMA 4 local attention. + LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details. 6. Support for multiple attention types: The design for most components is general to an arbitrary number of attention types. But `find_longest_cache_hit` only supports one attention type or two @@ -750,6 +772,10 @@ def _get_kv_cache_config_uniform_page_size( implementation of this function is feasible but we don't know how to implement it cleanly yet. + As we assume tokens per block, physical memory per token per layer, and + number of layers per group are the same now, we can ensure that physical + memory per block is the same for all groups. + Args: vllm_config: The global VllmConfig kv_cache_spec: The KVCacheSpec of each attention layer in the model @@ -903,9 +929,10 @@ def get_kv_cache_config( return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) elif is_kv_cache_page_size_uniform(kv_cache_spec): - # KV cache of all layers have the same page size. Split the layers into - # groups with the same number of layers, and thus same total page size. - # See KVCacheConfig.kv_cache_groups for more details. + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec, available_memory) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c14d2caa63c8..0c65110a8d01 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -185,23 +185,9 @@ class KVCacheConfig: kv_cache_tensors: list[KVCacheTensor] """ The kv cache groups of the model. - The layers in the models are repeated with some patterns, e.g., a model - with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. - The KVCacheManager allocates different block tables for each of the 3 layers - in the pattern, and repeats each of them 10 times to generate the - block_table for the 30 layers in the model. - Therefore, we can group the layers in the model into 3 groups, each of which - contains 10 layers in the model. - The KVCacheManager allocates the block_table for each group based on its - kv_cache spec, and the model runner applies the block table to each layer - in the group. - For example: - 1. A model only uses full attention. The pattern is - (num_hidden_layers * full), so there is only one group and the block table - is shared by all layers. - 2. A model with 10 full attention layers and 20 sliding window - attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so - there are 3 groups, each of which represents 10 layers in the model. + For models with only one type of attention, there is only one group that + contains all layers. + For models with multiple types of attention, there will be multiple groups, + see `_get_kv_cache_config_uniform_page_size` for more details. """ kv_cache_groups: list[KVCacheGroupSpec] From 1a862f9e40e3a657fc43012438b41bc3d4c83feb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 02:17:29 -0700 Subject: [PATCH 27/44] BlockHashType->BlockHash Signed-off-by: Chen Zhang --- docs/source/design/v1/prefix_caching.md | 2 +- tests/v1/core/test_kv_cache_utils.py | 10 +++++----- tests/v1/core/test_specialized_manager.py | 4 ++-- vllm/v1/core/block_pool.py | 13 ++++++------- vllm/v1/core/kv_cache_coordinator.py | 10 +++++----- vllm/v1/core/kv_cache_manager.py | 6 +++--- vllm/v1/core/kv_cache_utils.py | 18 +++++++++--------- vllm/v1/core/single_type_kv_cache_manager.py | 10 +++++----- 8 files changed, 36 insertions(+), 37 deletions(-) diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index 0f7475777797..b898611dec3a 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -104,7 +104,7 @@ class KVCacheBlock: block_id: int # The block hash (will be assigned when the block is full, # and will be reset when the block is evicted). - block_hash: BlockHashType + block_hash: BlockHash # The number of requests using this block now. ref_cnt: int diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 6ee4f2d46b19..85ddb0c21aa2 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable -from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, +from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, @@ -78,7 +78,7 @@ def test_kv_cache_block(): assert block.block_hash is None # Test block hash setting and resetting - block_hash = BlockHashType(hash_value=123, token_ids=(1, 2, 3)) + block_hash = BlockHash(hash_value=123, token_ids=(1, 2, 3)) block.block_hash = block_hash assert block.block_hash == block_hash @@ -258,7 +258,7 @@ def test_hash_block_tokens(hash_fn): block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - assert isinstance(block_hash, BlockHashType) + assert isinstance(block_hash, BlockHash) assert block_hash.hash_value == hash_fn( (parent_block_hash, curr_block_token_ids, extra_keys)) assert block_hash.token_ids == curr_block_token_ids @@ -281,8 +281,8 @@ def test_hash_request_tokens(hash_fn): block_hashes = hash_request_tokens(hash_fn, block_size, request) assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], BlockHashType) - assert isinstance(block_hashes[1], BlockHashType) + assert isinstance(block_hashes[0], BlockHash) + assert isinstance(block_hashes[1], BlockHash) # Check the first block assert block_hashes[0].token_ids == (0, 1, 2) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 49ba2e15454c..82d1f3a7da72 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,7 +3,7 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, KVCacheBlockBundle) from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec @@ -36,7 +36,7 @@ def test_sliding_window_possible_cached_prefix(): def run_one_case(block_is_cached, expect_length): block_hash_list = [ - BlockHashType(i, ()) for i in range(len(block_is_cached)) + BlockHash(i, ()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block[0].clear() diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 2638a08e6b22..f637c2cbad11 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -6,8 +6,7 @@ from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHashType, - BlockHashTypeWithGroupId, +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) @@ -57,7 +56,7 @@ def __init__( # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashTypeWithGroupId, dict[ + self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[ int, KVCacheBlock]] = defaultdict(dict) # To represent a placeholder block with block_id=0. @@ -69,7 +68,7 @@ def __init__( self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block( - self, block_hash: BlockHashType, + self, block_hash: BlockHash, kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. @@ -85,7 +84,7 @@ def get_cached_block( cached_blocks = [] for group_id in kv_cache_group_ids: cached_blocks_one_group = self.cached_block_hash_to_block[ - BlockHashTypeWithGroupId(block_hash, group_id)] + BlockHashWithGroupId(block_hash, group_id)] if not cached_blocks_one_group: return None first_block_id = next(iter(cached_blocks_one_group)) @@ -96,7 +95,7 @@ def cache_full_blocks( self, request: Request, blocks: list[KVCacheBlock], - block_hashes: list[BlockHashType], + block_hashes: list[BlockHash], num_cached_blocks: int, num_full_blocks: int, block_size: int, @@ -176,7 +175,7 @@ def cache_full_blocks( block_hashes.append(block_hash) # Update and added the full block to the cache. - block_hash_with_group_id = BlockHashTypeWithGroupId( + block_hash_with_group_id = BlockHashWithGroupId( block_hash, kv_cache_group_id) blk.block_hash = block_hash_with_group_id self.cached_block_hash_to_block[block_hash_with_group_id][ diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 53f5e818ca9f..69814cccacda 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -3,7 +3,7 @@ from typing import Callable, Optional from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( FullAttentionManager, SingleTypeKVCacheManager, get_manager_for_kv_cache_spec) @@ -103,7 +103,7 @@ def allocate_new_blocks(self, request_id: str, manager.allocate_new_blocks(request_id, num_tokens)) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], num_computed_tokens: int) -> None: """ Cache the blocks for the request. @@ -170,7 +170,7 @@ def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: @abstractmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHashType], + self, block_hashes: list[BlockHash], max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: pass @@ -195,7 +195,7 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, "UnifiedKVCacheCoordinator assumes only one kv cache group") def find_longest_cache_hit( - self, block_hashes: list[BlockHashType], + self, block_hashes: list[BlockHash], max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: hit_blocks = self.single_type_managers[0].find_longest_cache_hit( block_hashes=block_hashes, @@ -277,7 +277,7 @@ def verify_and_split_kv_cache_groups(self) -> None: def find_longest_cache_hit( self, - block_hashes: list[BlockHashType], + block_hashes: list[BlockHash], max_cache_hit_length: int, ) -> tuple[list[list[KVCacheBlock]], int]: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 278aa51e1bf3..69f477b30c30 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -8,7 +8,7 @@ from vllm.logger import init_logger from vllm.utils import sha256 from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, hash_request_tokens) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats @@ -104,7 +104,7 @@ def __init__( # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + str, list[BlockHash]] = defaultdict(list) @property def usage(self) -> float: @@ -377,7 +377,7 @@ def get_block_ids(self, request_id: str) -> list[list[int]]: return KVCacheBlocks( self.coordinator.get_blocks(request_id)).get_block_ids() - def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], num_computed_tokens: int) -> None: """Cache the blocks for the request.""" self.coordinator.cache_blocks(request, block_hashes, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index e8f832f85f27..1639f5cdf177 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -19,7 +19,7 @@ logger = init_logger(__name__) -class BlockHashType(NamedTuple): +class BlockHash(NamedTuple): """Hash value of a block (int), the token IDs in the block, and extra keys. We keep a tuple of token IDs and extra keys to reduce the likelihood of hash collisions when the hash value is the same. By using SHA256 however, @@ -34,8 +34,8 @@ class BlockHashType(NamedTuple): extra_keys: Optional[Any] = None -class BlockHashTypeWithGroupId(NamedTuple): - block_hash: BlockHashType +class BlockHashWithGroupId(NamedTuple): + block_hash: BlockHash group_id: int def get_hash_value(self) -> int: @@ -127,7 +127,7 @@ class KVCacheBlock: ref_cnt: int = 0 # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. - _block_hash: Optional[BlockHashTypeWithGroupId] = None + _block_hash: Optional[BlockHashWithGroupId] = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. @@ -141,11 +141,11 @@ def decr_ref(self): self.ref_cnt -= 1 @property - def block_hash(self) -> Optional[BlockHashTypeWithGroupId]: + def block_hash(self) -> Optional[BlockHashWithGroupId]: return self._block_hash @block_hash.setter - def block_hash(self, block_hash: BlockHashTypeWithGroupId): + def block_hash(self, block_hash: BlockHashWithGroupId): assert self.block_hash is None, ( "The block already has a hash. This should not happen.") self._block_hash = block_hash @@ -408,7 +408,7 @@ def hash_block_tokens( hash_function: Callable, parent_block_hash: Optional[int], curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: + extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -429,14 +429,14 @@ def hash_block_tokens( parent_block_hash = NONE_HASH curr_block_token_ids_tuple = tuple(curr_block_token_ids) - return BlockHashType( + return BlockHash( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), curr_block_token_ids_tuple, extra_keys) def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHashType]: + request: Request) -> list[BlockHash]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index eb94fdd211a4..1f5d96fd5550 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,7 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -125,7 +125,7 @@ def allocate_new_blocks(self, request_id: str, req_blocks.extend(new_blocks) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], num_tokens: int) -> None: """ Cache the blocks for the request. @@ -189,7 +189,7 @@ def get_num_common_prefix_blocks(self, request_id: str, @abstractmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHashType], + block_hashes: list[BlockHash], max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -247,7 +247,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHashType], + block_hashes: list[BlockHash], max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, @@ -304,7 +304,7 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHashType], + block_hashes: list[BlockHash], max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, From 904bd256513c7d594ec9a4b272f1dd9c5c440364 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 02:38:45 -0700 Subject: [PATCH 28/44] update coordinator Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 69814cccacda..649ab4b57c1a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -175,7 +175,7 @@ def find_longest_cache_hit( pass -class SingleGroupKVCacheCoordinator(KVCacheCoordinator): +class UnitaryKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for models with only one KV cache group. This is the case for models with only one KV cache type, e.g., all attention layers use @@ -192,7 +192,7 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "UnifiedKVCacheCoordinator assumes only one kv cache group") + "UnitaryKVCacheCoordinator assumes only one kv cache group") def find_longest_cache_hit( self, block_hashes: list[BlockHash], @@ -317,6 +317,14 @@ def find_longest_cache_hit( use_eagle=self.use_eagle, )) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size + + # NOTE: the prefix cache hit length must be a multiply of block_size as + # we don't support partial block cache hit yet. The cache hit length + # of other attention is ensured to be a multiple of the block size of + # full attention layers in current implementation, because hit_length is + # a multiple of other attention's block size, and other attention's + # block size is a multiple of full attention's block size (verified in + # `verify_and_split_kv_cache_groups`). assert hit_length % self.full_attention_block_size == 0 # Truncate the full attention cache hit to the length of the @@ -328,7 +336,6 @@ def find_longest_cache_hit( # Merge the hit blocks of full attention and other attention. hit_blocks = hit_blocks_other_attn for group_id, blocks in enumerate(hit_blocks_full_attn): - del blocks[hit_length // self.full_attention_block_size:] # NOTE: there is only one full attention group in most cases. So # the time complexity of insert is fine. hit_blocks.insert(group_id, blocks) @@ -340,10 +347,10 @@ def get_kv_cache_coordinator( enable_caching: bool, caching_hash_fn: Callable, enable_kv_cache_events: bool) -> KVCacheCoordinator: if len(kv_cache_config.kv_cache_groups) == 1: - return SingleGroupKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - caching_hash_fn, - enable_kv_cache_events) + return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, + use_eagle, enable_caching, + caching_hash_fn, + enable_kv_cache_events) else: return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, From 66032cffc9959b53e026566579c475b6eb7731eb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 03:08:42 -0700 Subject: [PATCH 29/44] small fix Signed-off-by: Chen Zhang --- vllm/v1/core/block_pool.py | 8 ++++++-- vllm/v1/core/kv_cache_utils.py | 6 +++++- vllm/v1/worker/gpu_model_runner.py | 13 ++++++------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index f637c2cbad11..9ac7d2836dc6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -249,6 +249,10 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: del self.cached_block_hash_to_block[block_hash] if self.enable_kv_cache_events: + # FIXME (Chen): Not sure whether we should return `hash_value` + # or `(hash_value, group_id)` here. But it's fine now because + # we disable hybrid kv cache manager when kv cache event is + # enabled, so there is only one group. self.kv_event_queue.append( BlockRemoved(block_hashes=[block_hash.get_hash_value()])) return True @@ -262,8 +266,8 @@ def touch(self, blocks: list[list[KVCacheBlock]]) -> None: Args: blocks: A list of blocks to touch. """ - for blocks_one_manager in blocks: - for block in blocks_one_manager: + for blocks_per_group in blocks: + for block in blocks_per_group: # ref_cnt=0 means this block is in the free list (i.e. eviction # candidate), so remove it. if block.ref_cnt == 0 and block != self.null_block: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 1639f5cdf177..313aafad62c0 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -35,7 +35,11 @@ class BlockHash(NamedTuple): class BlockHashWithGroupId(NamedTuple): + # The hash value for the contents (e.g., token_ids) of a block without group + # ID. The value is the same for blocks representing the same tokens but for + # different groups. block_hash: BlockHash + # The KV cache group ID. group_id: int def get_hash_value(self) -> int: @@ -638,7 +642,7 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: """ Get the page size of the KV cache. """ - page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) assert len(page_sizes) == 1 return page_sizes.pop() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 674cacf73173..38d06862f3be 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1867,7 +1867,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) - def _initialize_kv_cache_buffer( + def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs @@ -1894,13 +1894,13 @@ def _initialize_kv_cache_buffer( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors - def _setup_kv_cache_shapes( + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """ - Reshape the KV cache tensors to the desired shape. + Reshape the KV cache tensors to the desired shape and dtype. Args: kv_cache_config: The KV cache config @@ -1942,11 +1942,10 @@ def initialize_kv_cache_tensors( corresponding memory buffer for KV cache. """ # Initialize the memory buffer for KV cache - kv_cache_raw_tensors = self._initialize_kv_cache_buffer( - kv_cache_config) + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._setup_kv_cache_shapes(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, []) From 395e2bc866ac06da19f1658c68e65510431f8bce Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 03:40:15 -0700 Subject: [PATCH 30/44] small fix Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_coordinator.py | 4 ++-- vllm/v1/core/kv_cache_utils.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 649ab4b57c1a..f6c8d569b9c6 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Callable, Optional from vllm.v1.core.block_pool import BlockPool @@ -11,7 +11,7 @@ from vllm.v1.request import Request -class KVCacheCoordinator: +class KVCacheCoordinator(ABC): """ Coordinate the KV cache of different KV cache groups. """ diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 313aafad62c0..ed1588503f8f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -611,7 +611,7 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: True if all layers have the same type, False otherwise. """ - layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) + layer_keys = {layer.type_id for layer in kv_cache_spec.values()} return len(layer_keys) == 1 @@ -881,8 +881,11 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: if not is_hybrid(kv_cache_spec): return - logger.warning("Hybrid KV cache manager is disabled for this hybrid model," - "There can be some waste of KV cache memory.") + logger.warning( + "Hybrid KV cache manager is disabled for this hybrid model, " + "This means we do not enable any optimizations for saving KV cache " + "memory (e.g., dropping the KV cache outside the sliding window). " + "The compute of layers like sliding window is still saved.") has_full_attention = any( isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) From 3556db8cc74a97a503a407c48d688b6321321c10 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 08:44:00 -0700 Subject: [PATCH 31/44] update logging Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ed1588503f8f..24fb6ec589dc 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -842,12 +842,13 @@ def _get_kv_cache_config_uniform_page_size( KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) # Print the KV cache size and maximum concurrency. - # TODO in this PR: Now just copy from the uniform type implementation. - # Should reimplement this for hybrid model - num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens = num_blocks // len( + grouped_layers) * vllm_config.cache_config.block_size num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + # TODO in this PR: Now just copy from the uniform type implementation. + # Update after https://github.com/vllm-project/vllm/pull/19029 max_concurrency = num_tokens / vllm_config.model_config.max_model_len logger.info( "Maximum concurrency for %s tokens per request: %.2fx", From b4169634b1fc24ff01a656297c6d2198f85d56bc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 09:33:32 -0700 Subject: [PATCH 32/44] add todo in this pr Signed-off-by: Chen Zhang --- vllm/v1/engine/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f281a186e6dd..3cf5757c3a83 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -155,7 +155,7 @@ def _initialize_kv_caches( num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 scheduler_kv_cache_config = kv_cache_configs[0] - # TODO: remove this debug print + # TODO in this PR: remove this debug print print("kv_cache_config", scheduler_kv_cache_config) # Initialize kv cache and warmup the execution From e629ee828edfe0f4250b9614eff51f9b24f68da3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Jun 2025 16:57:28 +0000 Subject: [PATCH 33/44] fix tpu backend Signed-off-by: Chen Zhang --- vllm/v1/worker/tpu_model_runner.py | 56 +++++++++++++++++------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1b61ec64cc14..96cbb16e4dad 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -27,10 +27,12 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import PallasMetadata +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata @@ -1264,29 +1266,33 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert self.block_table_cpu.dtype == self.input_batch.block_table[ 0].get_cpu_tensor().dtype + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "TPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + kv_caches: dict[str, torch.Tensor] = {} - # TODO in this PR: update to the new kv cache config interface. - # comment out temporarily to pass type checker - # for kv_cache_group in kv_cache_config.kv_cache_groups: - # kv_cache_spec = kv_cache_group.kv_cache_spec - # for layer_name in kv_cache_group.layer_names: - # tensor_config = kv_cache_config.tensors[layer_name] - # assert isinstance(tensor_config, KVCacheNewTensor) - # assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - # num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes # noqa - # if isinstance(kv_cache_spec, AttentionSpec): - # kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( # noqa - # num_blocks, kv_cache_spec.block_size, - # kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - # dtype = kv_cache_spec.dtype - - # tpu_kv_cache = torch.zeros(kv_cache_shape, - # dtype=dtype, - # device=self.device) - - # kv_caches[layer_name] = tpu_kv_cache - # else: - # raise NotImplementedError + for kv_cache_group in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group.kv_cache_spec + for layer_name in kv_cache_group.layer_names: + tensor_size = kv_cache_sizes[layer_name] + assert tensor_size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( # noqa + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + tpu_kv_cache = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + + kv_caches[layer_name] = tpu_kv_cache + else: + raise NotImplementedError bind_kv_cache( kv_caches, From a52d27114699e1ab635205e776aa53818e2a8fdf Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 08:05:41 -0700 Subject: [PATCH 34/44] pass tests in v1/core Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 56 ++++++------- tests/v1/core/test_prefix_caching.py | 83 +++++++++----------- tests/v1/core/test_scheduler.py | 4 +- tests/v1/core/test_specialized_manager.py | 49 ++++++------ vllm/v1/core/single_type_kv_cache_manager.py | 4 +- 5 files changed, 95 insertions(+), 101 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 85ddb0c21aa2..7e35bd949eb7 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -19,7 +19,7 @@ hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheNewTensor, + KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -379,10 +379,10 @@ def test_unify_kv_cache_configs(): same_kv_cache_config = [ KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheNewTensor(100), - "layer2": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -391,10 +391,10 @@ def test_unify_kv_cache_configs(): ), KVCacheConfig( num_blocks=20, - tensors={ - "layer1": KVCacheNewTensor(100), - "layer2": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -409,10 +409,10 @@ def test_unify_kv_cache_configs(): need_sort_kv_cache_config = [ KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheNewTensor(100), - "layer2": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -421,10 +421,10 @@ def test_unify_kv_cache_configs(): ), KVCacheConfig( num_blocks=20, - tensors={ - "layer1": KVCacheNewTensor(100), - "layer2": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)), @@ -440,10 +440,10 @@ def test_unify_kv_cache_configs(): diff_kv_cache_config = [ KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheNewTensor(100), - "layer2": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -452,10 +452,10 @@ def test_unify_kv_cache_configs(): ), KVCacheConfig( num_blocks=20, - tensors={ - "layer1": KVCacheNewTensor(100), - "layer2": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + KVCacheTensor(size=100, shared_by=["layer2"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), KVCacheGroupSpec(["layer2"], @@ -576,9 +576,9 @@ def test_allocate_with_lookahead(): block_size = 4 config = KVCacheConfig( num_blocks=10, - tensors={ - "layer1": KVCacheNewTensor(100), - }, + kv_cache_tensors=[ + KVCacheTensor(size=100, shared_by=["layer1"]), + ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)), diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index e7d20fb989e6..ee3882bff477 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -12,8 +12,8 @@ from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - KVCacheBlockBundle, hash_block_tokens) +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + KVCacheBlock, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -47,7 +47,7 @@ def make_request(request_id, def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer"], @@ -79,7 +79,7 @@ def test_prefill(hash_algo): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id][block_size]) == 3 + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -93,7 +93,8 @@ def test_prefill(hash_algo): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[ + block_id].block_hash.block_hash == block_hash parent_block_hash = block_hash.hash_value # Check partial block metadata @@ -108,7 +109,7 @@ def test_prefill(hash_algo): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id][block_size]) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -142,7 +143,7 @@ def test_prefill(hash_algo): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id][block_size]) == 3 + assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -201,7 +202,7 @@ def test_prefill_plp(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id][block_size]) == 0 + assert len(manager.req_to_block_hashes[req0.request_id]) == 0 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -216,7 +217,8 @@ def test_prefill_plp(): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[block_id].block_hash == block_hash + assert manager.block_pool.blocks[ + block_id].block_hash.block_hash == block_hash parent_block_hash = block_hash.hash_value # Check partial block metadata @@ -232,7 +234,7 @@ def test_prefill_plp(): unique_token_ids = [3] * 5 req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id][block_size]) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -268,7 +270,7 @@ def test_prefill_plp(): common_token_ids + unique_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id][block_size]) == 0 + assert len(manager.req_to_block_hashes[req2.request_id]) == 0 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, @@ -426,7 +428,7 @@ def test_hash_block_correct_reuse(): assert len(blocks.blocks[0]) == 1 assert manager.block_pool.blocks[blocks.blocks[0] - [0].master_block_id].block_hash is None + [0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -451,7 +453,7 @@ def test_computed_blocks_not_evicted(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 - assert blocks.blocks[0][0].master_block_id == 1 + assert blocks.blocks[0][0].block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) @@ -462,7 +464,7 @@ def test_computed_blocks_not_evicted(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 - assert blocks.blocks[0][0].master_block_id == 2 + assert blocks.blocks[0][0].block_id == 2 # Free the blocks. manager.free(req0) @@ -473,14 +475,14 @@ def test_computed_blocks_not_evicted(): req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 - assert computed_blocks.blocks[0][0].master_block_id == 1 + assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) assert len(blocks.blocks[0]) == 1 - assert blocks.blocks[0][0].master_block_id == 2 + assert blocks.blocks[0][0].block_id == 2 def test_basic_prefix_caching_disabled(): @@ -538,7 +540,6 @@ def test_cache_blocks(hash_fn): block_pool = BlockPool( num_gpu_blocks=5, enable_caching=True, - num_single_type_managers=1, ) # Req: # Block 0: [0, 1, 2, 3] @@ -548,11 +549,8 @@ def test_cache_blocks(hash_fn): req = make_request("0", list(range(14))) # Test that blocks are cached correctly for 2 full blocks from the start. - blocks = [ - KVCacheBlockBundle(blocks=(KVCacheBlock(block_id=i), )) - for i in range(2) - ] - block_hashes: list[BlockHashType] = [] + blocks = [KVCacheBlock(block_id=i) for i in range(2)] + block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, @@ -562,14 +560,14 @@ def test_cache_blocks(hash_fn): num_full_blocks=2, block_size=block_size, hash_fn=hash_fn, - manager_id=0, + kv_cache_group_id=0, ) - assert len(block_pool.cached_block_hash_to_block[0]) == 2 + assert len(block_pool.cached_block_hash_to_block) == 2 assert all([block.block_hash is not None for block in blocks]) # Test that blocks that don't start from the beginning are cached correctly. - blocks += [KVCacheBlockBundle(blocks=(KVCacheBlock(block_id=2), ))] + blocks += [KVCacheBlock(block_id=2)] block_pool.cache_full_blocks( request=req, blocks=blocks, @@ -578,9 +576,9 @@ def test_cache_blocks(hash_fn): num_full_blocks=3, block_size=block_size, hash_fn=hash_fn, - manager_id=0, + kv_cache_group_id=0, ) - assert len(block_pool.cached_block_hash_to_block[0]) == 3 + assert len(block_pool.cached_block_hash_to_block) == 3 assert blocks[0].block_hash is not None @@ -588,7 +586,6 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ - block_size = 16 manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, @@ -623,7 +620,7 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id][block_size] + block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") @@ -684,7 +681,7 @@ def test_cache_key_salting(): # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id][block_size] + block_hashes = manager.req_to_block_hashes[req0.request_id] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None @@ -722,7 +719,7 @@ def test_cache_key_salting(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req2.request_id][block_size] + block_hashes = manager.req_to_block_hashes[req2.request_id] assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) @@ -798,7 +795,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): - block_size = 16 manager = KVCacheManager( make_kv_cache_config(16, 11), max_model_len=8192, @@ -816,7 +812,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req1 = make_request("1", all_token_ids) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id][block_size]) == 3 + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, @@ -832,7 +828,7 @@ def test_reset_prefix_cache(): manager.free(req1) assert manager.reset_prefix_cache() - assert not manager.block_pool.cached_block_hash_to_block[0] + assert not manager.block_pool.cached_block_hash_to_block assert all([blk.block_hash is None for blk in manager.block_pool.blocks]) @@ -883,7 +879,7 @@ def test_kv_cache_events(blocks_to_cache: int): block = events[-1] assert (len(block.block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block[0])) + manager.block_pool.cached_block_hash_to_block)) assert len(block.token_ids) == block.block_size * len(block.block_hashes) assert len(manager.block_pool.kv_event_queue) == 0 @@ -902,7 +898,7 @@ def test_kv_cache_events(blocks_to_cache: int): assert len(events) == blocks_to_cache + 1 assert (isinstance(events[-2], BlockRemoved)) assert (len(events[-1].block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block[0])) + manager.block_pool.cached_block_hash_to_block)) # All Blocks Cleared # Should see a single all blocks cleared event @@ -911,7 +907,7 @@ def test_kv_cache_events(blocks_to_cache: int): events = manager.take_events() assert isinstance(events[-1], AllBlocksCleared) - assert len(manager.block_pool.cached_block_hash_to_block[0]) == 0 + assert len(manager.block_pool.cached_block_hash_to_block) == 0 def test_eagle_enabled_removes_last_block(): @@ -989,7 +985,7 @@ def test_eagle_with_sliding_window(): manager = KVCacheManager( KVCacheConfig( num_blocks=10, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], ), max_model_len=8192, @@ -1007,8 +1003,7 @@ def test_eagle_with_sliding_window(): len(computed_blocks.blocks[0]) * 16, computed_blocks) # record the block hash of the first block in the request for later use - block_hash_first_block = manager.req_to_block_hashes[ - req.request_id][block_size][0] + block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] assert block_hash_first_block is not None manager.free(req) @@ -1020,10 +1015,10 @@ def test_eagle_with_sliding_window(): assert num_tokens == 1 * block_size # Evict the first block in the request - assert manager.block_pool.get_cached_block(block_hash_first_block, - manager_id=0) is not None - manager.block_pool.cached_block_hash_to_block[0].pop( - block_hash_first_block) + assert manager.block_pool.get_cached_block( + block_hash_first_block, kv_cache_group_ids=[0]) is not None + manager.block_pool.cached_block_hash_to_block.pop( + BlockHashWithGroupId(block_hash_first_block, 0)) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 9d8ba51ee84b..1870e46e36c7 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -96,7 +96,7 @@ def create_scheduler( ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, @@ -818,7 +818,7 @@ def _assert_right_kv_cache_manager( assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS - assert len(hashes[block_size]) == EXPECTED_TOTAL_BLOCKS + assert len(hashes) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. BLOCKS_PER_REQ = num_tokens / block_size diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 82d1f3a7da72..bb5f196204b0 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,8 +3,8 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - KVCacheBlockBundle) +from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, + KVCacheBlock) from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec @@ -12,10 +12,8 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): return SlidingWindowManager(sliding_window_spec, block_pool, - use_eagle=False, - num_kv_cache_groups=1, caching_hash_fn=lambda x: x, - manager_id=0) + kv_cache_group_id=0) def test_sliding_window_possible_cached_prefix(): @@ -29,9 +27,7 @@ def test_sliding_window_possible_cached_prefix(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=100, - enable_caching=True, - num_single_type_managers=1) + block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): @@ -39,28 +35,33 @@ def run_one_case(block_is_cached, expect_length): BlockHash(i, ()) for i in range(len(block_is_cached)) ] - block_pool.cached_block_hash_to_block[0].clear() + block_pool.cached_block_hash_to_block.clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[0][block_hash] = { - i: KVCacheBlockBundle(blocks=(block_pool.blocks[i + 10], )) - } + block_pool.cached_block_hash_to_block[BlockHashWithGroupId( + block_hash, 0)] = { + i: block_pool.blocks[i + 10], + } computed_blocks = manager.find_longest_cache_hit( - block_hash_list, - len(block_hash_list) * block_size) + block_hashes=block_hash_list, + max_length=len(block_hash_list) * block_size, + kv_cache_group_ids=[0], + block_pool=block_pool, + kv_cache_spec=sliding_window_spec, + use_eagle=False)[0] assert len(computed_blocks) == expect_length - assert all(block == manager.null_block + assert all(block == block_pool.null_block for block in computed_blocks[:expect_length - 2]) for i in range(2): if i < expect_length: block_index = expect_length - i - 1 assert computed_blocks[ - block_index].master_block_id == block_index + 10 + block_index].block_id == block_index + 10 run_one_case([False] * 10, 0) run_one_case([True], 1) @@ -92,26 +93,24 @@ def test_sliding_window_remove_skipped_blocks(): use_mla=False, ) - block_pool = BlockPool(num_gpu_blocks=2000, - enable_caching=True, - num_single_type_managers=1) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) manager = get_sliding_window_manager(sliding_window_spec, block_pool) null_block_id = block_pool.null_block.block_id - def id_to_block_table(ids) -> list[KVCacheBlockBundle]: + def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlockBundle(blocks=(KVCacheBlock(id_), )) - if id_ != null_block_id else manager.null_block for id_ in ids + KVCacheBlock(id_) + if id_ != null_block_id else block_pool.null_block for id_ in ids ] - def assert_block_id(block_table: list[KVCacheBlockBundle], ids: list[int]): + def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): for block, id_ in zip(block_table, ids): if id_ == null_block_id: - assert block == manager.null_block + assert block == block_pool.null_block else: - assert block.master_block_id == id_ + assert block.block_id == id_ original_block_ids = [ 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 1f5d96fd5550..52c5ce426487 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -271,7 +271,7 @@ def find_longest_cache_hit( computed_blocks[j].append(cached_block[j]) else: break - if use_eagle and len(computed_blocks) > 0: + if use_eagle and len(computed_blocks[0]) > 0: for j in range(len(kv_cache_group_ids)): computed_blocks[j].pop() return computed_blocks @@ -357,7 +357,7 @@ def find_longest_cache_hit( # `num_contiguous_blocks < sliding_window_contiguous_blocks`. for j in range(len(kv_cache_group_ids)): del computed_blocks[j][num_contiguous_blocks:] - if use_eagle and len(computed_blocks) > 0: + if use_eagle and len(computed_blocks[0]) > 0: for j in range(len(kv_cache_group_ids)): computed_blocks[j].pop() return computed_blocks From 08e0888492256f2843f9f68fea973adb523940ef Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 08:23:08 -0700 Subject: [PATCH 35/44] revert previous change in tests/v1/core Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 7 +++++++ tests/v1/core/test_prefix_caching.py | 24 ++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 7e35bd949eb7..4bdf8e47c671 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -75,8 +75,15 @@ def test_kv_cache_block(): # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) assert block.block_id == 0 + assert block.ref_cnt == 0 assert block.block_hash is None + # Test reference count manipulation + block.incr_ref() + assert block.ref_cnt == 1 + block.decr_ref() + assert block.ref_cnt == 0 + # Test block hash setting and resetting block_hash = BlockHash(hash_value=123, token_ids=(1, 2, 3)) block.block_hash = block_hash diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index ee3882bff477..c5a6e1806441 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -59,9 +59,8 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) def test_prefill(hash_algo): - block_size = 16 manager = KVCacheManager( - make_kv_cache_config(block_size, 11), + make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, caching_hash_algo=hash_algo, @@ -95,14 +94,13 @@ def test_prefill(hash_algo): block_tokens) assert manager.block_pool.blocks[ block_id].block_hash.block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None - - for block in blocks.blocks[0]: - assert block.ref_cnt == 1 + assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) @@ -155,6 +153,10 @@ def test_prefill(hash_algo): # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. assert manager.block_pool.free_block_queue.num_free_blocks == 6 + assert all([ + b.ref_cnt == 0 + for b in manager.block_pool.free_block_queue.get_all_free_blocks() + ]) assert len([ b for b in manager.block_pool.free_block_queue.get_all_free_blocks() ]) == 6 @@ -183,9 +185,8 @@ def test_prefill_plp(): 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' - block_size = 16 manager = KVCacheManager( - make_kv_cache_config(block_size, 11), + make_kv_cache_config(16, 11), max_model_len=8192, enable_caching=True, ) @@ -219,14 +220,13 @@ def test_prefill_plp(): block_tokens) assert manager.block_pool.blocks[ block_id].block_hash.block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 parent_block_hash = block_hash.hash_value # Check partial block metadata for block_id in (4, ): assert manager.block_pool.blocks[block_id].block_hash is None - - for block in blocks.blocks[0]: - assert block.ref_cnt == 1 + assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Request #1 is a non-prompt-logprobs request: # Cache hit in the common prefix when the original block is still in use. @@ -283,8 +283,8 @@ def test_prefill_plp(): # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. - for block in blocks.blocks[0]: - assert block.ref_cnt == 1 + for block_id in block_ids[0]: + assert manager.block_pool.blocks[block_id].ref_cnt == 1 manager.free(req2) From 2140dc6111a59b0274ecc67eb9203c521c1d8d6c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 08:26:36 -0700 Subject: [PATCH 36/44] update worker test Signed-off-by: Chen Zhang --- tests/v1/worker/test_gpu_input_batch.py | 6 +++--- tests/v1/worker/test_gpu_model_runner.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 638f5bedcfca..964786b3af51 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -27,9 +27,9 @@ def get_kv_cache_config() -> KVCacheConfig: return KVCacheConfig( num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, + kv_cache_tensors=[ + KVCacheTensor(size=1024, shared_by=["layer.0"]), + ], kv_cache_groups=[ KVCacheGroupSpec( layer_names=["layer.0"], diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e44660525763..40aee7f0ee2f 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -20,9 +20,9 @@ def initialize_kv_cache(runner: GPUModelRunner): """ kv_cache_config = KVCacheConfig( num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, + kv_cache_tensors=[ + KVCacheTensor(size=1024, shared_by=["layer.0"]), + ], kv_cache_groups=[ KVCacheGroupSpec( layer_names=["layer.0"], From b63d8eaeefe180f778f05d8123c606f7adb13d55 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 08:38:11 -0700 Subject: [PATCH 37/44] test_cache_blocks_multi_group Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 71 ++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c5a6e1806441..8130b0eb3d16 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -582,6 +582,77 @@ def test_cache_blocks(hash_fn): assert blocks[0].block_hash is not None +def test_cache_blocks_multi_group(): + """ + This tests that blocks are cached correctly for different kv cache groups. + """ + block_size = 4 + block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + + # Req: + # Block 0/4: [0, 1, 2, 3] + # Block 1/5: [4, 5, 6, 7] + # Block 2/6: [8, 9, 10, 11] + # Block 3/7: [12, 13] + req = make_request("0", list(range(14))) + + # Cache the blocks for group 0. + blocks = [KVCacheBlock(block_id=i) for i in range(2)] + block_hashes: list[BlockHash] = [] + block_pool.cache_full_blocks( + request=req, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=0, + num_full_blocks=2, + block_size=block_size, + hash_fn=hash, + kv_cache_group_id=0, + ) + assert len(block_pool.cached_block_hash_to_block) == 2 + assert len(block_hashes) == 2 + assert all([block.block_hash is not None for block in blocks]) + + # Cache the blocks for group 1. + blocks = [KVCacheBlock(block_id=i) for i in range(3)] + block_pool.cache_full_blocks( + request=req, + blocks=blocks, + block_hashes=block_hashes, + num_cached_blocks=0, + num_full_blocks=3, + block_size=block_size, + hash_fn=hash, + kv_cache_group_id=1, + ) + assert len(block_pool.cached_block_hash_to_block) == 5 + assert len(block_hashes) == 3 + assert all([block.block_hash is not None for block in blocks]) + + # Block hash 0: hit for group 0 and 1 + # Block hash 1: hit for group 0 and 1 + # Block hash 2: hit for group 1 + + assert block_pool.get_cached_block(block_hashes[0], + kv_cache_group_ids=[0]) is not None + assert block_pool.get_cached_block(block_hashes[1], + kv_cache_group_ids=[0]) is not None + assert block_pool.get_cached_block(block_hashes[2], + kv_cache_group_ids=[0]) is None + assert block_pool.get_cached_block(block_hashes[0], + kv_cache_group_ids=[1]) is not None + assert block_pool.get_cached_block(block_hashes[1], + kv_cache_group_ids=[1]) is not None + assert block_pool.get_cached_block(block_hashes[2], + kv_cache_group_ids=[1]) is not None + assert block_pool.get_cached_block(block_hashes[0], + kv_cache_group_ids=[0, 1]) is not None + assert block_pool.get_cached_block(block_hashes[1], + kv_cache_group_ids=[0, 1]) is not None + assert block_pool.get_cached_block(block_hashes[2], + kv_cache_group_ids=[0, 1]) is None + + def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. From b64b8b1176e1d2e3db4968d2ae43abd6118c98ef Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 09:24:45 -0700 Subject: [PATCH 38/44] test_prefill_hybrid_model Signed-off-by: Chen Zhang --- tests/v1/core/test_prefix_caching.py | 165 +++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 8130b0eb3d16..a3f05422121b 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Compare the with and without prefix caching.""" +import copy from typing import Optional import pytest @@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) +def make_kv_cache_config_hybrid_model(block_size: int, + num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec(block_size, 1, 1, torch.float32, False), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec(block_size, + 1, + 1, + torch.float32, + False, + sliding_window=2 * block_size), + ), + KVCacheGroupSpec( + ["layer3"], + SlidingWindowSpec(block_size, + 1, + 1, + torch.float32, + False, + sliding_window=2 * block_size), + ), + ], + ) + + @pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) def test_prefill(hash_algo): manager = KVCacheManager( @@ -178,6 +211,138 @@ def test_prefill(hash_algo): assert manager.block_pool.free_block_queue.free_list_tail is None +def test_prefill_hybrid_model(): + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config_hybrid_model(block_size, 21), + max_model_len=8192, + enable_caching=True, + ) + + hash_fn = hash + + # Complete 3 blocks (48 tokens) + common_token_ids = [i for i in range(3) for _ in range(block_size)] + + # Fully cache miss + # Incomplete 1 block (7 tokens) + unique_token_ids = [3] * 7 + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert not computed_blocks.blocks[0] + assert num_computed_tokens == 0 + blocks = manager.allocate_slots(req0, 55, + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12]] + + # Check full block metadata + parent_block_hash = None + for length, block_ids in zip((1, 2, 3), + ((1, 5, 9), (2, 6, 10), (3, 7, 11))): + block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, + block_tokens) + for block_id in block_ids: + assert manager.block_pool.blocks[ + block_id].block_hash.block_hash == block_hash + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + parent_block_hash = block_hash.hash_value + + # Check partial block metadata + for block_id in (4, 8, 12): + assert manager.block_pool.blocks[block_id].block_hash is None + assert manager.block_pool.blocks[block_id].ref_cnt == 1 + + # Cache hit in the common prefix + # Incomplete 1 block (5 tokens) + unique_token_ids = [3] * 5 + req1 = make_request("1", common_token_ids + unique_token_ids) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7], + [0, 10, 11]] + assert num_computed_tokens == 3 * 16 + num_new_tokens = 53 - 3 * 16 + blocks = manager.allocate_slots(req1, num_new_tokens, + len(computed_blocks.blocks[0]) * 16, + computed_blocks) + assert blocks.get_block_ids() == [[13], [14], [15]] + for block_per_group in computed_blocks.blocks: + for block in block_per_group: + if block != manager.block_pool.null_block: + assert block.ref_cnt == 2 + + block_hashes = manager.req_to_block_hashes[req1.request_id] + manager.free(req0) + manager.free(req1) + + cached_block_hash_to_block_bak = copy.copy( + manager.block_pool.cached_block_hash_to_block) + + def test_partial_request_hit(request_id: str, + hash_to_evict: list[BlockHashWithGroupId], + expect_hit_length: int): + req = make_request(request_id, common_token_ids + unique_token_ids) + for hash_with_group_id in hash_to_evict: + manager.block_pool.cached_block_hash_to_block.pop( + hash_with_group_id) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) + assert len(manager.req_to_block_hashes[req.request_id]) == 3 + assert num_computed_tokens == expect_hit_length * block_size + for block_per_group in computed_blocks.blocks: + assert len(block_per_group) == num_computed_tokens // block_size + for hash_with_group_id in hash_to_evict: + manager.block_pool.cached_block_hash_to_block[ + hash_with_group_id] = cached_block_hash_to_block_bak[ + hash_with_group_id] + manager.free(req) + + # Evict the blocks outside sliding window, does not affect the hit length. + test_partial_request_hit("2", [ + BlockHashWithGroupId(block_hashes[0], 1), + BlockHashWithGroupId(block_hashes[0], 2) + ], 3) + + # Evict the first block of full attention, makes total cache miss. + test_partial_request_hit("3", [ + BlockHashWithGroupId(block_hashes[0], 0), + ], 0) + + # Evict the last block of all layers, reduces the hit length to 2. + test_partial_request_hit("4", [ + BlockHashWithGroupId(block_hashes[2], 0), + BlockHashWithGroupId(block_hashes[2], 1), + BlockHashWithGroupId(block_hashes[2], 2), + ], 2) + + # Evict the last block of full attention, reduces the hit length to 2. + test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], + 2) + + # Evict the last block of sliding window, reduces the hit length to 2. + test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], + 2) + + # Evict the last block of sliding window, reduces the hit length to 2. + test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], + 2) + + # Evict different set of blocks for full attention and sliding window makes + # total cache miss. + # The cache hit length of full attention is 1 * block_size. + # The cache hit length of sliding window is 2 * block_size. + # Then it is cache miss as the two type of layers have different hit length. + test_partial_request_hit("8", [ + BlockHashWithGroupId(block_hashes[2], 0), + BlockHashWithGroupId(block_hashes[0], 1), + BlockHashWithGroupId(block_hashes[0], 2), + ], 0) + + def test_prefill_plp(): '''Test prefill with APC and some prompt logprobs (plp) requests. From 5fb5e490d99280f0e37ef7499661335af3bd8d66 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 09:26:41 -0700 Subject: [PATCH 39/44] revert test_scheduler Signed-off-by: Chen Zhang --- tests/v1/core/test_scheduler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 1870e46e36c7..5e2aea64dcda 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1206,11 +1206,10 @@ def assert_scheduler_empty(scheduler: Scheduler): assert num_free_blocks == ( scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) - # TODO(Chen): find a way to test no leak on ref_cnt. # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. - # for block in scheduler.kv_cache_manager.block_pool.blocks: - # assert block.ref_cnt == 0 + for block in scheduler.kv_cache_manager.block_pool.blocks: + assert block.ref_cnt == 0 # assert block._block_hash is None # assert ( # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block From 13b486a1be16dfd69d4d3c48cfbd383e332c7b58 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 09:29:38 -0700 Subject: [PATCH 40/44] revert test_manager Signed-off-by: Chen Zhang --- tests/v1/core/test_specialized_manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index bb5f196204b0..6a82cabed7b6 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -116,8 +116,6 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 ] block_table = id_to_block_table(original_block_ids) - for block in block_table: - block.incr_ref() manager.req_to_blocks["test"] = block_table manager.remove_skipped_blocks("test", 0) @@ -150,5 +148,3 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): # of removed blocks should be [1003, 1002]. manager.remove_skipped_blocks("test", 11) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) - - manager.free("test") From b598c0eadc1279e4df51be362cb7d420bbe394fb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 09:54:46 -0700 Subject: [PATCH 41/44] test_get_kv_cache_config Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 176 +++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4bdf8e47c671..b6fb8f07fbd5 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -15,6 +15,7 @@ PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, + get_kv_cache_config, hash_block_tokens, hash_request_tokens, unify_kv_cache_configs) @@ -65,6 +66,20 @@ def new_kv_cache_spec(block_size=16, sliding_window=sliding_window) +def new_sliding_window_spec(block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + use_mla=False, + sliding_window=1): + return SlidingWindowSpec(block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + use_mla=use_mla, + sliding_window=sliding_window) + + def test_none_hash(): assert NONE_HASH is not None assert isinstance(NONE_HASH, int) @@ -630,3 +645,164 @@ def test_allocate_with_lookahead(): num_lookahead_tokens=4, ) assert len(blocks.get_block_ids()[0]) == 2 + + +def test_get_kv_cache_config(): + # pass max_model_len to pass check_enough_kv_cache_memory + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + # all layers are full attention -> single group + kv_cache_specs_full = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(), + } + kv_cache_config_full = get_kv_cache_config( + vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_full == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) + ]) + + # all layers are sliding window -> single group + kv_cache_specs_sliding = { + 'layer_1': new_sliding_window_spec(), + 'layer_2': new_sliding_window_spec(), + } + kv_cache_config_sliding = get_kv_cache_config( + vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_sliding == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec()) + ]) + + # full + sliding, but disable_hybrid_kv_cache_manager + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], + new_kv_cache_spec(sliding_window=1)), + ], + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + + # full + sliding, with hybrid_kv_cache_manager + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=64, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 64, + shared_by=["layer_1", "layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer_2"], new_sliding_window_spec()), + ], + ) + + # 2 full + 4 sliding, 2 layers per group + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(), + 'layer_3': new_sliding_window_spec(), + 'layer_4': new_sliding_window_spec(), + 'layer_5': new_sliding_window_spec(), + 'layer_6': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_3", "layer_5"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_4", "layer_6"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer_3", "layer_4"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_5", "layer_6"], + new_sliding_window_spec()), + ], + ) + + # 3 full + 7 sliding, pad to 3 full + 9 sliding + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(), + 'layer_3': new_kv_cache_spec(), + 'layer_4': new_sliding_window_spec(), + 'layer_5': new_sliding_window_spec(), + 'layer_6': new_sliding_window_spec(), + 'layer_7': new_sliding_window_spec(), + 'layer_8': new_sliding_window_spec(), + 'layer_9': new_sliding_window_spec(), + 'layer_10': new_sliding_window_spec(), + } + kv_cache_config_hybrid = get_kv_cache_config( + vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32) + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_5", "layer_8"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_3", "layer_6", "layer_9"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], + new_kv_cache_spec()), + KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()), + ], + ) + + # different hidden size, unimplemented + kv_cache_specs_hybrid = { + 'layer_1': new_kv_cache_spec(head_size=128), + 'layer_2': new_kv_cache_spec(), + } + with pytest.raises(NotImplementedError): + get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, + mem_per_block_per_layer * 2 * 32) From 85798c5d948cf9db15e9245c134c4580ba79997d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 20:39:27 -0700 Subject: [PATCH 42/44] fix ci Signed-off-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 6 +++--- tests/v1/worker/test_gpu_model_runner.py | 18 +++++++++--------- vllm/v1/worker/gpu_model_runner.py | 3 ++- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index c2ea7d653b3d..ab7aa02823ab 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -650,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config(): kv_cache_config_full_attention = KVCacheConfig( num_blocks=int(1024 * 1.5), - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), @@ -662,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config(): kv_cache_config_sliding_window = KVCacheConfig( num_blocks=129 * 3, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec([f"layer_{i}" for i in range(32)], sliding_window_spec), @@ -674,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config(): kv_cache_config_hybrid_model = KVCacheConfig( num_blocks=(1024 + 129) * 3, - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index de9fc6698781..caacb1652e9a 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -519,9 +519,9 @@ def test_init_kv_cache_without_kv_sharing(): kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, available_memory) assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.tensors) == 2 - assert kv_cache_config.tensors[layer_0].size == available_memory // 2 - assert kv_cache_config.tensors[layer_1].size == available_memory // 2 + assert len(kv_cache_config.kv_cache_tensors) == 2 + assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 + assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 max_context_len =\ estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) @@ -531,9 +531,9 @@ def test_init_kv_cache_without_kv_sharing(): # important: override tensor size to prevent large mem alloc during test # this will only allocate 2 block worth of memory (2 * 32kb) kv_cache_config.num_blocks = 1 - for layer in kv_cache_config.tensors: - kv_cache_config.tensors[layer].size =\ - kv_cache_spec[layer].page_size_bytes + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + kv_cache_tensor.size = ( + kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) runner.initialize_kv_cache(kv_cache_config) @@ -590,10 +590,10 @@ def test_init_kv_cache_with_kv_sharing_valid(): kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, available_memory) assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.tensors) == 1 + assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache # compared to no KV sharing - assert kv_cache_config.tensors[layer_0].size == available_memory + assert kv_cache_config.kv_cache_tensors[0].size == available_memory max_context_len =\ estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) @@ -603,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_valid(): # important: override tensor size to prevent large mem alloc during test # this will only allocate 1 block worth of memory (32kb) kv_cache_config.num_blocks = 1 - kv_cache_config.tensors[layer_0].size =\ + kv_cache_config.kv_cache_tensors[0].size =\ kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 65f3436ab4ae..c5b69dbd6ca8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2168,7 +2168,8 @@ def initialize_kv_cache_tensors( bind_kv_cache( kv_caches, - self.vllm_config.compilation_config.static_forward_context, []) + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) return kv_caches def may_reinitialize_input_batch(self, From b5fa8e1632c7a3e7a738cc608b0fc821daa32a19 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 23:16:31 -0700 Subject: [PATCH 43/44] fix kv connector tests Signed-off-by: Chen Zhang --- tests/v1/core/test_scheduler.py | 2 +- .../kv_connector/unit/test_nixl_connector.py | 4 ++-- .../unit/test_remote_decode_lifecycle.py | 4 ++-- .../unit/test_remote_prefill_lifecycle.py | 24 +++++++++---------- tests/v1/kv_connector/unit/utils.py | 10 ++++---- vllm/v1/core/block_pool.py | 4 ++-- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 66860a8b23fe..d348956aa177 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1212,7 +1212,7 @@ def assert_scheduler_empty(scheduler: Scheduler): # value, etc will remain since we lazily evict for prefix cache. for block in scheduler.kv_cache_manager.block_pool.blocks: assert block.ref_cnt == 0 - # assert block._block_hash is None + # assert block._block_hash is None # assert ( # len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block # ) == 0) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 9b257143d69d..622ab6f35db3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -36,8 +36,8 @@ def test_basic_inferface(): req_meta = kv_connector_metadata.requests[request_id] for block_id, block in zip( - req_meta.local_block_ids, scheduler.kv_cache_manager. - single_type_manager.req_to_blocks[request_id]): + req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]): assert block_id == block.block_id diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 52dc21a2cdba..ff36a281c413 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -54,8 +54,8 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. - blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 2312e2135908..a1156306dc4b 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -51,8 +51,8 @@ def test_basic_lifecycle(): assert (block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE) assert len(block_pool.cached_block_hash_to_block) == 0 - blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] for block in blocks: assert block._block_hash is None @@ -87,8 +87,8 @@ def test_basic_lifecycle(): # Confirm the block are actually allocated. num_hashed_blocks = 0 - blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 num_hashed_blocks += (1 if block._block_hash is not None else 0) @@ -261,10 +261,10 @@ def test_no_spurious_prefix_caching(): assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ - request_local.request_id] - remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501 - request_remote.request_id] + local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_local.request_id] + remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks[request_remote.request_id] # Local should have cached blocks (but not all due to preallocate). num_hashed_blocks = 0 @@ -300,8 +300,8 @@ def test_full_block_prompt(): # STEP (1): Initialize a recv. scheduler_output = scheduler.schedule() # All blocks should be allocated. - num_blocks = len(scheduler.kv_cache_manager.single_type_manager. - req_to_blocks[request_id]) + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT scheduler.update_from_output(scheduler_output, model_runner_output) @@ -319,8 +319,8 @@ def test_full_block_prompt(): # We need to recompute the final token of the prompt to generate # the first new token, so we should not have a new block. - num_blocks = len(scheduler.kv_cache_manager.single_type_manager. - req_to_blocks[request_id]) + num_blocks = len(scheduler.kv_cache_manager.coordinator. + single_type_managers[0].req_to_blocks[request_id]) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index e190e956170d..4a9e3a7ad807 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -32,11 +32,11 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len( - scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len( - scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 + assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( @@ -96,7 +96,7 @@ def create_scheduler( block_size = vllm_config.cache_config.block_size kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests - tensors={}, + kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 9aa64962da44..3b2a4f936000 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -85,8 +85,8 @@ def get_cached_block( """ cached_blocks = [] for group_id in kv_cache_group_ids: - cached_blocks_one_group = self.cached_block_hash_to_block[ - BlockHashWithGroupId(block_hash, group_id)] + cached_blocks_one_group = self.cached_block_hash_to_block.get( + BlockHashWithGroupId(block_hash, group_id)) if not cached_blocks_one_group: return None first_block_id = next(iter(cached_blocks_one_group)) From fa2f7bcb1d1bcc8dd650a34df5e3255427ba62d1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Jun 2025 23:54:28 -0700 Subject: [PATCH 44/44] small updates Signed-off-by: Chen Zhang --- vllm/config.py | 2 +- vllm/v1/core/kv_cache_coordinator.py | 10 ++-- vllm/v1/core/kv_cache_manager.py | 4 +- vllm/v1/core/kv_cache_utils.py | 31 +++------- vllm/v1/core/single_type_kv_cache_manager.py | 9 ++- vllm/v1/engine/core.py | 2 - vllm/v1/kv_cache_interface.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 59 ++++++++++---------- 8 files changed, 51 insertions(+), 68 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3b4edda2c9e7..15e1b530dc9e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4472,7 +4472,7 @@ def __post_init__(self): if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, we don't log + # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. if not (current_platform.is_cuda() or current_platform.is_rocm()): # Hybrid KV cache manager is not supported on non-GPU platforms. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index f6c8d569b9c6..993ce4b484f9 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -272,8 +272,8 @@ def verify_and_split_kv_cache_groups(self) -> None: self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size assert self.other_block_size % self.full_attention_block_size == 0, ( - "KVCacheCoordinator assumes the block_size of the full " - "attention layer is divisible by other layers now.") + "KVCacheCoordinator assumes the block_size of full attention " + "layers is divisible by other layers now.") def find_longest_cache_hit( self, @@ -320,10 +320,10 @@ def find_longest_cache_hit( # NOTE: the prefix cache hit length must be a multiply of block_size as # we don't support partial block cache hit yet. The cache hit length - # of other attention is ensured to be a multiple of the block size of + # of other attention is ensured to be a multiply of the block size of # full attention layers in current implementation, because hit_length is - # a multiple of other attention's block size, and other attention's - # block size is a multiple of full attention's block size (verified in + # a multiply of other attention's block size, and other attention's + # block size is a multiply of full attention's block size (verified in # `verify_and_split_kv_cache_groups`). assert hit_length % self.full_attention_block_size == 0 diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index fda74b6b013b..fc701215ba5d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -101,7 +101,7 @@ def __init__( self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config - # Mapping from request ID to kv block hashes of all block sizes. + # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ @@ -172,7 +172,7 @@ def get_computed_blocks(self, if self.log_stats: assert self.prefix_cache_stats is not None - self.prefix_cache_stats.queries += len(request.all_token_ids) + self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_new_computed_tokens return KVCacheBlocks(computed_blocks), num_new_computed_tokens diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a1c3661cc87a..6d4bcfe64a35 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -26,7 +26,6 @@ class BlockHash(NamedTuple): hash collisions when the hash value is the same. By using SHA256 however, hash collisions are practically impossible. """ - # Hash value of the block in an integer. hash_value: int # Token IDs in the block. @@ -165,10 +164,10 @@ def reset_hash(self): def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = self.prev_free_block.block_id \ - if self.prev_free_block else None - next_block_id = self.next_free_block.block_id \ - if self.next_free_block else None + prev_block_id = (self.prev_free_block.block_id + if self.prev_free_block else None) + next_block_id = (self.next_free_block.block_id + if self.next_free_block else None) return (f"KVCacheBlock(block_id={self.block_id}, " f"ref_cnt={self.ref_cnt}, " f"_block_hash={self._block_hash}, " @@ -620,7 +619,7 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: True if all layers have the same type, False otherwise. """ - layer_keys = {layer.type_id for layer in kv_cache_spec.values()} + layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) return len(layer_keys) == 1 @@ -652,7 +651,6 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int, num_layers: The number of layers available_memory: Memory available for KV cache in bytes. page_size: The page size of the KV cache. - """ num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) @@ -694,13 +692,6 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), available_memory, page_size) - if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) - per_layer_size = page_size * num_blocks # All layers have the same KV cache spec, so we create one kv cache group # for all layers. @@ -878,14 +869,10 @@ def _get_kv_cache_config_uniform_page_size( num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" - # TODO in this PR: Now just copy from the uniform type implementation. - # Update after https://github.com/vllm-project/vllm/pull/19029 - max_concurrency = num_tokens / vllm_config.model_config.max_model_len - logger.info( - "Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, - max_concurrency, - ) + max_concurrency = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) return kv_cache_config diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7680517c3d12..98d758f820ad 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -232,9 +232,8 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks`. The removed - blocks should be replaced by null_block. Return the removed blocks in - eviction order, where the first returned block should be evicted first. + Remove the blocks that are no longer needed from `blocks` and free the + blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. Args: @@ -328,8 +327,8 @@ def find_longest_cache_hit( sliding_window_contiguous_blocks += 1 # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # optimize the time complexity from O(len(block_hashes)) to - # O(len(block_hashes) / sliding_window_contiguous_blocks + + # optimize the time complexity from O(max_num_blocks) to + # O(max_num_blocks / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // kv_cache_spec.block_size diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0f5929367d34..f36a491a1970 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -163,8 +163,6 @@ def _initialize_kv_caches( num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 scheduler_kv_cache_config = kv_cache_configs[0] - # TODO in this PR: remove this debug print - print("kv_cache_config", scheduler_kv_cache_config) # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6e25455c95a5..e938f3bfc671 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -157,7 +157,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass class KVCacheTensor: """ - A class for specifying the KV cache tensor. + A class for specifying how the workers should initialize the KV cache. """ size: int # size of the KV cache tensor in bytes shared_by: list[str] # layer names that share the same KV cache tensor diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c5b69dbd6ca8..def80c5421c5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2056,6 +2056,35 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + def may_reinitialize_input_batch(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + if block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + ) + def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -2172,35 +2201,6 @@ def initialize_kv_cache_tensors( self.kv_caches) return kv_caches - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: - """ - Re-initialize the input batch if the block sizes are different from - `[self.cache_config.block_size]`. This usually happens when there - are multiple KV cache groups. - - Args: - kv_cache_config: The KV cache configuration. - """ - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - ] - if block_sizes != [self.cache_config.block_size]: - assert self.cache_config.cpu_offload_gb == 0, ( - "Cannot re-initialize the input batch when CPU weight " - "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=block_sizes, - ) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2211,7 +2211,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) self.initialize_attn_backend(kv_cache_config) - kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle():