Skip to content

Commit 715fc83

Browse files
wseaton0xrushi
authored andcommitted
[P/D] [NixlConnector] kv load recovery integration (vllm-project#26171)
Signed-off-by: Will Eaton <[email protected]> Signed-off-by: 0xrushi <[email protected]>
1 parent 60b847d commit 715fc83

File tree

3 files changed

+252
-26
lines changed

3 files changed

+252
-26
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def _make_fake_nixl_pkg():
190190
# Copy of FakeNixlWrapper implementation for Ray workers
191191
import uuid
192192
from collections import defaultdict
193-
from typing import Optional
194193
195194
{fake_nixl_source}
196195
@@ -1143,3 +1142,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
11431142
# After abort, the worker should not keep tracking it as "in-batch"
11441143
assert req.request_id not in connector.connector_worker._reqs_to_process
11451144
#### Model Runner end ####
1145+
1146+
1147+
class FailingNixlWrapper(FakeNixlWrapper):
1148+
"""Mock NixlWrapper that fails on specific operations."""
1149+
1150+
def __init__(self, *args, **kwargs):
1151+
super().__init__(*args, **kwargs)
1152+
self.fail_handshake = False
1153+
self.fail_transfer_setup = False
1154+
self.fail_send_notif = False
1155+
1156+
def add_remote_agent(self, agent_metadata: bytes) -> str:
1157+
if self.fail_handshake:
1158+
from zmq.error import Again
1159+
1160+
raise Again("Simulated timeout failure")
1161+
return super().add_remote_agent(agent_metadata)
1162+
1163+
def make_prepped_xfer(
1164+
self,
1165+
xfer_type: str,
1166+
local_xfer_side_handle: int,
1167+
local_block_descs_ids: list[int],
1168+
remote_xfer_side_handle: int,
1169+
remote_block_descs_ids: list[int],
1170+
notif_msg: bytes | None = None,
1171+
) -> int:
1172+
if self.fail_transfer_setup:
1173+
# classic RuntimeError to simulate failure
1174+
raise RuntimeError("BAD STATUS")
1175+
return super().make_prepped_xfer(
1176+
xfer_type,
1177+
local_xfer_side_handle,
1178+
local_block_descs_ids,
1179+
remote_xfer_side_handle,
1180+
remote_block_descs_ids,
1181+
notif_msg,
1182+
)
1183+
1184+
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
1185+
if self.fail_send_notif:
1186+
raise RuntimeError("Simulated send_notif failure")
1187+
return super().send_notif(agent_name, notif_msg)
1188+
1189+
1190+
@patch(
1191+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
1192+
FailingNixlWrapper,
1193+
)
1194+
def test_handshake_failure_returns_finished(dist_init):
1195+
"""Test that handshake failures mark blocks invalid and return via get_finished."""
1196+
vllm_config = create_vllm_config()
1197+
1198+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
1199+
connector.connector_worker = FakeNixlConnectorWorker(
1200+
vllm_config, connector.engine_id, hand_shake_latency=0.1
1201+
)
1202+
connector.connector_worker.nixl_wrapper.fail_handshake = True
1203+
1204+
request_id = "test_handshake_fail"
1205+
metadata = NixlConnectorMetadata()
1206+
metadata.add_new_req(
1207+
request_id=request_id,
1208+
local_block_ids=[1, 2, 3],
1209+
kv_transfer_params={
1210+
"remote_block_ids": [4, 5, 6],
1211+
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
1212+
"remote_host": "localhost",
1213+
"remote_port": 1234,
1214+
"remote_tp_size": 1,
1215+
},
1216+
)
1217+
connector.bind_connector_metadata(metadata)
1218+
1219+
dummy_ctx = ForwardContext(
1220+
no_compile_layers={},
1221+
attn_metadata={},
1222+
virtual_engine=0,
1223+
)
1224+
connector.start_load_kv(dummy_ctx)
1225+
1226+
# Wait for handshake to fail
1227+
time.sleep(0.3)
1228+
1229+
# Check that blocks were marked invalid
1230+
invalid_blocks = connector.get_block_ids_with_load_errors()
1231+
assert invalid_blocks == {1, 2, 3}
1232+
1233+
# Check that request appears in get_finished
1234+
_, done_recving = connector.get_finished(finished_req_ids=set())
1235+
assert request_id in done_recving
1236+
1237+
1238+
@patch(
1239+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
1240+
FailingNixlWrapper,
1241+
)
1242+
def test_transfer_setup_failure_returns_finished(dist_init):
1243+
"""Test that transfer setup failures mark blocks invalid
1244+
and return via get_finished."""
1245+
vllm_config = create_vllm_config()
1246+
1247+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
1248+
connector.connector_worker = FakeNixlConnectorWorker(
1249+
vllm_config, connector.engine_id, hand_shake_latency=0
1250+
)
1251+
connector.connector_worker.nixl_wrapper.fail_transfer_setup = True
1252+
1253+
request_id = "test_transfer_fail"
1254+
metadata = NixlConnectorMetadata()
1255+
metadata.add_new_req(
1256+
request_id=request_id,
1257+
local_block_ids=[7, 8, 9],
1258+
kv_transfer_params={
1259+
"remote_block_ids": [10, 11, 12],
1260+
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
1261+
"remote_host": "localhost",
1262+
"remote_port": 1234,
1263+
"remote_tp_size": 1,
1264+
},
1265+
)
1266+
connector.bind_connector_metadata(metadata)
1267+
1268+
dummy_ctx = ForwardContext(
1269+
no_compile_layers={},
1270+
attn_metadata={},
1271+
virtual_engine=0,
1272+
)
1273+
connector.start_load_kv(dummy_ctx)
1274+
1275+
# Wait for handshake to complete and process ready_requests
1276+
connector.bind_connector_metadata(NixlConnectorMetadata())
1277+
time.sleep(0.1)
1278+
connector.start_load_kv(dummy_ctx)
1279+
1280+
# check that blocks were marked invalid
1281+
invalid_blocks = connector.get_block_ids_with_load_errors()
1282+
assert invalid_blocks == {7, 8, 9}
1283+
1284+
# ensure request appears in get_finished
1285+
_, done_recving = connector.get_finished(finished_req_ids=set())
1286+
assert request_id in done_recving

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
NixlWrapper = None
6969
nixlXferTelemetry = None
7070

71+
7172
try:
7273
from nixl._api import nixl_agent_config
7374
except ImportError:
@@ -234,6 +235,11 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
234235
assert self.connector_worker is not None
235236
return self.connector_worker.get_finished()
236237

238+
def get_block_ids_with_load_errors(self) -> set[int]:
239+
"""Get block IDs that failed to load via NIXL."""
240+
assert self.connector_worker is not None
241+
return self.connector_worker.get_block_ids_with_load_errors()
242+
237243
def get_kv_connector_stats(self) -> KVConnectorStats | None:
238244
assert self.connector_worker is not None
239245
return self.connector_worker.get_kv_connector_stats()
@@ -614,6 +620,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
614620
# Set of requests that have been part of a batch, regardless of status.
615621
self._reqs_to_process: set[ReqId] = set()
616622

623+
# invalid blocks from failed NIXL operations
624+
self._invalid_block_ids: set[int] = set()
625+
# requests that skipped transfer (handshake or transfer failures)
626+
self._failed_recv_reqs: set[ReqId] = set()
627+
617628
# Background thread for handling new handshake requests.
618629
self._nixl_handshake_listener_t: threading.Thread | None = None
619630
# Background thread for initializing new NIXL handshakes.
@@ -713,6 +724,8 @@ def _nixl_handshake(
713724

714725
# Send query for the request.
715726
with zmq_ctx(zmq.REQ, path) as sock:
727+
# Set receive timeout to 5 seconds to avoid hanging on dead server
728+
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
716729
sock.send(GET_META_MSG)
717730
metadata_bytes = sock.recv()
718731
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
@@ -795,10 +808,20 @@ def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
795808

796809
fut.add_done_callback(done_callback)
797810

798-
# TODO: handle failure state of future in the
799-
# callback, we want to fail the request in this case.
800-
def request_ready(_f: Future[Any], entry=(req_id, meta)):
801-
self._ready_requests.put(entry)
811+
# check handshake success before proceeding with request
812+
def request_ready(f: Future[Any], entry=(req_id, meta)):
813+
try:
814+
# check if handshake succeeded
815+
f.result()
816+
self._ready_requests.put(entry)
817+
except Exception:
818+
# handshake failed - mark blocks as invalid
819+
logger.exception(
820+
"Handshake failed for request %s, marking blocks as invalid", req_id
821+
)
822+
if req_meta := self._recving_metadata.get(req_id):
823+
self._invalid_block_ids.update(req_meta.local_block_ids)
824+
self._failed_recv_reqs.add(req_id)
802825

803826
fut.add_done_callback(request_ready)
804827

@@ -1205,6 +1228,11 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12051228
"""
12061229
done_sending = self._get_new_notifs()
12071230
done_recving = self._pop_done_transfers(self._recving_transfers)
1231+
1232+
# add requests that skipped transfer to done_recving
1233+
done_recving.update(self._failed_recv_reqs)
1234+
self._failed_recv_reqs.clear()
1235+
12081236
if len(done_sending) > 0 or len(done_recving) > 0:
12091237
logger.debug(
12101238
"Rank %s, get_finished: %s requests done sending "
@@ -1214,10 +1242,10 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12141242
len(done_recving),
12151243
)
12161244

1217-
if self.use_host_buffer:
1218-
for req_id in done_recving:
1219-
meta = self._recving_metadata.pop(req_id)
1220-
assert meta, f"{req_id} not found in recving_metadata list"
1245+
# clean up metadata for completed requests
1246+
for req_id in done_recving:
1247+
meta = self._recving_metadata.pop(req_id, None)
1248+
if self.use_host_buffer and meta:
12211249
self.sync_recved_kv_to_device(req_id, meta)
12221250

12231251
# Handle timeout to avoid stranding blocks on remote.
@@ -1296,7 +1324,19 @@ def _pop_done_transfers(
12961324
in_progress = True
12971325
continue
12981326
else:
1299-
raise RuntimeError("Transfer failed with state %s", xfer_state)
1327+
# transfer failed - mark blocks as invalid
1328+
logger.error(
1329+
"NIXL transfer failed for request %s with state %s. "
1330+
"Marking blocks as invalid.",
1331+
req_id,
1332+
xfer_state,
1333+
)
1334+
# mark all blocks for this request as invalid
1335+
if meta := self._recving_metadata.pop(req_id, None):
1336+
self._invalid_block_ids.update(meta.local_block_ids)
1337+
self._recving_metadata.pop(req_id, None)
1338+
self.nixl_wrapper.release_xfer_handle(handle)
1339+
self.xfer_stats.record_failed_transfer()
13001340
if not in_progress:
13011341
done_req_ids.add(req_id)
13021342
del transfers[req_id]
@@ -1317,8 +1357,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13171357
len(meta.local_block_ids),
13181358
len(meta.remote_block_ids),
13191359
)
1320-
if self.use_host_buffer:
1321-
self._recving_metadata[req_id] = meta
1360+
# always store metadata for failure recovery
1361+
self._recving_metadata[req_id] = meta
13221362
if remote_engine_id not in self._remote_agents:
13231363
# Initiate handshake with remote engine to exchange metadata.
13241364
with self._handshake_lock:
@@ -1394,7 +1434,16 @@ def _read_blocks(
13941434
if num_local_blocks == 0:
13951435
remote_rank = self.tp_rank // tp_ratio
13961436
agent_name = self._remote_agents[dst_engine_id][remote_rank]
1397-
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
1437+
try:
1438+
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
1439+
except Exception:
1440+
logger.exception(
1441+
"NIXL send_notif failed for request %s: "
1442+
"P worker blocks will be freed after timeout. "
1443+
"This may indicate network issues.",
1444+
request_id,
1445+
)
1446+
self.xfer_stats.record_failed_notification()
13981447
return
13991448

14001449
# Partial prefix cache hit: just read uncomputed blocks.
@@ -1456,20 +1505,35 @@ def _read_blocks(
14561505
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
14571506

14581507
# Prepare transfer with Nixl.
1459-
handle = self.nixl_wrapper.make_prepped_xfer(
1460-
"READ",
1461-
local_xfer_side_handle,
1462-
local_block_descs_ids,
1463-
remote_xfer_side_handle,
1464-
remote_block_descs_ids,
1465-
notif_msg=notif_id,
1466-
)
1508+
handle = None
1509+
try:
1510+
handle = self.nixl_wrapper.make_prepped_xfer(
1511+
"READ",
1512+
local_xfer_side_handle,
1513+
local_block_descs_ids,
1514+
remote_xfer_side_handle,
1515+
remote_block_descs_ids,
1516+
notif_msg=notif_id,
1517+
)
14671518

1468-
# Begin async xfer.
1469-
self.nixl_wrapper.transfer(handle)
1519+
# Begin async xfer.
1520+
self.nixl_wrapper.transfer(handle)
14701521

1471-
# Use handle to check completion in future step().
1472-
self._recving_transfers[request_id].append((handle, time.perf_counter()))
1522+
# Use handle to check completion in future step().
1523+
self._recving_transfers[request_id].append((handle, time.perf_counter()))
1524+
except Exception:
1525+
logger.exception(
1526+
"NIXL transfer setup/initiation failed for request %s. "
1527+
"Marking blocks as invalid.",
1528+
request_id,
1529+
)
1530+
# mark all blocks for this request as invalid
1531+
if meta := self._recving_metadata.get(request_id):
1532+
self._invalid_block_ids.update(meta.local_block_ids)
1533+
self.xfer_stats.record_failed_transfer()
1534+
if handle is not None:
1535+
self.nixl_wrapper.release_xfer_handle(handle)
1536+
self._failed_recv_reqs.add(request_id)
14731537

14741538
def _get_block_descs_ids(
14751539
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
@@ -1527,6 +1591,17 @@ def get_kv_connector_stats(self) -> KVConnectorStats | None:
15271591
return self.xfer_stats.clone_and_reset()
15281592
return None
15291593

1594+
def get_block_ids_with_load_errors(self) -> set[int]:
1595+
"""
1596+
Return and clear the set of block IDs that failed to load.
1597+
1598+
This is called by the scheduler to identify blocks that need
1599+
to be retried after a NIXL transfer failure.
1600+
"""
1601+
result = self._invalid_block_ids
1602+
self._invalid_block_ids = set()
1603+
return result
1604+
15301605
def shutdown(self):
15311606
"""Shutdown the connector worker."""
15321607
self._handshake_initiation_executor.shutdown(wait=False)
@@ -1586,6 +1661,8 @@ def reset(self):
15861661
"post_duration": [],
15871662
"bytes_transferred": [],
15881663
"num_descriptors": [],
1664+
"num_failed_transfers": [],
1665+
"num_failed_notifications": [],
15891666
}
15901667

15911668
def record_transfer(self, res: nixlXferTelemetry):
@@ -1595,6 +1672,14 @@ def record_transfer(self, res: nixlXferTelemetry):
15951672
self.data["bytes_transferred"].append(res.totalBytes)
15961673
self.data["num_descriptors"].append(res.descCount)
15971674

1675+
def record_failed_transfer(self):
1676+
"""Record a failed NIXL transfer operation."""
1677+
self.data["num_failed_transfers"].append(1.0)
1678+
1679+
def record_failed_notification(self):
1680+
"""Record a failed NIXL notification (send_notif)."""
1681+
self.data["num_failed_notifications"].append(1.0)
1682+
15981683
def clone_and_reset(self) -> "NixlKVConnectorStats":
15991684
old = copy.copy(self)
16001685
self.reset()

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1487,7 +1487,7 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
14871487
total_tokens_to_reschedule += num_tokens_to_reschedule
14881488

14891489
# Mark requests with async KV load failures; they will be rescheduled
1490-
# once loading completes
1490+
# once loading completes.
14911491
self.failed_recving_kv_req_ids |= async_affected_req_ids
14921492

14931493
# --- Handle sync KV loads (running requests) ---

0 commit comments

Comments
 (0)