@@ -38,7 +38,8 @@ def __init__(
3838 from lmcache .integration .vllm .utils import ENGINE_NAME
3939 from lmcache .integration .vllm .vllm_adapter import (
4040 RetrieveStatus , StoreStatus , init_lmcache_engine ,
41- lmcache_retrieve_kv , lmcache_should_store , lmcache_store_kv )
41+ lmcache_retrieve_kv , lmcache_should_retrieve , lmcache_should_store ,
42+ lmcache_store_kv )
4243 logger .info ("Initializing LMCacheConfig under kv_transfer_config %s" ,
4344 self .transfer_config )
4445
@@ -54,6 +55,7 @@ def __init__(
5455 self .cache_config = config .cache_config
5556 self .lmcache_retrieve_kv = lmcache_retrieve_kv
5657 self .lmcache_store_kv = lmcache_store_kv
58+ self .lmcache_should_retrieve = lmcache_should_retrieve
5759 self .lmcache_should_store = lmcache_should_store
5860 self .store_status = StoreStatus
5961 self .retrieve_status = RetrieveStatus
@@ -65,15 +67,11 @@ def recv_kv_caches_and_hidden_states(
6567 ) -> Tuple [Union [torch .Tensor , IntermediateTensors ], bool ,
6668 "ModelInputForGPUWithSamplingMetadata" ]:
6769
68- hidden_or_intermediate_states = None
69-
70- # TODO (Jiayi): Need to support chunked prefill
71- retrieve_status = self .retrieve_status .PREFILL
72-
73- model_input , bypass_model_exec = self .lmcache_retrieve_kv (
74- model_executable , model_input , self .cache_config , kv_caches ,
75- retrieve_status )
76-
70+ retrieve_status = self .lmcache_should_retrieve (model_input )
71+ model_input , bypass_model_exec , hidden_or_intermediate_states = \
72+ self .lmcache_retrieve_kv (
73+ model_executable , model_input , self .cache_config , kv_caches ,
74+ retrieve_status )
7775 return hidden_or_intermediate_states , bypass_model_exec , model_input
7876
7977 def send_kv_caches_and_hidden_states (
@@ -84,15 +82,7 @@ def send_kv_caches_and_hidden_states(
8482 hidden_or_intermediate_states : Union [torch .Tensor ,
8583 IntermediateTensors ],
8684 ) -> None :
87- num_reqs = 0
88- seq_group_list = model_input .sampling_metadata .seq_groups
89- assert seq_group_list is not None
90- for seq_group in seq_group_list :
91- seq_ids = seq_group .seq_ids
92- for seq_id in seq_ids :
93- num_reqs += 1
94-
95- # TODO (Jiayi): Only normal prefill is supported for now
85+
9686 store_status = self .lmcache_should_store (model_input )
9787 self .lmcache_store_kv (
9888 self .model_config ,
0 commit comments