@@ -187,17 +187,19 @@ def get_num_common_prefix_blocks(self, request_id: str,
187187 raise NotImplementedError
188188
189189 @abstractmethod
190- def find_longest_cache_hit (
191- self , block_hashes : list [ BlockHashType ] ) -> list [KVCacheBlock ]:
190+ def find_longest_cache_hit (self , block_hashes : list [ BlockHashType ],
191+ max_length : int ) -> list [KVCacheBlock ]:
192192 """
193- Get the longest cache hit prefix of the blocks. If no cache hit is
194- found, return an empty list. if eagle is enabled, drop the last matched
195- block to force recompute the last block to get the required hidden
196- states for eagle drafting head. Need to be customized for each attention
197- type.
193+ Get the longest cache hit prefix of the blocks that is not longer than
194+ `max_length`. If no cache hit is found, return an empty list.
195+ If eagle is enabled, drop the last matched block to force recompute the
196+ last block to get the required hidden states for eagle drafting head.
197+ Need to be customized for each attention type.
198198
199199 Args:
200200 block_hashes: The block hashes of the request.
201+ max_length: The maximum length of the cache hit prefix.
202+
201203 Returns:
202204 A list of cached blocks with skipped blocks replaced by null block.
203205 For example, sliding window manager should return a list like
@@ -226,10 +228,12 @@ def remove_skipped_blocks(self, request_id: str,
226228
227229class FullAttentionManager (SingleTypeKVCacheManager ):
228230
229- def find_longest_cache_hit (
230- self , block_hashes : list [ BlockHashType ] ) -> list [KVCacheBlock ]:
231+ def find_longest_cache_hit (self , block_hashes : list [ BlockHashType ],
232+ max_length : int ) -> list [KVCacheBlock ]:
231233 computed_blocks : list [KVCacheBlock ] = []
232- for block_hash in block_hashes :
234+ max_num_blocks = max_length // self .block_size
235+ for i in range (max_num_blocks ):
236+ block_hash = block_hashes [i ]
233237 # block_hashes is a chain of block hashes. If a block hash is not
234238 # in the cached_block_hash_to_id, the following block hashes are
235239 # not computed yet for sure.
@@ -276,19 +280,20 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
276280 self .sliding_window_contiguous_blocks += 1
277281 self ._null_block = block_pool .null_block
278282
279- def find_longest_cache_hit (
280- self , block_hashes : list [ BlockHashType ] ) -> list [KVCacheBlock ]:
283+ def find_longest_cache_hit (self , block_hashes : list [ BlockHashType ],
284+ max_length : int ) -> list [KVCacheBlock ]:
281285 # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
282- # optimize the time complexity from O(len(block_hashes) ) to
283- # O(len(block_hashes) / sliding_window_contiguous_blocks +
286+ # optimize the time complexity from O(max_num_blocks ) to
287+ # O(max_num_blocks / sliding_window_contiguous_blocks +
284288 # sliding_window_contiguous_blocks),
285289 # which is good for low cache hit rate scenarios.
286- computed_blocks = [self ._null_block ] * len (block_hashes )
290+ max_num_blocks = max_length // self .block_size
291+ computed_blocks = [self ._null_block ] * max_num_blocks
287292 num_contiguous_blocks = 0
288293
289294 match_found = False
290295 # Search from right to left and early stop when a match is found.
291- for i in range (len ( block_hashes ) - 1 , - 1 , - 1 ):
296+ for i in range (max_num_blocks - 1 , - 1 , - 1 ):
292297 if cached_block := self .block_pool .get_cached_block (
293298 block_hashes [i ]):
294299 computed_blocks [i ] = cached_block
0 commit comments