diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index de06da0fc11c..101a2379be37 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): def test_sliding_window_possible_cached_prefix(): + block_size = 2 sliding_window_spec = SlidingWindowSpec( - block_size=2, + block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, @@ -44,7 +45,9 @@ def run_one_case(block_is_cached, expect_length): i: block_pool.blocks[i + 10] } - computed_blocks = manager.find_longest_cache_hit(block_hash_list) + computed_blocks = manager.find_longest_cache_hit( + block_hash_list, + len(block_hash_list) * block_size) assert len(computed_blocks) == expect_length assert all(block == block_pool.null_block diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b34b53155cc3..61ccb5311b2d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -146,21 +146,16 @@ def get_computed_blocks(self, assert self.prefix_cache_stats is not None self.prefix_cache_stats.requests += 1 - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None - - computed_blocks = ( - self.single_type_manager.find_longest_cache_hit(block_hashes)) + # NOTE: When all tokens hit the cache, we must recompute the last token + # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. + # This can trigger recomputation of an entire block, rather than just + # the single last token, because allocate_slots() requires + # num_computed_tokens to be block-size aligned. Removing this limitation + # could slightly improve performance in the future. + max_cache_hit_length = request.num_tokens - 1 + + computed_blocks = self.single_type_manager.find_longest_cache_hit( + block_hashes, max_cache_hit_length) # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. @@ -171,12 +166,6 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_computed_tokens - if last_block_hash is not None: - # Add back the last block hash if it was removed. - # NOTE: Because block_hashes is cached in req_to_block_hashes, - # we shouldn't modify it directly. - block_hashes.append(last_block_hash) - return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 3fd3cb2841e0..0223c9ceec8d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -187,17 +187,19 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError @abstractmethod - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: """ - Get the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. if eagle is enabled, drop the last matched - block to force recompute the last block to get the required hidden - states for eagle drafting head. Need to be customized for each attention - type. + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. If no cache hit is found, return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. + Need to be customized for each attention type. Args: block_hashes: The block hashes of the request. + max_length: The maximum length of the cache hit prefix. + Returns: A list of cached blocks with skipped blocks replaced by null block. For example, sliding window manager should return a list like @@ -226,10 +228,12 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: computed_blocks: list[KVCacheBlock] = [] - for block_hash in block_hashes: + max_num_blocks = max_length // self.block_size + for i in range(max_num_blocks): + block_hash = block_hashes[i] # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. @@ -276,19 +280,20 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, self.sliding_window_contiguous_blocks += 1 self._null_block = block_pool.null_block - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # optimize the time complexity from O(len(block_hashes)) to - # O(len(block_hashes) / sliding_window_contiguous_blocks + + # optimize the time complexity from O(max_num_blocks) to + # O(max_num_blocks / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - computed_blocks = [self._null_block] * len(block_hashes) + max_num_blocks = max_length // self.block_size + computed_blocks = [self._null_block] * max_num_blocks num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. - for i in range(len(block_hashes) - 1, -1, -1): + for i in range(max_num_blocks - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( block_hashes[i]): computed_blocks[i] = cached_block