From 6d7bea73e741a196ac7dc1baf43ab2861f35d3ac Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 12 May 2025 08:07:37 -0700 Subject: [PATCH 1/3] last block Signed-off-by: Chen Zhang --- tests/v1/core/test_specialized_manager.py | 7 +++- vllm/v1/core/kv_cache_manager.py | 35 +++++++----------- vllm/v1/core/single_type_kv_cache_manager.py | 39 +++++++++++--------- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index de06da0fc11c..101a2379be37 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): def test_sliding_window_possible_cached_prefix(): + block_size = 2 sliding_window_spec = SlidingWindowSpec( - block_size=2, + block_size=block_size, num_kv_heads=1, head_size=1, dtype=torch.float32, @@ -44,7 +45,9 @@ def run_one_case(block_is_cached, expect_length): i: block_pool.blocks[i + 10] } - computed_blocks = manager.find_longest_cache_hit(block_hash_list) + computed_blocks = manager.find_longest_cache_hit( + block_hash_list, + len(block_hash_list) * block_size) assert len(computed_blocks) == expect_length assert all(block == block_pool.null_block diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ad8468a89dc5..9113b71edc62 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -134,33 +134,26 @@ def get_computed_blocks(self, assert self.prefix_cache_stats is not None self.prefix_cache_stats.requests += 1 - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None - - computed_blocks = ( - self.single_type_manager.find_longest_cache_hit(block_hashes)) + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. This + # have to be achieved by re-computing an entire block because + # allocate_slots() assumes num_computed_tokens is always a + # multiple of the block size. To achieve this, set max_cache_hit_length + # to prompt_length - 1 in this case. + # This limitation can potentially be removed in the future to + # slightly improve the performance. + max_cache_hit_length = request.num_tokens + if max_cache_hit_length % self.block_size == 0: + max_cache_hit_length -= 1 + + computed_blocks = (self.single_type_manager.find_longest_cache_hit( + block_hashes, max_cache_hit_length)) if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) - if last_block_hash is not None: - # Add back the last block hash if it was removed. - # NOTE: Because block_hashes is cached in req_to_block_hashes, - # we shouldn't modify it directly. - block_hashes.append(last_block_hash) - # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 3fd3cb2841e0..f0ed2f18eb81 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -187,17 +187,19 @@ def get_num_common_prefix_blocks(self, request_id: str, raise NotImplementedError @abstractmethod - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: """ - Get the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. if eagle is enabled, drop the last matched - block to force recompute the last block to get the required hidden - states for eagle drafting head. Need to be customized for each attention - type. + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. If no cache hit is found, return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. + Need to be customized for each attention type. Args: block_hashes: The block hashes of the request. + max_length: The maximum length of the cache hit prefix. + Returns: A list of cached blocks with skipped blocks replaced by null block. For example, sliding window manager should return a list like @@ -226,10 +228,12 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] - for block_hash in block_hashes: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: + computed_blocks = [] + max_num_blocks = max_length // self.block_size + for i in range(max_num_blocks): + block_hash = block_hashes[i] # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. @@ -276,19 +280,20 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, self.sliding_window_contiguous_blocks += 1 self._null_block = block_pool.null_block - def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + max_length: int) -> list[KVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to - # optimize the time complexity from O(len(block_hashes)) to - # O(len(block_hashes) / sliding_window_contiguous_blocks + + # optimize the time complexity from O(max_num_blocks) to + # O(max_num_blocks / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - computed_blocks = [self._null_block] * len(block_hashes) + max_num_blocks = max_length // self.block_size + computed_blocks = [self._null_block] * max_num_blocks num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. - for i in range(len(block_hashes) - 1, -1, -1): + for i in range(max_num_blocks - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( block_hashes[i]): computed_blocks[i] = cached_block From c07f1c0b5d4ba74d1643fcd445eefd19e516d1a2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 13 May 2025 11:18:48 +0800 Subject: [PATCH 2/3] Apply suggestions from code review Signed-off-by: Chen Zhang Co-authored-by: Woosuk Kwon --- vllm/v1/core/kv_cache_manager.py | 22 +++++++------------- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c690486dca36..aa441f2295a2 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -146,20 +146,14 @@ def get_computed_blocks(self, assert self.prefix_cache_stats is not None self.prefix_cache_stats.requests += 1 - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, set max_cache_hit_length - # to prompt_length - 1 in this case. - # This limitation can potentially be removed in the future to - # slightly improve the performance. - max_cache_hit_length = request.num_tokens - if max_cache_hit_length % self.block_size == 0: - max_cache_hit_length -= 1 - - computed_blocks = (self.single_type_manager.find_longest_cache_hit( - block_hashes, max_cache_hit_length)) + # NOTE: When all tokens hit the cache, we must recompute the last token to obtain logits. + # Thus, set max_cache_hit_length to prompt_length - 1. This can trigger recomputation of an entire block, + # rather than just the single last token, because allocate_slots() requires num_computed_tokens to be + # block-size aligned. Removing this limitation could slightly improve performance in the future. + max_cache_hit_length = request.num_tokens - 1 + + computed_blocks = self.single_type_manager.find_longest_cache_hit( + block_hashes, max_cache_hit_length) # NOTE(woosuk): Since incomplete blocks are not eligible for # sharing, `num_computed_tokens` is always a multiple of # `block_size`. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index f0ed2f18eb81..0223c9ceec8d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -230,7 +230,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit(self, block_hashes: list[BlockHashType], max_length: int) -> list[KVCacheBlock]: - computed_blocks = [] + computed_blocks: list[KVCacheBlock] = [] max_num_blocks = max_length // self.block_size for i in range(max_num_blocks): block_hash = block_hashes[i] From b33c41331ddd96f361cf4b4f04e7bfe46e7df2fc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 12 May 2025 20:58:40 -0700 Subject: [PATCH 3/3] fix precommit Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_manager.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index aa441f2295a2..61ccb5311b2d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -146,10 +146,12 @@ def get_computed_blocks(self, assert self.prefix_cache_stats is not None self.prefix_cache_stats.requests += 1 - # NOTE: When all tokens hit the cache, we must recompute the last token to obtain logits. - # Thus, set max_cache_hit_length to prompt_length - 1. This can trigger recomputation of an entire block, - # rather than just the single last token, because allocate_slots() requires num_computed_tokens to be - # block-size aligned. Removing this limitation could slightly improve performance in the future. + # NOTE: When all tokens hit the cache, we must recompute the last token + # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. + # This can trigger recomputation of an entire block, rather than just + # the single last token, because allocate_slots() requires + # num_computed_tokens to be block-size aligned. Removing this limitation + # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks = self.single_type_manager.find_longest_cache_hit(