Skip to content

Commit 91f7d9d

Browse files
lk-chennjhill
andauthored
[P/D] Asynchronously do _nixl_handshake (#19836)
Signed-off-by: Linkun Chen <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 8619e71 commit 91f7d9d

File tree

2 files changed

+264
-96
lines changed

2 files changed

+264
-96
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 168 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,6 @@
77
from typing import Optional
88
from unittest.mock import patch
99

10-
import pytest
11-
12-
try:
13-
from nixl._api import nixl_agent as NixlWrapper
14-
except ImportError:
15-
NixlWrapper = None
16-
1710
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
1811
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
1912
NixlConnectorWorker)
@@ -92,7 +85,8 @@ def test_prompt_less_than_block_size():
9285
class FakeNixlWrapper:
9386
"""Mock implementation of NixlWrapper for testing.
9487
95-
We don't inherit from NixlWrapper because NixlWrapper could be None.
88+
We don't inherit from nixl._api.nixl_agent because nixl may not be
89+
installed.
9690
"""
9791

9892
AGENT_METADATA = b"fake_agent_metadata"
@@ -167,7 +161,7 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
167161
super().__init__(*args, **kwargs)
168162
self._hand_shake_latency = hand_shake_latency
169163

170-
def _nixl_handshake(self, host: str, port: int):
164+
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
171165
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
172166
time.sleep(self._hand_shake_latency)
173167
# These should've been done in register_kv_caches(), called by
@@ -177,7 +171,7 @@ def _nixl_handshake(self, host: str, port: int):
177171
self.num_blocks = 1
178172
self.dst_num_blocks[self.engine_id] = self.num_blocks
179173

180-
self.add_remote_agent(
174+
remote_agent_name = self.add_remote_agent(
181175
NixlAgentMetadata(
182176
engine_id=self.REMOTE_ENGINE_ID,
183177
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
@@ -187,60 +181,176 @@ def _nixl_handshake(self, host: str, port: int):
187181
block_len=self.block_len,
188182
attn_backend_name=self.backend_name,
189183
))
190-
191-
192-
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
193-
@patch(
194-
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
195-
FakeNixlWrapper)
196-
def test_multi_xfer_one_engine(
197-
# dist_init is a fixture that initializes the distributed environment.
198-
dist_init):
199-
"""Test case where multiple xfers are initiated to the same engine.
200-
201-
This test triggers the connector to load remote KV for the same
202-
`request_id`. The transfer is not done immediately due to
203-
`set_cycles_before_xfer_done`, so there is a state where there are multiple
204-
transfer states for the same `request_id`, and `get_finished` should handle
205-
it correctly (wait for all transfers to be done).
206-
"""
207-
vllm_config = create_vllm_config()
208-
209-
request_id = "req_id"
210-
211-
# Test worker role in decode server.
212-
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
213-
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
214-
connector.engine_id,
215-
hand_shake_latency=0)
216-
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
217-
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
218-
for i in range(4):
184+
return {0: remote_agent_name}
185+
186+
187+
class TestNixlHandshake:
188+
189+
@patch(
190+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
191+
FakeNixlWrapper)
192+
def test_multi_xfer_one_engine(
193+
self,
194+
# dist_init is a fixture that initializes the distributed environment.
195+
dist_init):
196+
"""Test case where multiple xfers are initiated to the same engine.
197+
198+
This test triggers the connector to load remote KV for the same
199+
`request_id`. The transfer is not done immediately due to
200+
`set_cycles_before_xfer_done`, so there is a state where there are
201+
multiple transfer states for the same `request_id`, and `get_finished`
202+
should handle it correctly (wait for all transfers to be done).
203+
"""
204+
vllm_config = create_vllm_config()
205+
206+
request_id = "req_id"
207+
208+
# Test worker role in decode server.
209+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
210+
connector.connector_worker = FakeNixlConnectorWorker(
211+
vllm_config, connector.engine_id, hand_shake_latency=0)
212+
assert isinstance(connector.connector_worker.nixl_wrapper,
213+
FakeNixlWrapper)
214+
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
215+
num_xfers = 4
216+
while True:
217+
# For the same request_id, initiate multiple xfers across different
218+
# round of `execute_model` calls.
219+
metadata = NixlConnectorMetadata()
220+
if num_xfers > 0:
221+
num_xfers -= 1
222+
metadata.add_new_req(
223+
request_id=request_id,
224+
local_block_ids=[
225+
num_xfers + 1, num_xfers + 2, num_xfers + 3
226+
],
227+
kv_transfer_params={
228+
"remote_block_ids":
229+
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
230+
"remote_engine_id":
231+
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
232+
"remote_host":
233+
"localhost",
234+
"remote_port":
235+
1234,
236+
})
237+
connector.bind_connector_metadata(metadata)
238+
239+
# Mimic maybe_setup_kv_connector in gpu_model_runner.
240+
dummy_ctx = ForwardContext(
241+
no_compile_layers={},
242+
attn_metadata={},
243+
virtual_engine=0,
244+
)
245+
_before_load = time.perf_counter()
246+
connector.start_load_kv(dummy_ctx)
247+
_after_load = time.perf_counter()
248+
assert _after_load - _before_load < 0.1, "start_load_kv took " \
249+
f"{_after_load - _before_load} seconds"
250+
251+
# Mimic get_finished_kv_transfers in gpu_model_runner.
252+
_, done_recving = connector.get_finished(finished_req_ids=set())
253+
if len(done_recving) > 0:
254+
assert request_id in done_recving
255+
break
256+
257+
connector.clear_connector_metadata()
258+
259+
@patch(
260+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
261+
FakeNixlWrapper)
262+
def test_async_load_kv(
263+
self,
264+
# dist_init is a fixture that initializes the distributed environment.
265+
dist_init):
266+
"""Test that NixlConnector's start_load_kv should be non-blocking."""
267+
268+
vllm_config = create_vllm_config()
269+
270+
# Test worker role in decode server.
271+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
272+
connector.connector_worker = FakeNixlConnectorWorker(
273+
vllm_config, connector.engine_id)
219274
metadata = NixlConnectorMetadata()
220-
metadata.add_new_req(request_id=request_id,
221-
local_block_ids=[i + 1, i + 2, i + 3],
275+
metadata.add_new_req(request_id="id",
276+
local_block_ids=[1, 2, 3],
222277
kv_transfer_params={
223-
"remote_block_ids": [i + 4, i + 5, i + 6],
278+
"remote_block_ids": [4, 5, 6],
224279
"remote_engine_id":
225280
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
226281
"remote_host": "localhost",
227282
"remote_port": 1234,
228283
})
229284
connector.bind_connector_metadata(metadata)
230285

231-
dummy_ctx = ForwardContext(
232-
no_compile_layers={},
233-
attn_metadata={},
234-
virtual_engine=0,
235-
)
236-
_before_load = time.perf_counter()
237-
connector.start_load_kv(dummy_ctx)
238-
_after_load = time.perf_counter()
239-
assert _after_load - _before_load < 0.1, "start_load_kv took " \
240-
f"{_after_load - _before_load} seconds"
241-
242-
while True:
243-
_, done_recving = connector.get_finished(finished_req_ids=set())
244-
if len(done_recving) > 0:
245-
assert request_id in done_recving
246-
break
286+
timeout = 2.5
287+
start = time.perf_counter()
288+
while time.perf_counter() - start < timeout:
289+
dummy_ctx = ForwardContext(
290+
no_compile_layers={},
291+
attn_metadata={},
292+
virtual_engine=0,
293+
)
294+
_before_load = time.perf_counter()
295+
connector.start_load_kv(dummy_ctx)
296+
_after_load = time.perf_counter()
297+
assert _after_load - _before_load < 0.1, "start_load_kv took " \
298+
f"{_after_load - _before_load} seconds"
299+
time.sleep(0.5) # backoff for the async handshake to complete.
300+
connector.bind_connector_metadata(NixlConnectorMetadata())
301+
_, done_recving = connector.get_finished(finished_req_ids=set())
302+
if len(done_recving) > 0:
303+
return
304+
raise TimeoutError("Took too long to complete async handshake.")
305+
306+
@patch(
307+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
308+
FakeNixlWrapper)
309+
def test_concurrent_load_kv(
310+
self,
311+
# dist_init is a fixture that initializes the distributed environment.
312+
dist_init):
313+
"""Test that multiple start_load_kv calls should occur concurrently."""
314+
315+
vllm_config = create_vllm_config()
316+
317+
# Test worker role in decode server.
318+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
319+
connector.connector_worker = FakeNixlConnectorWorker(
320+
vllm_config, connector.engine_id)
321+
metadata = NixlConnectorMetadata()
322+
total_reqs = 5
323+
for i in range(total_reqs):
324+
metadata.add_new_req(request_id=f"id_{i}",
325+
local_block_ids=[1, 2, 3],
326+
kv_transfer_params={
327+
"remote_block_ids": [4, 5, 6],
328+
"remote_engine_id":
329+
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
330+
"remote_host": "localhost",
331+
"remote_port": 1234,
332+
})
333+
connector.bind_connector_metadata(metadata)
334+
335+
timeout = 2.5 * total_reqs
336+
cnt_finished_reqs = 0
337+
start = time.perf_counter()
338+
while time.perf_counter() - start < timeout:
339+
dummy_ctx = ForwardContext(
340+
no_compile_layers={},
341+
attn_metadata={},
342+
virtual_engine=0,
343+
)
344+
_before_load = time.perf_counter()
345+
connector.start_load_kv(dummy_ctx)
346+
_after_load = time.perf_counter()
347+
assert _after_load - _before_load < 0.1, "start_load_kv took " \
348+
f"{_after_load - _before_load} seconds"
349+
time.sleep(0.5) # backoff for the async handshake to complete.
350+
connector.bind_connector_metadata(NixlConnectorMetadata())
351+
_, done_recving = connector.get_finished(finished_req_ids=set())
352+
if len(done_recving) > 0:
353+
cnt_finished_reqs += len(done_recving)
354+
if cnt_finished_reqs == total_reqs:
355+
return
356+
raise TimeoutError("Took too long to complete async handshake.")

0 commit comments

Comments
 (0)