Skip to content

Conversation

@wseaton
Copy link
Contributor

@wseaton wseaton commented Oct 3, 2025

Purpose

This integrates nixl_connector with additional scheduler features exposed in #19330 for retrying requests that have failed blocks.

This PR also includes a small bugfix where if P crashes during zmq handshake, the D node's request status would get stuck in WAITING_FOR_REMOTE_KV forever.

Test Plan

For integration testing, tested injecting faults using a vllm process instrumented with https://github.com/wseaton/ucx-fault-injector/, which forces nixl exceptions to be thrown during transfer.

Logs

  1. KV Load Failure Recovery after failed metadata handshake (decode does local prefill):
(EngineCore_DP0 pid=3317003) ERROR 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:820]   File "zmq/backend/cython/_zmq.py", line 186, in zmq.backend.cython._zmq._check_rc
(EngineCore_DP0 pid=3317003) ERROR 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:820]     raise Again(errno)
(EngineCore_DP0 pid=3317003) ERROR 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:820]     ^^^^^^^^^^^
(EngineCore_DP0 pid=3317003) ERROR 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:820] zmq.error.Again: Resource temporarily unavailable
(EngineCore_DP0 pid=3317003) ERROR 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:832] Handshake failed for request cmpl-6e8a849f-7ec0-4abe-b4c7-38bce04d53d1-0, marking blocks as invalid
(EngineCore_DP0 pid=3317003) WARNING 10-07 17:11:06 [v1/core/sched/scheduler.py:1511] Recovered from KV load failure: 1 request(s) rescheduled (64 tokens affected).
(EngineCore_DP0 pid=3317003) DEBUG 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:349] NIXLConnector get_num_new_matched_tokens: num_computed_tokens=0, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [1, 2, 3, 4], 'remote_engine_id': 'c1ce8a2a-f52d-42fe-8b54-4faa2d6ade51', 'remote_host': 'localhost', 'remote_port': 45307, 'tp_size': 1}
(EngineCore_DP0 pid=3317003) DEBUG 10-07 17:11:06 [distributed/.../v1/nixl_connector.py:369] NIXLConnector update_state_after_alloc: num_external_tokens=0, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [1, 2, 3, 4], 'remote_engine_id': 'c1ce8a2a-f52d-42fe-8b54-4faa2d6ade51', 'remote_host': 'localhost', 'remote_port': 45307, 'tp_size': 1}
(EngineCore_DP0 pid=3317003) DEBUG 10-07 17:11:09 [distributed/.../v1/nixl_connector.py:477] NIXLConnector request_finished, request_status=FINISHED_LENGTH_CAPPED, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [1, 2, 3, 4], 'remote_engine_id': 'c1ce8a2a-f52d-42fe-8b54-4faa2d6ade51', 'remote_host': 'localhost', 'remote_port': 45307, 'tp_size': 1}
(EngineCore_DP0 pid=3317003) DEBUG 10-07 17:11:09 [v1/engine/core.py:810] EngineCore waiting for work.
(APIServer pid=3316606) INFO 10-07 17:11:09 [v1/metrics/loggers.py:128] Engine 000: Avg prompt throughput: 3.6 tokens/s, Avg generation throughput: 7.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=3316606) INFO:     127.0.0.1:40116 - "POST /v1/completions HTTP/1.1" 200 OK
  1. Load failure recovery during block pull (decode does local prefill)
