@@ -879,12 +879,12 @@ def update_from_output(
879879 kv_connector_stats = (kv_connector_output .kv_connector_stats
880880 if kv_connector_output else None )
881881
882- affected_req_ids = None
882+ failed_kv_load_req_ids = None
883883 if kv_connector_output and kv_connector_output .invalid_block_ids :
884884 # These blocks contain externally computed tokens that failed to
885885 # load. Identify affected requests and adjust their computed token
886886 # count to trigger recomputation of the invalid blocks.
887- affected_req_ids = self ._handle_invalid_blocks (
887+ failed_kv_load_req_ids = self ._handle_invalid_blocks (
888888 kv_connector_output .invalid_block_ids )
889889
890890 # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
@@ -894,7 +894,7 @@ def update_from_output(
894894 stopped_preempted_reqs : set [Request ] = set ()
895895 for req_id , num_tokens_scheduled in num_scheduled_tokens .items ():
896896 assert num_tokens_scheduled > 0
897- if affected_req_ids and req_id in affected_req_ids :
897+ if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids :
898898 # Skip requests that were recovered from KV load failure
899899 continue
900900 request = self .requests .get (req_id )
@@ -1325,11 +1325,30 @@ def _update_from_kv_xfer_finished(self,
13251325
13261326 def _update_requests_with_invalid_blocks (
13271327 self , requests : Iterable [Request ],
1328- invalid_block_ids : set [int ]) -> tuple [set [str ], int , set [int ]]:
1328+ invalid_block_ids : set [int ]) -> tuple [set [str ], int ]:
1329+ """
1330+ Identify and update requests affected by invalid KV cache blocks.
1331+
1332+ This method scans the given requests, detects those with invalid blocks
1333+ and adjusts their `num_computed_tokens` to the longest valid prefix.
1334+ For observability, it also accumulates the total number of tokens that
1335+ will need to be recomputed across all affected requests.
1336+
1337+ Args:
1338+ requests: The set of requests to scan for invalid blocks.
1339+ invalid_block_ids: IDs of invalid blocks.
1340+
1341+ Returns:
1342+ tuple:
1343+ - affected_req_ids (set[str]): IDs of requests impacted by
1344+ invalid blocks.
1345+ - total_affected_tokens (int): Total number of tokens that must
1346+ be recomputed across all affected requests (for observability).
1347+ """
13291348 affected_req_ids : set [str ] = set ()
13301349 total_affected_tokens = 0
13311350 # If a block is invalid and shared by multiple requests in the batch,
1332- # all requests must be rescheduled, but only the first will recompute
1351+ # these requests must be rescheduled, but only the first will recompute
13331352 # it. This set tracks blocks already marked for recomputation.
13341353 marked_invalid_block_ids : set [int ] = set ()
13351354 for request in requests :
@@ -1341,12 +1360,14 @@ def _update_requests_with_invalid_blocks(
13411360 # We iterate only over blocks that may contain externally computed
13421361 # tokens
13431362 if request .status == RequestStatus .WAITING_FOR_REMOTE_KVS :
1363+ # Async loading. If num_computed_tokens is set it implies we
1364+ # already processed some block failures for it in a prior step
13441365 req_num_computed_tokens = (
1345- request .num_computed_tokens if request . request_id
1366+ request .num_computed_tokens if req_id
13461367 in self .failed_recving_kv_req_ids else len (req_block_ids ) *
13471368 self .block_size )
13481369 else :
1349- # In sync load, num_computed_tokens includes new tokens
1370+ # Sync loading. num_computed_tokens includes new tokens
13501371 req_num_computed_tokens = request .num_cached_tokens
13511372
13521373 req_num_computed_blocks = (req_num_computed_tokens +
@@ -1364,6 +1385,8 @@ def _update_requests_with_invalid_blocks(
13641385 # and was already marked for recomputation.
13651386 # This means this request can still consider this block
13661387 # as computed when rescheduled.
1388+ # Currently this only applies to sync loading; Async
1389+ # loading does not yet support block sharing
13671390 continue
13681391
13691392 marked_invalid_block_ids .add (block_id )
@@ -1374,6 +1397,7 @@ def _update_requests_with_invalid_blocks(
13741397 continue
13751398
13761399 marked_invalid_block = True
1400+ # Truncate the computed tokens at the first failed block
13771401 request .num_computed_tokens = idx * self .block_size
13781402 total_affected_tokens += (req_num_computed_tokens -
13791403 request .num_computed_tokens )
@@ -1383,14 +1407,15 @@ def _update_requests_with_invalid_blocks(
13831407 # All invalid blocks of this request are shared with
13841408 # previous requests and will be recomputed by them.
13851409 # Revert to considering only cached tokens as computed.
1410+ # Currently this only applies to sync loading; Async
1411+ # loading does not yet support block sharing
13861412 total_affected_tokens += (request .num_computed_tokens -
13871413 request .num_cached_tokens )
13881414 request .num_computed_tokens = request .num_cached_tokens
13891415
13901416 affected_req_ids .add (request .request_id )
13911417
1392- return (affected_req_ids , total_affected_tokens ,
1393- marked_invalid_block_ids )
1418+ return (affected_req_ids , total_affected_tokens )
13941419
13951420 def _handle_invalid_blocks (self , invalid_block_ids : set [int ]) -> set [str ]:
13961421 total_requests_to_reschedule = 0
@@ -1400,36 +1425,31 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
14001425 async_load_reqs = (
14011426 req for req in self .waiting
14021427 if req .status == RequestStatus .WAITING_FOR_REMOTE_KVS )
1403- (affected_req_ids , num_tokens_to_reschedule ,
1404- marked_invalid_block_ids ) = (
1405- self ._update_requests_with_invalid_blocks (async_load_reqs ,
1406- invalid_block_ids ))
1428+ async_affected_req_ids , num_tokens_to_reschedule = (
1429+ self ._update_requests_with_invalid_blocks (async_load_reqs ,
1430+ invalid_block_ids ))
14071431
1408- total_requests_to_reschedule += len (affected_req_ids )
1432+ total_requests_to_reschedule += len (async_affected_req_ids )
14091433 total_tokens_to_reschedule += num_tokens_to_reschedule
14101434
14111435 # Mark requests with async KV load failures; they will be rescheduled
14121436 # once loading completes
1413- self .failed_recving_kv_req_ids |= affected_req_ids
1414-
1415- # Remove async loaded invalid blocks already handled,
1416- # as they cannot be shared with running requests.
1417- invalid_block_ids .difference_update (marked_invalid_block_ids )
1437+ self .failed_recving_kv_req_ids |= async_affected_req_ids
14181438
14191439 # --- Handle sync KV loads (running requests) ---
1420- affected_req_ids , num_tokens_to_reschedule , _ = (
1440+ sync_affected_req_ids , num_tokens_to_reschedule = (
14211441 self ._update_requests_with_invalid_blocks (self .running ,
14221442 invalid_block_ids ))
14231443
1424- total_requests_to_reschedule += len (affected_req_ids )
1444+ total_requests_to_reschedule += len (sync_affected_req_ids )
14251445 total_tokens_to_reschedule += num_tokens_to_reschedule
14261446
14271447 if total_requests_to_reschedule :
1428- logger .info (
1448+ logger .warning (
14291449 "Recovered from KV load failure: "
14301450 "%d request(s) rescheduled (%d tokens affected)." ,
14311451 total_requests_to_reschedule , total_tokens_to_reschedule )
14321452
14331453 # Return the IDs of affected running requests to skip in
14341454 # update_from_output.
1435- return affected_req_ids
1455+ return sync_affected_req_ids
0 commit comments