6868 NixlWrapper = None
6969 nixlXferTelemetry = None
7070
71+
7172try :
7273 from nixl ._api import nixl_agent_config
7374except ImportError :
@@ -234,6 +235,11 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
234235 assert self .connector_worker is not None
235236 return self .connector_worker .get_finished ()
236237
238+ def get_block_ids_with_load_errors (self ) -> set [int ]:
239+ """Get block IDs that failed to load via NIXL."""
240+ assert self .connector_worker is not None
241+ return self .connector_worker .get_block_ids_with_load_errors ()
242+
237243 def get_kv_connector_stats (self ) -> KVConnectorStats | None :
238244 assert self .connector_worker is not None
239245 return self .connector_worker .get_kv_connector_stats ()
@@ -614,6 +620,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
614620 # Set of requests that have been part of a batch, regardless of status.
615621 self ._reqs_to_process : set [ReqId ] = set ()
616622
623+ # invalid blocks from failed NIXL operations
624+ self ._invalid_block_ids : set [int ] = set ()
625+ # requests that skipped transfer (handshake or transfer failures)
626+ self ._failed_recv_reqs : set [ReqId ] = set ()
627+
617628 # Background thread for handling new handshake requests.
618629 self ._nixl_handshake_listener_t : threading .Thread | None = None
619630 # Background thread for initializing new NIXL handshakes.
@@ -713,6 +724,8 @@ def _nixl_handshake(
713724
714725 # Send query for the request.
715726 with zmq_ctx (zmq .REQ , path ) as sock :
727+ # Set receive timeout to 5 seconds to avoid hanging on dead server
728+ sock .setsockopt (zmq .RCVTIMEO , 5000 ) # milliseconds
716729 sock .send (GET_META_MSG )
717730 metadata_bytes = sock .recv ()
718731 decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
@@ -795,10 +808,20 @@ def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
795808
796809 fut .add_done_callback (done_callback )
797810
798- # TODO: handle failure state of future in the
799- # callback, we want to fail the request in this case.
800- def request_ready (_f : Future [Any ], entry = (req_id , meta )):
801- self ._ready_requests .put (entry )
811+ # check handshake success before proceeding with request
812+ def request_ready (f : Future [Any ], entry = (req_id , meta )):
813+ try :
814+ # check if handshake succeeded
815+ f .result ()
816+ self ._ready_requests .put (entry )
817+ except Exception :
818+ # handshake failed - mark blocks as invalid
819+ logger .exception (
820+ "Handshake failed for request %s, marking blocks as invalid" , req_id
821+ )
822+ if req_meta := self ._recving_metadata .get (req_id ):
823+ self ._invalid_block_ids .update (req_meta .local_block_ids )
824+ self ._failed_recv_reqs .add (req_id )
802825
803826 fut .add_done_callback (request_ready )
804827
@@ -1205,6 +1228,11 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12051228 """
12061229 done_sending = self ._get_new_notifs ()
12071230 done_recving = self ._pop_done_transfers (self ._recving_transfers )
1231+
1232+ # add requests that skipped transfer to done_recving
1233+ done_recving .update (self ._failed_recv_reqs )
1234+ self ._failed_recv_reqs .clear ()
1235+
12081236 if len (done_sending ) > 0 or len (done_recving ) > 0 :
12091237 logger .debug (
12101238 "Rank %s, get_finished: %s requests done sending "
@@ -1214,10 +1242,10 @@ def get_finished(self) -> tuple[set[str], set[str]]:
12141242 len (done_recving ),
12151243 )
12161244
1217- if self . use_host_buffer :
1218- for req_id in done_recving :
1219- meta = self ._recving_metadata .pop (req_id )
1220- assert meta , f" { req_id } not found in recving_metadata list"
1245+ # clean up metadata for completed requests
1246+ for req_id in done_recving :
1247+ meta = self ._recving_metadata .pop (req_id , None )
1248+ if self . use_host_buffer and meta :
12211249 self .sync_recved_kv_to_device (req_id , meta )
12221250
12231251 # Handle timeout to avoid stranding blocks on remote.
@@ -1296,7 +1324,19 @@ def _pop_done_transfers(
12961324 in_progress = True
12971325 continue
12981326 else :
1299- raise RuntimeError ("Transfer failed with state %s" , xfer_state )
1327+ # transfer failed - mark blocks as invalid
1328+ logger .error (
1329+ "NIXL transfer failed for request %s with state %s. "
1330+ "Marking blocks as invalid." ,
1331+ req_id ,
1332+ xfer_state ,
1333+ )
1334+ # mark all blocks for this request as invalid
1335+ if meta := self ._recving_metadata .pop (req_id , None ):
1336+ self ._invalid_block_ids .update (meta .local_block_ids )
1337+ self ._recving_metadata .pop (req_id , None )
1338+ self .nixl_wrapper .release_xfer_handle (handle )
1339+ self .xfer_stats .record_failed_transfer ()
13001340 if not in_progress :
13011341 done_req_ids .add (req_id )
13021342 del transfers [req_id ]
@@ -1317,8 +1357,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
13171357 len (meta .local_block_ids ),
13181358 len (meta .remote_block_ids ),
13191359 )
1320- if self . use_host_buffer :
1321- self ._recving_metadata [req_id ] = meta
1360+ # always store metadata for failure recovery
1361+ self ._recving_metadata [req_id ] = meta
13221362 if remote_engine_id not in self ._remote_agents :
13231363 # Initiate handshake with remote engine to exchange metadata.
13241364 with self ._handshake_lock :
@@ -1394,7 +1434,16 @@ def _read_blocks(
13941434 if num_local_blocks == 0 :
13951435 remote_rank = self .tp_rank // tp_ratio
13961436 agent_name = self ._remote_agents [dst_engine_id ][remote_rank ]
1397- self .nixl_wrapper .send_notif (agent_name , notif_msg = notif_id )
1437+ try :
1438+ self .nixl_wrapper .send_notif (agent_name , notif_msg = notif_id )
1439+ except Exception :
1440+ logger .exception (
1441+ "NIXL send_notif failed for request %s: "
1442+ "P worker blocks will be freed after timeout. "
1443+ "This may indicate network issues." ,
1444+ request_id ,
1445+ )
1446+ self .xfer_stats .record_failed_notification ()
13981447 return
13991448
14001449 # Partial prefix cache hit: just read uncomputed blocks.
@@ -1456,20 +1505,35 @@ def _read_blocks(
14561505 assert len (local_block_descs_ids ) == len (remote_block_descs_ids )
14571506
14581507 # Prepare transfer with Nixl.
1459- handle = self .nixl_wrapper .make_prepped_xfer (
1460- "READ" ,
1461- local_xfer_side_handle ,
1462- local_block_descs_ids ,
1463- remote_xfer_side_handle ,
1464- remote_block_descs_ids ,
1465- notif_msg = notif_id ,
1466- )
1508+ handle = None
1509+ try :
1510+ handle = self .nixl_wrapper .make_prepped_xfer (
1511+ "READ" ,
1512+ local_xfer_side_handle ,
1513+ local_block_descs_ids ,
1514+ remote_xfer_side_handle ,
1515+ remote_block_descs_ids ,
1516+ notif_msg = notif_id ,
1517+ )
14671518
1468- # Begin async xfer.
1469- self .nixl_wrapper .transfer (handle )
1519+ # Begin async xfer.
1520+ self .nixl_wrapper .transfer (handle )
14701521
1471- # Use handle to check completion in future step().
1472- self ._recving_transfers [request_id ].append ((handle , time .perf_counter ()))
1522+ # Use handle to check completion in future step().
1523+ self ._recving_transfers [request_id ].append ((handle , time .perf_counter ()))
1524+ except Exception :
1525+ logger .exception (
1526+ "NIXL transfer setup/initiation failed for request %s. "
1527+ "Marking blocks as invalid." ,
1528+ request_id ,
1529+ )
1530+ # mark all blocks for this request as invalid
1531+ if meta := self ._recving_metadata .get (request_id ):
1532+ self ._invalid_block_ids .update (meta .local_block_ids )
1533+ self .xfer_stats .record_failed_transfer ()
1534+ if handle is not None :
1535+ self .nixl_wrapper .release_xfer_handle (handle )
1536+ self ._failed_recv_reqs .add (request_id )
14731537
14741538 def _get_block_descs_ids (
14751539 self , engine_id : str , block_ids : list [int ], layer_idx : int | None = None
@@ -1527,6 +1591,17 @@ def get_kv_connector_stats(self) -> KVConnectorStats | None:
15271591 return self .xfer_stats .clone_and_reset ()
15281592 return None
15291593
1594+ def get_block_ids_with_load_errors (self ) -> set [int ]:
1595+ """
1596+ Return and clear the set of block IDs that failed to load.
1597+
1598+ This is called by the scheduler to identify blocks that need
1599+ to be retried after a NIXL transfer failure.
1600+ """
1601+ result = self ._invalid_block_ids
1602+ self ._invalid_block_ids = set ()
1603+ return result
1604+
15301605 def shutdown (self ):
15311606 """Shutdown the connector worker."""
15321607 self ._handshake_initiation_executor .shutdown (wait = False )
@@ -1586,6 +1661,8 @@ def reset(self):
15861661 "post_duration" : [],
15871662 "bytes_transferred" : [],
15881663 "num_descriptors" : [],
1664+ "num_failed_transfers" : [],
1665+ "num_failed_notifications" : [],
15891666 }
15901667
15911668 def record_transfer (self , res : nixlXferTelemetry ):
@@ -1595,6 +1672,14 @@ def record_transfer(self, res: nixlXferTelemetry):
15951672 self .data ["bytes_transferred" ].append (res .totalBytes )
15961673 self .data ["num_descriptors" ].append (res .descCount )
15971674
1675+ def record_failed_transfer (self ):
1676+ """Record a failed NIXL transfer operation."""
1677+ self .data ["num_failed_transfers" ].append (1.0 )
1678+
1679+ def record_failed_notification (self ):
1680+ """Record a failed NIXL notification (send_notif)."""
1681+ self .data ["num_failed_notifications" ].append (1.0 )
1682+
15981683 def clone_and_reset (self ) -> "NixlKVConnectorStats" :
15991684 old = copy .copy (self )
16001685 self .reset ()
0 commit comments