1616if TYPE_CHECKING :
1717 from vllm .attention .backends .abstract import AttentionMetadata
1818 from vllm .forward_context import ForwardContext
19- from vllm .v1 .core .kv_cache_manager import KVCacheManager
20- from vllm .v1 .core .kv_cache_utils import KVCacheBlock
2119 from vllm .v1 .request import Request
2220
2321logger = init_logger (__name__ )
@@ -152,7 +150,7 @@ def inject_kv_into_layer(
152150 kv_cache_layer = attn_layer .kv_cache [\
153151 forward_context .virtual_engine ]
154152
155- filename = self .generate_filename_debug (
153+ filename = self ._generate_filename_debug (
156154 layer_name , request .token_ids )
157155 kv_cache = safetensors .torch .load_file (
158156 filename )["kv_cache" ].cuda ()
@@ -201,7 +199,7 @@ def extract_kv_from_layer(
201199 assert isinstance (connector_metadata , SharedStorageConnectorMetadata )
202200 for request in connector_metadata .requests :
203201 if request .is_store :
204- filename = self .generate_filename_debug (
202+ filename = self ._generate_filename_debug (
205203 layer_name , request .token_ids )
206204 kv_cache = extract_kv_from_layer (kv_layer ,
207205 request .slot_mapping )
@@ -211,78 +209,47 @@ def extract_kv_from_layer(
211209 def wait_for_save (self ):
212210 return
213211
214- def get_external_prefix_cache_blocks (
212+ def get_num_matched_tokens (
215213 self ,
216214 request : "Request" ,
217- computed_blocks : list ["KVCacheBlock" ],
218215 num_computed_tokens : int ,
219- kv_cache_manager : "KVCacheManager" ,
220- ) -> list ["KVCacheBlock" ]:
221- """Get the external prefix cache blocks from the connector.
222-
223- This function may change the state of the connector, which will be
224- used by `build_connector_meta` later.
225-
226- Args:
227- request (Request): the request object.
228- computed_blocks (list[KVCacheBlock]): the 'local' computed blocks.
229- num_computed_tokens (int): the number of 'local' computed tokens.
230- kv_cache_manager (KVCacheManager): the KV cache manager to
231- allocate/free the blocks if needed.
232-
233- Returns:
234- The updated list of the computed blocks (appended with the remote
235- cached blocks)
216+ ) -> int :
236217 """
218+ Check for external KV cache hit.
219+
220+ Returns the number of tokens that can be loaded from the
221+ external KV cache beyond what is already computed.
222+ """
223+
237224 # NOTE: in this debug implementation, we assume that the prompt is
238225 # cached_prompt + newly_generated_single_token
239226 # Therefore, we use prompt_token_ids[:-1] to determine the folder name
240227
241228 # NOTE: in current v1 scheduler, the num_computed_tokens is aligned
242229 # with the block granularity. And it expects the returned blocks and
243230 # num_computed_tokens to also be aligned with the block granularity.
244- if not self .found_match_for_request (request ):
245- return computed_blocks
231+ if not self ._found_match_for_request (request ):
232+ return 0
233+
234+ logger .info ("External Cache Hit!" )
246235
247236 # Now, first num_tokens_to_check tokens are hit, we need to prepare
248237 # the metadata for the worker connector to correctly load the KV
249-
250- logger .info ("Hit the cache! Allocate new blocks!" )
251238 num_tokens_to_check = align_to_block_size (
252239 len (request .prompt_token_ids ) - 1 , self ._block_size )
253- need_to_allocate = num_tokens_to_check - num_computed_tokens
254- if need_to_allocate > 0 :
255- # HACK: We don't want the scheduler see the blocks are allocated
256- # and associated with the current request. Instead, we want the
257- # scheduler find that the blocks are already allocated and they
258- # are associated with some other requests (i.e., the case of
259- # prefix caching.
260-
261- # HACK: KVCacheManager.allocate_slots will pre-allocate a few
262- # blocks, which will cause problems in the later allocations.
263- # We should make sure the pre allocation does not happen.
264- old_req_id = request .request_id
265- request .request_id = "temp-req-id-for-connector"
266- allocated_blocks = kv_cache_manager .allocate_slots (
267- request ,
268- need_to_allocate ,
269- computed_blocks ,
270- skip_preallocate = True ,
271- skip_inc_ref_count = True )
272- request .request_id = old_req_id
273- kv_cache_manager .req_to_blocks .pop ("temp-req-id-for-connector" )
274- kv_cache_manager .num_cached_block .pop ("temp-req-id-for-connector" )
275-
276- num_expected_blocks = need_to_allocate // self ._block_size
277- if len (allocated_blocks ) > num_expected_blocks :
278- logger .error ("Detected pre-allocated blocks in the connector!"
279- "This should not happen!" )
280- allocated_blocks = allocated_blocks [:num_expected_blocks ]
281240
241+ return num_tokens_to_check - num_computed_tokens
242+
243+ def update_state_after_alloc (self , request : Request ,
244+ num_allocated_blocks : int ):
245+ """
246+ Update KVConnector state after temporary buffer alloc.
247+
248+ For SharedStorageConnector, update _request_needs_load
249+ if the CacheManager this allocated blocks for us.
250+ """
251+ if num_allocated_blocks > 0 :
282252 self ._requests_need_load .append (request .request_id )
283- return computed_blocks + allocated_blocks
284- else :
285- return computed_blocks
286253
287254 def build_connector_meta (
288255 self , scheduler_output : SchedulerOutput ) -> KVConnectorMetadata :
@@ -302,7 +269,7 @@ def build_connector_meta(
302269 # NOTE: here, we set the store and load being exclusive,
303270 # but in LMCache use case, a single request can have both
304271 # store and load status
305- if not self .found_match_for_request (request ):
272+ if not self ._found_match_for_request (request ):
306273 meta .add_request (request , self ._block_size , is_store = True )
307274
308275 self ._requests_need_load .clear ()
@@ -312,20 +279,20 @@ def build_connector_meta(
312279 # Helper functions
313280 # ==============================
314281
315- def found_match_for_request (
282+ def _found_match_for_request (
316283 self ,
317284 request : "Request" ,
318285 ) -> bool :
319286 """Check if the cache is hit for the request.
320287 """
321288 num_tokens_to_check = align_to_block_size (
322289 len (request .prompt_token_ids ) - 1 , self ._block_size )
323- foldername = self .generate_foldername_debug (torch .tensor (
290+ foldername = self ._generate_foldername_debug (torch .tensor (
324291 request .prompt_token_ids )[:num_tokens_to_check ],
325- create_folder = False )
292+ create_folder = False )
326293 return os .path .exists (foldername )
327294
328- def generate_foldername_debug (
295+ def _generate_foldername_debug (
329296 self ,
330297 input_ids : torch .Tensor ,
331298 create_folder = False ,
@@ -340,16 +307,16 @@ def generate_foldername_debug(
340307 os .makedirs (foldername , exist_ok = True )
341308 return foldername
342309
343- def generate_filename_debug (
310+ def _generate_filename_debug (
344311 self ,
345312 layer_name : str ,
346313 input_ids : torch .Tensor ,
347314 ) -> str :
348315 """Generate a file name based on the layer name and the hash
349316 of the bytes of the input ids.
350317 """
351- foldername = self .generate_foldername_debug (input_ids ,
352- create_folder = True )
318+ foldername = self ._generate_foldername_debug (input_ids ,
319+ create_folder = True )
353320 return os .path .join (foldername , f"{ layer_name } .safetensors" )
354321
355322
0 commit comments