Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def test_multi_xfer_one_engine(
num_xfers + 6,
],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
Expand Down Expand Up @@ -526,6 +527,7 @@ def test_async_load_kv(
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": "prefill-id",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
Expand Down Expand Up @@ -581,6 +583,7 @@ def test_concurrent_load_kv(
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-id-{i}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
Expand Down Expand Up @@ -746,6 +749,7 @@ def test_kv_connector_stats(dist_init):
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
Expand Down Expand Up @@ -1459,6 +1463,7 @@ def test_handshake_failure_returns_finished(dist_init):
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
Expand Down Expand Up @@ -1508,6 +1513,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
kv_transfer_params={
"remote_block_ids": [10, 11, 12],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def create_request(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_request_id=f"prefill-{request_id}",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
Expand Down
14 changes: 12 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ReqMeta:
remote_host: str
remote_port: int
remote_engine_id: str
remote_request_id: str
tp_size: int


Expand All @@ -144,6 +145,7 @@ def add_new_req(
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_request_id=kv_transfer_params["remote_request_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
Expand Down Expand Up @@ -530,7 +532,12 @@ def update_state_after_alloc(
if params.get("remote_block_ids"):
if all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
Expand Down Expand Up @@ -659,6 +666,7 @@ def request_finished(
do_remote_decode=False,
remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
Expand Down Expand Up @@ -1946,6 +1954,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
remote_request_id=meta.remote_request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
)
Expand All @@ -1956,6 +1965,7 @@ def _read_blocks(
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
remote_request_id: str,
):
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
Expand Down Expand Up @@ -1988,7 +1998,7 @@ def _read_blocks(
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
notif_id = f"{request_id}:{tp_ratio}".encode()
notif_id = f"{remote_request_id}:{tp_ratio}".encode()

# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
Expand Down