Skip to content

Commit 8e1eadc

Browse files
updated
Signed-off-by: [email protected] <[email protected]>
1 parent 05349a5 commit 8e1eadc

File tree

4 files changed

+133
-104
lines changed

4 files changed

+133
-104
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from vllm.attention.backends.abstract import AttentionMetadata
2020
from vllm.config import VllmConfig
2121
from vllm.forward_context import ForwardContext
22-
from vllm.v1.core.kv_cache_manager import KVCacheManager
23-
from vllm.v1.core.kv_cache_utils import KVCacheBlock
2422
from vllm.v1.request import Request
2523

2624

@@ -148,32 +146,33 @@ def wait_for_save(self):
148146
# Scheduler-side methods
149147
# ==============================
150148
@abstractmethod
151-
def get_external_prefix_cache_blocks(
149+
def get_num_matched_tokens(
152150
self,
153151
request: "Request",
154-
computed_blocks: list["KVCacheBlock"],
155152
num_computed_tokens: int,
156-
kv_cache_manager: "KVCacheManager",
157-
) -> list["KVCacheBlock"]:
153+
) -> int:
158154
"""
159-
Get the external prefix cache blocks from the connector.
160-
161-
This function may change the state of the connector, which will
162-
be used by `build_connector_meta` later.
163-
164-
This function will also allocate/free the blocks dynamically when
165-
there is remote cache hit.
166-
155+
Check for external KV cache hit.
156+
167157
Args:
168158
request (Request): the request object.
169-
computed_blocks (list[KVCacheBlock]): the 'local' computed blocks.
170-
num_computed_tokens (int): the number of 'local' computed tokens.
171-
kv_cache_manager (KVCacheManager): the KV cache manager to
172-
allocate/free the blocks if needed.
159+
num_computed_tokens (int): the number of locally
160+
computed tokens for this request
173161
174162
Returns:
175-
The updated list of the computed blocks (appended with the remote
176-
cached blocks)
163+
the number of tokens that can be loaded from the
164+
external KV cache beyond what is already computed.
165+
"""
166+
pass
167+
168+
@abstractmethod
169+
def update_state_after_alloc(self, request: Request,
170+
num_allocated_blocks: int):
171+
"""
172+
Update KVConnector state after temporary buffer alloc.
173+
174+
For SharedStorageConnector, update _request_needs_load
175+
if the CacheManager this allocated blocks for us.
177176
"""
178177
pass
179178

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 33 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
if TYPE_CHECKING:
1717
from vllm.attention.backends.abstract import AttentionMetadata
1818
from vllm.forward_context import ForwardContext
19-
from vllm.v1.core.kv_cache_manager import KVCacheManager
20-
from vllm.v1.core.kv_cache_utils import KVCacheBlock
2119
from vllm.v1.request import Request
2220

2321
logger = init_logger(__name__)
@@ -152,7 +150,7 @@ def inject_kv_into_layer(
152150
kv_cache_layer = attn_layer.kv_cache[\
153151
forward_context.virtual_engine]
154152

155-
filename = self.generate_filename_debug(
153+
filename = self._generate_filename_debug(
156154
layer_name, request.token_ids)
157155
kv_cache = safetensors.torch.load_file(
158156
filename)["kv_cache"].cuda()
@@ -201,7 +199,7 @@ def extract_kv_from_layer(
201199
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
202200
for request in connector_metadata.requests:
203201
if request.is_store:
204-
filename = self.generate_filename_debug(
202+
filename = self._generate_filename_debug(
205203
layer_name, request.token_ids)
206204
kv_cache = extract_kv_from_layer(kv_layer,
207205
request.slot_mapping)
@@ -211,78 +209,47 @@ def extract_kv_from_layer(
211209
def wait_for_save(self):
212210
return
213211

214-
def get_external_prefix_cache_blocks(
212+
def get_num_matched_tokens(
215213
self,
216214
request: "Request",
217-
computed_blocks: list["KVCacheBlock"],
218215
num_computed_tokens: int,
219-
kv_cache_manager: "KVCacheManager",
220-
) -> list["KVCacheBlock"]:
221-
"""Get the external prefix cache blocks from the connector.
222-
223-
This function may change the state of the connector, which will be
224-
used by `build_connector_meta` later.
225-
226-
Args:
227-
request (Request): the request object.
228-
computed_blocks (list[KVCacheBlock]): the 'local' computed blocks.
229-
num_computed_tokens (int): the number of 'local' computed tokens.
230-
kv_cache_manager (KVCacheManager): the KV cache manager to
231-
allocate/free the blocks if needed.
232-
233-
Returns:
234-
The updated list of the computed blocks (appended with the remote
235-
cached blocks)
216+
) -> int:
236217
"""
218+
Check for external KV cache hit.
219+
220+
Returns the number of tokens that can be loaded from the
221+
external KV cache beyond what is already computed.
222+
"""
223+
237224
# NOTE: in this debug implementation, we assume that the prompt is
238225
# cached_prompt + newly_generated_single_token
239226
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
240227

241228
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
242229
# with the block granularity. And it expects the returned blocks and
243230
# num_computed_tokens to also be aligned with the block granularity.
244-
if not self.found_match_for_request(request):
245-
return computed_blocks
231+
if not self._found_match_for_request(request):
232+
return 0
233+
234+
logger.info("External Cache Hit!")
246235

247236
# Now, first num_tokens_to_check tokens are hit, we need to prepare
248237
# the metadata for the worker connector to correctly load the KV
249-
250-
logger.info("Hit the cache! Allocate new blocks!")
251238
num_tokens_to_check = align_to_block_size(
252239
len(request.prompt_token_ids) - 1, self._block_size)
253-
need_to_allocate = num_tokens_to_check - num_computed_tokens
254-
if need_to_allocate > 0:
255-
# HACK: We don't want the scheduler see the blocks are allocated
256-
# and associated with the current request. Instead, we want the
257-
# scheduler find that the blocks are already allocated and they
258-
# are associated with some other requests (i.e., the case of
259-
# prefix caching.
260-
261-
# HACK: KVCacheManager.allocate_slots will pre-allocate a few
262-
# blocks, which will cause problems in the later allocations.
263-
# We should make sure the pre allocation does not happen.
264-
old_req_id = request.request_id
265-
request.request_id = "temp-req-id-for-connector"
266-
allocated_blocks = kv_cache_manager.allocate_slots(
267-
request,
268-
need_to_allocate,
269-
computed_blocks,
270-
skip_preallocate=True,
271-
skip_inc_ref_count=True)
272-
request.request_id = old_req_id
273-
kv_cache_manager.req_to_blocks.pop("temp-req-id-for-connector")
274-
kv_cache_manager.num_cached_block.pop("temp-req-id-for-connector")
275-
276-
num_expected_blocks = need_to_allocate // self._block_size
277-
if len(allocated_blocks) > num_expected_blocks:
278-
logger.error("Detected pre-allocated blocks in the connector!"
279-
"This should not happen!")
280-
allocated_blocks = allocated_blocks[:num_expected_blocks]
281240

241+
return num_tokens_to_check - num_computed_tokens
242+
243+
def update_state_after_alloc(self, request: Request,
244+
num_allocated_blocks: int):
245+
"""
246+
Update KVConnector state after temporary buffer alloc.
247+
248+
For SharedStorageConnector, update _request_needs_load
249+
if the CacheManager this allocated blocks for us.
250+
"""
251+
if num_allocated_blocks > 0:
282252
self._requests_need_load.append(request.request_id)
283-
return computed_blocks + allocated_blocks
284-
else:
285-
return computed_blocks
286253

287254
def build_connector_meta(
288255
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
@@ -302,7 +269,7 @@ def build_connector_meta(
302269
# NOTE: here, we set the store and load being exclusive,
303270
# but in LMCache use case, a single request can have both
304271
# store and load status
305-
if not self.found_match_for_request(request):
272+
if not self._found_match_for_request(request):
306273
meta.add_request(request, self._block_size, is_store=True)
307274

308275
self._requests_need_load.clear()
@@ -312,20 +279,20 @@ def build_connector_meta(
312279
# Helper functions
313280
# ==============================
314281

315-
def found_match_for_request(
282+
def _found_match_for_request(
316283
self,
317284
request: "Request",
318285
) -> bool:
319286
"""Check if the cache is hit for the request.
320287
"""
321288
num_tokens_to_check = align_to_block_size(
322289
len(request.prompt_token_ids) - 1, self._block_size)
323-
foldername = self.generate_foldername_debug(torch.tensor(
290+
foldername = self._generate_foldername_debug(torch.tensor(
324291
request.prompt_token_ids)[:num_tokens_to_check],
325-
create_folder=False)
292+
create_folder=False)
326293
return os.path.exists(foldername)
327294

328-
def generate_foldername_debug(
295+
def _generate_foldername_debug(
329296
self,
330297
input_ids: torch.Tensor,
331298
create_folder=False,
@@ -340,16 +307,16 @@ def generate_foldername_debug(
340307
os.makedirs(foldername, exist_ok=True)
341308
return foldername
342309

343-
def generate_filename_debug(
310+
def _generate_filename_debug(
344311
self,
345312
layer_name: str,
346313
input_ids: torch.Tensor,
347314
) -> str:
348315
"""Generate a file name based on the layer name and the hash
349316
of the bytes of the input ids.
350317
"""
351-
foldername = self.generate_foldername_debug(input_ids,
352-
create_folder=True)
318+
foldername = self._generate_foldername_debug(input_ids,
319+
create_folder=True)
353320
return os.path.join(foldername, f"{layer_name}.safetensors")
354321

355322

vllm/v1/core/kv_cache_manager.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def __init__(
8484
# data for reempted ones.
8585
self.num_cached_block: dict[str, int] = {}
8686
self.prefix_cache_stats = PrefixCacheStats()
87-
self.connector = connector
87+
88+
# KVConnector: buffer reqs for KVConnector. We write
89+
# the external KVs to the "buffer" req and leverage
90+
# prefix caching to share with the "real" req
91+
self.kv_connector_buffer_reqs: list[Request] = []
8892

8993
@property
9094
def usage(self) -> float:
@@ -159,13 +163,6 @@ def get_computed_blocks(
159163
# we shouldn't modify it directly.
160164
block_hashes.append(last_block_hash)
161165

162-
# Check the remote cache for the external prefix cache blocks.
163-
if self.connector is not None:
164-
computed_blocks =\
165-
self.connector.get_external_prefix_cache_blocks(
166-
request, computed_blocks,
167-
len(computed_blocks) * self.block_size, self)
168-
169166
# NOTE(woosuk): Since incomplete blocks are not eligible for
170167
# sharing, `num_computed_tokens` is always a multiple of
171168
# `block_size`.
@@ -178,7 +175,6 @@ def allocate_slots(
178175
num_tokens: int,
179176
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
180177
skip_preallocate: bool = False,
181-
skip_inc_ref_count: bool = False,
182178
) -> Optional[list[KVCacheBlock]]:
183179
"""Add slots for a request with new tokens to append.
184180
@@ -188,11 +184,7 @@ def allocate_slots(
188184
not include the tokens that have already been computed.
189185
new_computed_blocks: A list of new computed blocks just hitting the
190186
prefix caching.
191-
skip_preallocate: Whether to skip preallocating blocks for
192-
the request.
193-
skip_preallocate: Whether to skip incrementing the ref count. This
194-
is useful for the KVConnector to allocate blocks which will be
195-
filled by the remote KVs for a single model step().
187+
skip_preallocate: Whether to skip preallocating blocks.
196188
197189
Blocks layout:
198190
-----------------------------------------------------------------------
@@ -246,12 +238,11 @@ def allocate_slots(
246238
return None
247239

248240
# Touch the computed blocks to make sure they won't be evicted.
249-
if self.enable_caching and not skip_inc_ref_count:
241+
if self.enable_caching:
250242
self.block_pool.touch(new_computed_blocks)
251243
else:
252-
assert not new_computed_blocks, (
253-
"Computed blocks should be empty when "
254-
"prefix caching is disabled")
244+
assert not new_computed_blocks, "Computed blocks should "\
245+
"be empty when prefix caching is disabled"
255246

256247
# Append the new computed blocks to the request blocks until now to
257248
# avoid the case where the new blocks cannot be allocated.
@@ -396,3 +387,56 @@ def free_block_hashes(self, request: Request) -> None:
396387
is finished, not when it is preempted.
397388
"""
398389
self.req_to_block_hashes.pop(request.request_id, None)
390+
391+
def alloc_and_get_external_blocks(
392+
self,
393+
request: "Request",
394+
computed_blocks: list["KVCacheBlock"],
395+
num_computed_tokens: int,
396+
kv_connector: KVConnectorBase_V1,
397+
) -> tuple[list["KVCacheBlock"], int]:
398+
399+
# Check for cache hit.
400+
need_to_allocate = kv_connector.get_num_matched_tokens(
401+
request, num_computed_tokens)
402+
num_allocated_blocks = 0
403+
404+
# Cache hit: allocate buffer.
405+
if need_to_allocate > 0:
406+
# HACK: We don't want the scheduler see the blocks are allocated
407+
# and associated with the current request. Instead, we want the
408+
# scheduler find that the blocks are already allocated and they
409+
# are associated with some other requests (i.e., the case of
410+
# prefix caching.
411+
412+
old_req_id = request.request_id
413+
request.request_id = f"{old_req_id}-buf-for-kv-connector"
414+
allocated_blocks = self.allocate_slots(
415+
request,
416+
need_to_allocate,
417+
computed_blocks,
418+
skip_preallocate=True,
419+
)
420+
request.request_id = old_req_id
421+
422+
num_expected_blocks = need_to_allocate // self.block_size
423+
num_allocated_blocks = len(
424+
allocated_blocks) if allocated_blocks else 0
425+
assert num_allocated_blocks <= num_expected_blocks, ""\
426+
"Detected pre-allocated blocks in the connector! "\
427+
"This should not happen!"
428+
429+
# Update internal state. In case of:
430+
# * SharedStorageConnector: add req_id to _requests_need_load
431+
# so that we know to load this requests KVs later.
432+
kv_connector.update_state_after_alloc(request, num_allocated_blocks)
433+
num_computed_blocks = len(computed_blocks) * self.block_size
434+
return computed_blocks, num_computed_blocks
435+
436+
def free_buffer_requests(self) -> None:
437+
"""Free buffer requests for the KV connector."""
438+
439+
for buffer_req in self.kv_connector_buffer_reqs:
440+
self.free(buffer_req)
441+
self.free_block_hashes(buffer_req)
442+
self.kv_connector_buffer_reqs.clear()

vllm/v1/core/sched/scheduler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,20 @@ def schedule(self) -> SchedulerOutput:
308308
# Get already-cached tokens.
309309
computed_blocks, num_computed_tokens = \
310310
self.kv_cache_manager.get_computed_blocks(request)
311+
312+
# KVConnector: get blocks externally-cached tokens.
313+
# Internally, this allocates a "buffer" req with a prompt
314+
# corresponding to externally cached tokens. In alloc_slots
315+
# below, we will compute a cache hit and thus skip the
316+
# computation for externally cached tokens.
317+
# NOTE: since this allocates temporary buffer requests,
318+
# we must call kv_cache_manager.free_buffer_requests() below.
319+
if self.connector is not None:
320+
computed_blocks, num_computed_tokens = \
321+
self.kv_cache_manager.alloc_and_get_external_blocks(
322+
request, computed_blocks,
323+
num_computed_tokens, self.connector)
324+
311325
# Number of tokens to be scheduled.
312326
# We use `request.num_tokens` instead of
313327
# `request.num_prompt_tokens` to consider the resumed requests,
@@ -467,6 +481,11 @@ def schedule(self) -> SchedulerOutput:
467481
for req_id, num_scheduled_token in num_scheduled_tokens.items():
468482
self.requests[req_id].num_computed_tokens += num_scheduled_token
469483

484+
# KVConnector: once we have allocated the buffer blocks to the
485+
# "real" requests (via prefix caching), free the tmp buffer reqs.
486+
if self.connector is not None:
487+
self.kv_cache_manager.free_buffer_requests()
488+
470489
self.finished_req_ids = set()
471490
return scheduler_output
472491

0 commit comments

Comments
 (0)