77from typing import Optional
88from unittest .mock import patch
99
10- import pytest
11-
12- try :
13- from nixl ._api import nixl_agent as NixlWrapper
14- except ImportError :
15- NixlWrapper = None
16-
1710from vllm .distributed .kv_transfer .kv_connector .v1 .nixl_connector import (
1811 KVConnectorRole , NixlAgentMetadata , NixlConnector , NixlConnectorMetadata ,
1912 NixlConnectorWorker )
@@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
9285class FakeNixlWrapper :
9386 """Mock implementation of NixlWrapper for testing.
9487
95- We don't inherit from NixlWrapper because NixlWrapper could be None.
88+ We don't inherit from nixl._api.nixl_agent because nixl may not be
89+ installed.
9690 """
9791
9892 AGENT_METADATA = b"fake_agent_metadata"
@@ -167,7 +161,7 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
167161 super ().__init__ (* args , ** kwargs )
168162 self ._hand_shake_latency = hand_shake_latency
169163
170- def _nixl_handshake (self , host : str , port : int ):
164+ def _nixl_handshake (self , host : str , port : int ) -> dict [ int , str ] :
171165 # Mimic slow _nixl_handshake, as well as bypass zmq communication.
172166 time .sleep (self ._hand_shake_latency )
173167 # These should've been done in register_kv_caches(), called by
@@ -177,7 +171,7 @@ def _nixl_handshake(self, host: str, port: int):
177171 self .num_blocks = 1
178172 self .dst_num_blocks [self .engine_id ] = self .num_blocks
179173
180- self .add_remote_agent (
174+ remote_agent_name = self .add_remote_agent (
181175 NixlAgentMetadata (
182176 engine_id = self .REMOTE_ENGINE_ID ,
183177 agent_metadata = FakeNixlWrapper .AGENT_METADATA ,
@@ -187,60 +181,176 @@ def _nixl_handshake(self, host: str, port: int):
187181 block_len = self .block_len ,
188182 attn_backend_name = self .backend_name ,
189183 ))
190-
191-
192- @pytest .mark .skipif (NixlWrapper is None , reason = "nixl not installed" )
193- @patch (
194- "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
195- FakeNixlWrapper )
196- def test_multi_xfer_one_engine (
197- # dist_init is a fixture that initializes the distributed environment.
198- dist_init ):
199- """Test case where multiple xfers are initiated to the same engine.
200-
201- This test triggers the connector to load remote KV for the same
202- `request_id`. The transfer is not done immediately due to
203- `set_cycles_before_xfer_done`, so there is a state where there are multiple
204- transfer states for the same `request_id`, and `get_finished` should handle
205- it correctly (wait for all transfers to be done).
206- """
207- vllm_config = create_vllm_config ()
208-
209- request_id = "req_id"
210-
211- # Test worker role in decode server.
212- connector = NixlConnector (vllm_config , KVConnectorRole .WORKER )
213- connector .connector_worker = FakeNixlConnectorWorker (vllm_config ,
214- connector .engine_id ,
215- hand_shake_latency = 0 )
216- assert isinstance (connector .connector_worker .nixl_wrapper , FakeNixlWrapper )
217- connector .connector_worker .nixl_wrapper .set_cycles_before_xfer_done (3 )
218- for i in range (4 ):
184+ return {0 : remote_agent_name }
185+
186+
187+ class TestNixlHandshake :
188+
189+ @patch (
190+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
191+ FakeNixlWrapper )
192+ def test_multi_xfer_one_engine (
193+ self ,
194+ # dist_init is a fixture that initializes the distributed environment.
195+ dist_init ):
196+ """Test case where multiple xfers are initiated to the same engine.
197+
198+ This test triggers the connector to load remote KV for the same
199+ `request_id`. The transfer is not done immediately due to
200+ `set_cycles_before_xfer_done`, so there is a state where there are
201+ multiple transfer states for the same `request_id`, and `get_finished`
202+ should handle it correctly (wait for all transfers to be done).
203+ """
204+ vllm_config = create_vllm_config ()
205+
206+ request_id = "req_id"
207+
208+ # Test worker role in decode server.
209+ connector = NixlConnector (vllm_config , KVConnectorRole .WORKER )
210+ connector .connector_worker = FakeNixlConnectorWorker (
211+ vllm_config , connector .engine_id , hand_shake_latency = 0 )
212+ assert isinstance (connector .connector_worker .nixl_wrapper ,
213+ FakeNixlWrapper )
214+ connector .connector_worker .nixl_wrapper .set_cycles_before_xfer_done (3 )
215+ num_xfers = 4
216+ while True :
217+ # For the same request_id, initiate multiple xfers across different
218+ # round of `execute_model` calls.
219+ metadata = NixlConnectorMetadata ()
220+ if num_xfers > 0 :
221+ num_xfers -= 1
222+ metadata .add_new_req (
223+ request_id = request_id ,
224+ local_block_ids = [
225+ num_xfers + 1 , num_xfers + 2 , num_xfers + 3
226+ ],
227+ kv_transfer_params = {
228+ "remote_block_ids" :
229+ [num_xfers + 4 , num_xfers + 5 , num_xfers + 6 ],
230+ "remote_engine_id" :
231+ FakeNixlConnectorWorker .REMOTE_ENGINE_ID ,
232+ "remote_host" :
233+ "localhost" ,
234+ "remote_port" :
235+ 1234 ,
236+ })
237+ connector .bind_connector_metadata (metadata )
238+
239+ # Mimic maybe_setup_kv_connector in gpu_model_runner.
240+ dummy_ctx = ForwardContext (
241+ no_compile_layers = {},
242+ attn_metadata = {},
243+ virtual_engine = 0 ,
244+ )
245+ _before_load = time .perf_counter ()
246+ connector .start_load_kv (dummy_ctx )
247+ _after_load = time .perf_counter ()
248+ assert _after_load - _before_load < 0.1 , "start_load_kv took " \
249+ f"{ _after_load - _before_load } seconds"
250+
251+ # Mimic get_finished_kv_transfers in gpu_model_runner.
252+ _ , done_recving = connector .get_finished (finished_req_ids = set ())
253+ if len (done_recving ) > 0 :
254+ assert request_id in done_recving
255+ break
256+
257+ connector .clear_connector_metadata ()
258+
259+ @patch (
260+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
261+ FakeNixlWrapper )
262+ def test_async_load_kv (
263+ self ,
264+ # dist_init is a fixture that initializes the distributed environment.
265+ dist_init ):
266+ """Test that NixlConnector's start_load_kv should be non-blocking."""
267+
268+ vllm_config = create_vllm_config ()
269+
270+ # Test worker role in decode server.
271+ connector = NixlConnector (vllm_config , KVConnectorRole .WORKER )
272+ connector .connector_worker = FakeNixlConnectorWorker (
273+ vllm_config , connector .engine_id )
219274 metadata = NixlConnectorMetadata ()
220- metadata .add_new_req (request_id = request_id ,
221- local_block_ids = [i + 1 , i + 2 , i + 3 ],
275+ metadata .add_new_req (request_id = "id" ,
276+ local_block_ids = [1 , 2 , 3 ],
222277 kv_transfer_params = {
223- "remote_block_ids" : [i + 4 , i + 5 , i + 6 ],
278+ "remote_block_ids" : [4 , 5 , 6 ],
224279 "remote_engine_id" :
225280 FakeNixlConnectorWorker .REMOTE_ENGINE_ID ,
226281 "remote_host" : "localhost" ,
227282 "remote_port" : 1234 ,
228283 })
229284 connector .bind_connector_metadata (metadata )
230285
231- dummy_ctx = ForwardContext (
232- no_compile_layers = {},
233- attn_metadata = {},
234- virtual_engine = 0 ,
235- )
236- _before_load = time .perf_counter ()
237- connector .start_load_kv (dummy_ctx )
238- _after_load = time .perf_counter ()
239- assert _after_load - _before_load < 0.1 , "start_load_kv took " \
240- f"{ _after_load - _before_load } seconds"
241-
242- while True :
243- _ , done_recving = connector .get_finished (finished_req_ids = set ())
244- if len (done_recving ) > 0 :
245- assert request_id in done_recving
246- break
286+ timeout = 2.5
287+ start = time .perf_counter ()
288+ while time .perf_counter () - start < timeout :
289+ dummy_ctx = ForwardContext (
290+ no_compile_layers = {},
291+ attn_metadata = {},
292+ virtual_engine = 0 ,
293+ )
294+ _before_load = time .perf_counter ()
295+ connector .start_load_kv (dummy_ctx )
296+ _after_load = time .perf_counter ()
297+ assert _after_load - _before_load < 0.1 , "start_load_kv took " \
298+ f"{ _after_load - _before_load } seconds"
299+ time .sleep (0.5 ) # backoff for the async handshake to complete.
300+ connector .bind_connector_metadata (NixlConnectorMetadata ())
301+ _ , done_recving = connector .get_finished (finished_req_ids = set ())
302+ if len (done_recving ) > 0 :
303+ return
304+ raise TimeoutError ("Took too long to complete async handshake." )
305+
306+ @patch (
307+ "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" ,
308+ FakeNixlWrapper )
309+ def test_concurrent_load_kv (
310+ self ,
311+ # dist_init is a fixture that initializes the distributed environment.
312+ dist_init ):
313+ """Test that multiple start_load_kv calls should occur concurrently."""
314+
315+ vllm_config = create_vllm_config ()
316+
317+ # Test worker role in decode server.
318+ connector = NixlConnector (vllm_config , KVConnectorRole .WORKER )
319+ connector .connector_worker = FakeNixlConnectorWorker (
320+ vllm_config , connector .engine_id )
321+ metadata = NixlConnectorMetadata ()
322+ total_reqs = 5
323+ for i in range (total_reqs ):
324+ metadata .add_new_req (request_id = f"id_{ i } " ,
325+ local_block_ids = [1 , 2 , 3 ],
326+ kv_transfer_params = {
327+ "remote_block_ids" : [4 , 5 , 6 ],
328+ "remote_engine_id" :
329+ FakeNixlConnectorWorker .REMOTE_ENGINE_ID ,
330+ "remote_host" : "localhost" ,
331+ "remote_port" : 1234 ,
332+ })
333+ connector .bind_connector_metadata (metadata )
334+
335+ timeout = 2.5 * total_reqs
336+ cnt_finished_reqs = 0
337+ start = time .perf_counter ()
338+ while time .perf_counter () - start < timeout :
339+ dummy_ctx = ForwardContext (
340+ no_compile_layers = {},
341+ attn_metadata = {},
342+ virtual_engine = 0 ,
343+ )
344+ _before_load = time .perf_counter ()
345+ connector .start_load_kv (dummy_ctx )
346+ _after_load = time .perf_counter ()
347+ assert _after_load - _before_load < 0.1 , "start_load_kv took " \
348+ f"{ _after_load - _before_load } seconds"
349+ time .sleep (0.5 ) # backoff for the async handshake to complete.
350+ connector .bind_connector_metadata (NixlConnectorMetadata ())
351+ _ , done_recving = connector .get_finished (finished_req_ids = set ())
352+ if len (done_recving ) > 0 :
353+ cnt_finished_reqs += len (done_recving )
354+ if cnt_finished_reqs == total_reqs :
355+ return
356+ raise TimeoutError ("Took too long to complete async handshake." )
0 commit comments