diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index aaac2deb12ac..1a705ad827c7 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1948,6 +1948,84 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(scheduler.waiting) == 1 +def test_kv_connector_finished_sending_race_condition(): + """ + Test the race condition where a request is in finished_sending + but not actually in a finished state on the scheduler side. + + This can happen when: + 1. Worker-side NIXL connector times out waiting for decode workers + 2. Worker reports request in finished_sending to prevent stranding blocks + 3. Scheduler-side request hasn't reached a finished state yet + + Before the fix, this would crash with AssertionError in _free_blocks. + After the fix, it should log a warning and skip block freeing. + """ + from vllm.v1.outputs import KVConnectorOutput + + # Setup scheduler with KV connector + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=True, + ) + + # Create and schedule a request + requests = create_requests(num_requests=1, max_tokens=10) + request = requests[0] + scheduler.add_request(request) + + # Schedule the request + scheduler_output = scheduler.schedule() + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert request.request_id in scheduler.requests + + # Simulate model execution - generate one token but DON'T finish + model_runner_output = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[100]], # One token, not EOS + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + # Update with this output - request should still be RUNNING + scheduler.update_from_output(scheduler_output, model_runner_output) + assert request.status == RequestStatus.RUNNING + assert len(scheduler.running) == 1 + + # Now simulate the race condition: create a KVConnectorOutput that + # reports this request as finished_sending, even though the request + # is still RUNNING on the scheduler side. + # This simulates the timeout scenario in NIXL connector. + kv_connector_output = KVConnectorOutput(finished_sending={request.request_id}) + + # Schedule again to trigger the race condition + scheduler_output2 = scheduler.schedule() + model_runner_output2 = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[101]], # Another token, not EOS + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + kv_connector_output=kv_connector_output, + ) + + # This should handle the race condition gracefully + # by logging a warning and skipping block freeing. + _ = scheduler.update_from_output(scheduler_output2, model_runner_output2) + + # Verify the request is still in the system and running + # (i.e., it was NOT incorrectly freed) + assert request.request_id in scheduler.requests + assert request.status == RequestStatus.RUNNING + assert len(scheduler.running) == 1 + + # The request should NOT have been freed + assert request.request_id not in scheduler.finished_req_ids + + def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): """Test that priority scheduling preempts lower priority requests when out of KV cache space.""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99ef..387145372e86 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1365,8 +1365,24 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): self.finished_recving_kv_req_ids.add(req_id) for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) - assert req_id in self.requests - self._free_blocks(self.requests[req_id]) + request = self.requests.get(req_id) + if request is None: + logger.warning( + "Got finished sending KV transfer for request %s, " + "but the request is already freed.", + req_id, + ) + elif not request.is_finished(): + logger.warning( + "Got finished sending KV transfer for request %s, " + "but the request is not finished (status=%s). " + "This may indicate the request was aborted or the KV " + "transfer timed out before the request completed.", + req_id, + request.status, + ) + else: + self._free_blocks(request) def _update_requests_with_invalid_blocks( self, requests: Iterable[Request], invalid_block_ids: set[int]