[1759857217.590230](EngineCore_DP0 pid=3320874) ERROR 10-07 17:13:37 [distributed/.../v1/nixl_connector.py:1536] NIXL transfer setup/initiation failed for request cmpl-8f11141d-6a13-42f4-aaf7-548dd83a8952-0: NIXL_ERR_REMOTE_DISCONNECT. Marking blocks as invalid.
(EngineCore_DP0 pid=3320874) WARNING 10-07 17:13:37 [v1/core/sched/scheduler.py:1511] Recovered from KV load failure: 1 request(s) rescheduled (64 tokens affected).
(EngineCore_DP0 pid=3320874) DEBUG 10-07 17:13:37 [distributed/.../v1/nixl_connector.py:349] NIXLConnector get_num_new_matched_tokens: num_computed_tokens=0, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [1, 2, 3, 4], 'remote_engine_id': '37ae0670-4572-46d6-8f06-4de596379492', 'remote_host': 'localhost', 'remote_port': 45307, 'tp_size': 1}
(EngineCore_DP0 pid=3320874) DEBUG 10-07 17:13:37 [distributed/.../v1/nixl_connector.py:369] NIXLConnector update_state_after_alloc: num_external_tokens=0, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [1, 2, 3, 4], 'remote_engine_id': '37ae0670-4572-46d6-8f06-4de596379492', 'remote_host': 'localhost', 'remote_port': 45307, 'tp_size': 1}
(EngineCore_DP0 pid=3320874) DEBUG 10-07 17:13:38 [distributed/.../v1/nixl_connector.py:477] NIXLConnector request_finished, request_status=FINISHED_LENGTH_CAPPED, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [1, 2, 3, 4], 'remote_engine_id': '37ae0670-4572-46d6-8f06-4de596379492', 'remote_host': 'localhost', 'remote_port': 45307, 'tp_size': 1}
(EngineCore_DP0 pid=3320874) DEBUG 10-07 17:13:38 [v1/engine/core.py:810] EngineCore waiting for work.
(APIServer pid=3320562) INFO:     127.0.0.1:44080 - "POST /v1/completions HTTP/1.1" 200 OK

Future Work

Make this behavior opt-out via a global configuration option, and then enable aborting in the API server for the fail path, since this results in locall prefills on the decode node as the failure recovery mechanism.

@mergify
Copy link

mergify bot commented Oct 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wseaton.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces failure recovery mechanisms for the NIXL KV connector. It adds error handling for transfer initiation failures and failures during block reads. When a failure occurs, the affected KV cache blocks are marked as invalid, and this information is propagated to the scheduler for retrying the request. The changes also include adding statistics for failed transfers and notifications, and rate-limiting for some log messages to prevent spam.

The overall approach is sound and significantly improves the robustness of the NIXL connector. However, I've found a critical issue where failed blocks are not reported correctly when use_host_buffer is disabled, which would prevent failure recovery in that configuration. I've left a comment with details on the issue and a suggested fix.

@mergify mergify bot added the frontend label Oct 3, 2025
@wseaton wseaton changed the title [P/D] [NixlConnecotr] Draft: add KV load failure recovery to nixl connector [P/D] [NixlConnector] Draft: failure handling + context propogation Oct 4, 2025
@wseaton wseaton changed the title [P/D] [NixlConnector] Draft: failure handling + context propogation [P/D] [NixlConnector] Draft: improved failure handling + context propagation Oct 4, 2025
@wseaton wseaton force-pushed the nixl-failure-recovery branch from 035e54b to bfd1f52 Compare October 6, 2025 13:46
@mergify mergify bot removed the needs-rebase label Oct 6, 2025
@wseaton wseaton force-pushed the nixl-failure-recovery branch from 84e9a53 to 23612e9 Compare October 6, 2025 20:56
wseaton and others added 10 commits October 12, 2025 20:14
Signed-off-by: Will Eaton <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
Signed-off-by: Will Eaton <[email protected]>
auto-merge was automatically disabled October 13, 2025 00:17

Head branch was pushed to by a user without write access

@wseaton wseaton force-pushed the nixl-failure-recovery branch from 86f86f0 to 52ab9f6 Compare October 13, 2025 00:17
@wseaton
Copy link
Contributor Author

wseaton commented Oct 13, 2025

@njhill this needs a manual merge, had to rebase because of formatting changes 😬

Signed-off-by: Will Eaton <[email protected]>
@njhill njhill merged commit 53c9a7c into vllm-project:main Oct 13, 2025
50 checks passed
VladOS95-cyber pushed a commit to VladOS95-cyber/vllm that referenced this pull request Oct 13, 2025
1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Nov 12, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants