Skip to content

Commit 3d41b47

Browse files
author
David Ben-David
committed
Fix PR comments
Signed-off-by: David Ben-David <[email protected]>
1 parent c299ff3 commit 3d41b47

File tree

5 files changed

+59
-29
lines changed

5 files changed

+59
-29
lines changed

tests/v1/kv_connector/unit/test_offloading_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def _run(self, decoded_tokens: list[int]):
281281

282282
model_runner_output = create_model_runner_output(
283283
reqs=self.scheduler.running,
284-
finished_sending=list(finished_sending),
285-
finished_recving=list(finished_recving),
284+
finished_sending=finished_sending,
285+
finished_recving=finished_recving,
286286
token_id=token_id)
287287

288288
if self.scheduler.running:

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,16 @@ def get_block_ids_with_load_errors(self) -> set[int]:
236236
Returns:
237237
Set of block IDs that encountered load errors.
238238
Empty set if no load errors occurred.
239+
240+
Notes:
241+
- Applies to both sync- and async-loading requests.
242+
- Async loading: failed blocks may be reported in any forward pass
243+
up to and including the pass where the request ID is returned by
244+
`get_finished()`. Even if failures occur, the request must still
245+
be reported via `get_finished()`, and the failed block IDs must
246+
appear here no later than that same pass.
247+
- Sync loading: failed blocks should be reported in the forward
248+
pass in which they are detected.
239249
"""
240250
return set()
241251

vllm/v1/core/sched/scheduler.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -879,12 +879,12 @@ def update_from_output(
879879
kv_connector_stats = (kv_connector_output.kv_connector_stats
880880
if kv_connector_output else None)
881881

882-
affected_req_ids = None
882+
failed_kv_load_req_ids = None
883883
if kv_connector_output and kv_connector_output.invalid_block_ids:
884884
# These blocks contain externally computed tokens that failed to
885885
# load. Identify affected requests and adjust their computed token
886886
# count to trigger recomputation of the invalid blocks.
887-
affected_req_ids = self._handle_invalid_blocks(
887+
failed_kv_load_req_ids = self._handle_invalid_blocks(
888888
kv_connector_output.invalid_block_ids)
889889

890890
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
@@ -894,7 +894,7 @@ def update_from_output(
894894
stopped_preempted_reqs: set[Request] = set()
895895
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
896896
assert num_tokens_scheduled > 0
897-
if affected_req_ids and req_id in affected_req_ids:
897+
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
898898
# Skip requests that were recovered from KV load failure
899899
continue
900900
request = self.requests.get(req_id)
@@ -1325,11 +1325,30 @@ def _update_from_kv_xfer_finished(self,
13251325

13261326
def _update_requests_with_invalid_blocks(
13271327
self, requests: Iterable[Request],
1328-
invalid_block_ids: set[int]) -> tuple[set[str], int, set[int]]:
1328+
invalid_block_ids: set[int]) -> tuple[set[str], int]:
1329+
"""
1330+
Identify and update requests affected by invalid KV cache blocks.
1331+
1332+
This method scans the given requests, detects those with invalid blocks
1333+
and adjusts their `num_computed_tokens` to the longest valid prefix.
1334+
For observability, it also accumulates the total number of tokens that
1335+
will need to be recomputed across all affected requests.
1336+
1337+
Args:
1338+
requests: The set of requests to scan for invalid blocks.
1339+
invalid_block_ids: IDs of invalid blocks.
1340+
1341+
Returns:
1342+
tuple:
1343+
- affected_req_ids (set[str]): IDs of requests impacted by
1344+
invalid blocks.
1345+
- total_affected_tokens (int): Total number of tokens that must
1346+
be recomputed across all affected requests (for observability).
1347+
"""
13291348
affected_req_ids: set[str] = set()
13301349
total_affected_tokens = 0
13311350
# If a block is invalid and shared by multiple requests in the batch,
1332-
# all requests must be rescheduled, but only the first will recompute
1351+
# these requests must be rescheduled, but only the first will recompute
13331352
# it. This set tracks blocks already marked for recomputation.
13341353
marked_invalid_block_ids: set[int] = set()
13351354
for request in requests:
@@ -1341,12 +1360,14 @@ def _update_requests_with_invalid_blocks(
13411360
# We iterate only over blocks that may contain externally computed
13421361
# tokens
13431362
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
1363+
# Async loading. If num_computed_tokens is set it implies we
1364+
# already processed some block failures for it in a prior step
13441365
req_num_computed_tokens = (
1345-
request.num_computed_tokens if request.request_id
1366+
request.num_computed_tokens if req_id
13461367
in self.failed_recving_kv_req_ids else len(req_block_ids) *
13471368
self.block_size)
13481369
else:
1349-
# In sync load, num_computed_tokens includes new tokens
1370+
# Sync loading. num_computed_tokens includes new tokens
13501371
req_num_computed_tokens = request.num_cached_tokens
13511372

13521373
req_num_computed_blocks = (req_num_computed_tokens +
@@ -1364,6 +1385,8 @@ def _update_requests_with_invalid_blocks(
13641385
# and was already marked for recomputation.
13651386
# This means this request can still consider this block
13661387
# as computed when rescheduled.
1388+
# Currently this only applies to sync loading; Async
1389+
# loading does not yet support block sharing
13671390
continue
13681391

13691392
marked_invalid_block_ids.add(block_id)
@@ -1374,6 +1397,7 @@ def _update_requests_with_invalid_blocks(
13741397
continue
13751398

13761399
marked_invalid_block = True
1400+
# Truncate the computed tokens at the first failed block
13771401
request.num_computed_tokens = idx * self.block_size
13781402
total_affected_tokens += (req_num_computed_tokens -
13791403
request.num_computed_tokens)
@@ -1383,14 +1407,15 @@ def _update_requests_with_invalid_blocks(
13831407
# All invalid blocks of this request are shared with
13841408
# previous requests and will be recomputed by them.
13851409
# Revert to considering only cached tokens as computed.
1410+
# Currently this only applies to sync loading; Async
1411+
# loading does not yet support block sharing
13861412
total_affected_tokens += (request.num_computed_tokens -
13871413
request.num_cached_tokens)
13881414
request.num_computed_tokens = request.num_cached_tokens
13891415

13901416
affected_req_ids.add(request.request_id)
13911417

1392-
return (affected_req_ids, total_affected_tokens,
1393-
marked_invalid_block_ids)
1418+
return (affected_req_ids, total_affected_tokens)
13941419

13951420
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
13961421
total_requests_to_reschedule = 0
@@ -1400,36 +1425,31 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
14001425
async_load_reqs = (
14011426
req for req in self.waiting
14021427
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
1403-
(affected_req_ids, num_tokens_to_reschedule,
1404-
marked_invalid_block_ids) = (
1405-
self._update_requests_with_invalid_blocks(async_load_reqs,
1406-
invalid_block_ids))
1428+
async_affected_req_ids, num_tokens_to_reschedule = (
1429+
self._update_requests_with_invalid_blocks(async_load_reqs,
1430+
invalid_block_ids))
14071431

1408-
total_requests_to_reschedule += len(affected_req_ids)
1432+
total_requests_to_reschedule += len(async_affected_req_ids)
14091433
total_tokens_to_reschedule += num_tokens_to_reschedule
14101434

14111435
# Mark requests with async KV load failures; they will be rescheduled
14121436
# once loading completes
1413-
self.failed_recving_kv_req_ids |= affected_req_ids
1414-
1415-
# Remove async loaded invalid blocks already handled,
1416-
# as they cannot be shared with running requests.
1417-
invalid_block_ids.difference_update(marked_invalid_block_ids)
1437+
self.failed_recving_kv_req_ids |= async_affected_req_ids
14181438

14191439
# --- Handle sync KV loads (running requests) ---
1420-
affected_req_ids, num_tokens_to_reschedule, _ = (
1440+
sync_affected_req_ids, num_tokens_to_reschedule = (
14211441
self._update_requests_with_invalid_blocks(self.running,
14221442
invalid_block_ids))
14231443

1424-
total_requests_to_reschedule += len(affected_req_ids)
1444+
total_requests_to_reschedule += len(sync_affected_req_ids)
14251445
total_tokens_to_reschedule += num_tokens_to_reschedule
14261446

14271447
if total_requests_to_reschedule:
1428-
logger.info(
1448+
logger.warning(
14291449
"Recovered from KV load failure: "
14301450
"%d request(s) rescheduled (%d tokens affected).",
14311451
total_requests_to_reschedule, total_tokens_to_reschedule)
14321452

14331453
# Return the IDs of affected running requests to skip in
14341454
# update_from_output.
1435-
return affected_req_ids
1455+
return sync_affected_req_ids

vllm/v1/worker/gpu_input_batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class CachedRequestState:
4848
def __post_init__(self):
4949
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
5050
self.prompt_token_ids, self.prompt_embeds)
51-
# 'last_generator_offset' and 'last_gelen_last_output_token_ids' are
52-
# used to allow safe rollback in case a sampled token turns out to be
53-
# invalid (e.g., due to KV load errors).
51+
# 'last_generator_offset' and 'len_last_output_token_ids' are used to
52+
# allow safe rollback in case a sampled token turns out to be invalid
53+
# (e.g., due to KV load errors).
5454
self.last_generator_offset = 0 if self.generator else None
5555
self.len_last_output_token_ids = len(self.output_token_ids)
5656

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
636636
resumed_from_preemption = req_data.resumed_from_preemption[i]
637637

638638
# Update the cached states.
639-
if (num_computed_tokens <= req_state.num_computed_tokens):
639+
if num_computed_tokens <= req_state.num_computed_tokens:
640640
# The request was rescheduled after a KV load failure. Clear
641641
# the last sampled tokens and rewind the generator state
642642
len_output_token_ids = len(req_state.output_token_ids)

0 commit comments

Comments
 (0)