Skip to content

Commit 06847be

Browse files
robertgshaw2-redhatApostaCtlrmchlsmthRobert Shaw
authored
[P/D Disagg] [1/N] Support Homogeneous TP > 1 (#65)
* [Update] LMcache connector v1 implementation Signed-off-by: ApostaC <[email protected]> * [Add] examples for disaggregated prefill Signed-off-by: ApostaC <[email protected]> * [add] extra information about evns Signed-off-by: ApostaC <[email protected]> * Initial stubs for P/D scheduling changes Signed-off-by: Tyler Michael Smith <[email protected]> * Updates Signed-off-by: Tyler Michael Smith <[email protected]> * Rs branch (#3) * updated Signed-off-by: [email protected] <[email protected]> * Rs branch (#5) Signed-off-by: [email protected] <[email protected]> * Remove Unneeded Arguments (#7) * updated Signed-off-by: [email protected] <[email protected]> * stash Signed-off-by: [email protected] <[email protected]> * cleanup Signed-off-by: [email protected] <[email protected]> --------- Signed-off-by: [email protected] <[email protected]> * Improve disagg-example.sh (#8) - fix spelling - CUDA_VISIBLE_DEVICES should be set externally Signed-off-by: Tyler Michael Smith <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * added connector Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * update Signed-off-by: [email protected] <[email protected]> * remove Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * seems to load properly Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * Revert "updated" This reverts commit 97316d9. * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * stash Signed-off-by: [email protected] <[email protected]> * added Signed-off-by: [email protected] <[email protected]> * diffs for local dev on macos Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * update Signed-off-by: Robert Shaw <[email protected]> * updaed Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * Checkpoint. Signed-off-by: Tyler Michael Smith <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * Cleanup Signed-off-by: Tyler Michael Smith <[email protected]> * WIP Signed-off-by: Tyler Michael Smith <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated on scheduler side Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * Hacking away Signed-off-by: Tyler Michael Smith <[email protected]> * cleanup Signed-off-by: Robert Shaw <[email protected]> * ensure request removed from running list Signed-off-by: Robert Shaw <[email protected]> * Runs E2E. Garbage output. Crashes on 2nd request Signed-off-by: Tyler Michael Smith <[email protected]> * update Signed-off-by: Tyler Michael Smith <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * rename files Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * updated Signed-off-by: Robert Shaw <[email protected]> * update Signed-off-by: Robert Shaw <[email protected]> * Second request no longer crashes Signed-off-by: Tyler Michael Smith <[email protected]> * Remove gpu_model_runner hacks Signed-off-by: Tyler Michael Smith <[email protected]> * Clean up Justfile Signed-off-by: Tyler Michael Smith <[email protected]> * [Bugfix] Stale finished requests in EMPTY_MODEL_RUNNER_OUTPUT Signed-off-by: Tyler Michael Smith <[email protected]> * update Signed-off-by: Tyler Michael Smith <[email protected]> * justfile edits Signed-off-by: Tyler Michael Smith <[email protected]> * Update Signed-off-by: Tyler Michael Smith <[email protected]> * Fixes - lm_eval gsm8k has correctness Signed-off-by: Tyler Michael Smith <[email protected]> * "just delete the assert" Signed-off-by: Tyler Michael Smith <[email protected]> * fixup precommit issues Signed-off-by: Tyler Michael Smith <[email protected]> * Fixes Signed-off-by: Tyler Michael Smith <[email protected]> * updated (#12) Signed-off-by: [email protected] <[email protected]> * Add Accuracy Test (#13) * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> --------- Signed-off-by: [email protected] <[email protected]> * Preemption Bugfixes (#15) * stash fixed double free issue Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * fixed issue Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updatrd Signed-off-by: [email protected] <[email protected]> * updatrd Signed-off-by: [email protected] <[email protected]> * updatrd Signed-off-by: [email protected] <[email protected]> * updatrd Signed-off-by: [email protected] <[email protected]> * updatrd Signed-off-by: [email protected] <[email protected]> * updatrd Signed-off-by: [email protected] <[email protected]> --------- Signed-off-by: [email protected] <[email protected]> * updated (#16) Signed-off-by: [email protected] <[email protected]> * Fix Bad Merge | Fix Memory Leak in Upstream (#18) * updated Signed-off-by: [email protected] <[email protected]> * fix merge Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> --------- Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * cleanup code Signed-off-by: [email protected] <[email protected]> * cleanup code Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * stash Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updatted Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * revert Signed-off-by: [email protected] <[email protected]> * more spurious changes Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * updated Signed-off-by: [email protected] <[email protected]> * Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Co-authored-by: Tyler Michael Smith <[email protected]> * Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Co-authored-by: Tyler Michael Smith <[email protected]> --------- Signed-off-by: ApostaC <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: [email protected] <[email protected]> Signed-off-by: Robert Shaw <[email protected]> Co-authored-by: ApostaC <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent 8061a5c commit 06847be

File tree

1 file changed

+160
-125
lines changed

1 file changed

+160
-125
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 160 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from vllm.config import VllmConfig
1818
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1919
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
20+
from vllm.distributed.parallel_state import (
21+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
22+
get_tp_group)
2023
from vllm.logger import init_logger
2124
from vllm.sampling_params import KVTransferParams
2225
from vllm.utils import round_down
@@ -47,8 +50,6 @@ class NixlAgentMetadata(
4750
dict=True):
4851
engine_id: str
4952
agent_metadata: bytes
50-
# Base addr for each layer for KVs
51-
# NOTE: we will need another list for TP>1
5253
kv_caches_base_addr: list[int]
5354
num_blocks: int
5455

@@ -222,47 +223,53 @@ def __init__(self, engine_id: str):
222223

223224
# Agent.
224225
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
225-
# Map of engine_id -> list[agent_names] (1 per rank).
226-
self._remote_agents: dict[str, list[str]] = {}
226+
# Map of engine_id -> agent_name.
227+
self._remote_agents: dict[str, str] = {}
227228

228229
# Metadata.
229230
self.engine_id = engine_id
230-
self.rank = 0
231+
self.rank = get_tensor_model_parallel_rank()
232+
self.world_size = get_tensor_model_parallel_world_size()
233+
self.tp_group = get_tp_group()
231234

232235
# KV Caches and nixl tracking data.
233236
self.kv_caches: dict[str, torch.Tensor] = {}
234237

235238
# Map of engine_id -> kv_caches_base_addr
236-
# For Local: base addr for *this* rank, each layer for K,V
237-
# For Remote: base addr for *each* rank, each layer for K,V
238-
# KV_CACHES_ADDR_TYPE = Union[list[tuple[int, int]],
239-
# list[list[tuple[int, int]]]]
240239
self.kv_caches_base_addr: dict[str, list[int]] = {}
241240

242241
# Number of NIXL regions. Currently one region per cache
243242
# (so 1 per layer for MLA, otherwise 2 per layer)
244243
self.num_regions = 0
245244

246-
# Map of tp_mult -> nixl_prepped_dlist_handle (int).
247-
self.src_xfer_side_handles: dict[int, int] = {}
248-
# Map of engine_id -> map[tp_mult -> nixl_prepped_dlist_handle (int)].
249-
self.dst_xfer_side_handles: defaultdict[str,
250-
dict[int,
251-
int]] = defaultdict(dict)
245+
# nixl_prepped_dlist_handle (int).
246+
self.src_xfer_side_handle: int = 0
247+
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
248+
self.dst_xfer_side_handles: dict[str, int] = {}
249+
252250
# Map of engine_id -> num_blocks.
253251
self.dst_num_blocks: dict[str, int] = {}
254252
self._registered_descs: list[Any] = []
255253

256254
# In progress transfers.
257255
# [req_id -> list[handle]]
258-
self._recving_transfers: dict[str, list[Any]] = defaultdict(list[Any])
256+
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
257+
list[Any])
258+
259+
# Complete transfer tracker. Used by the rank 0 to track finished
260+
# transactions on ranks 1 to N-1.
261+
# [req_id -> count]
262+
self._done_recving_count: defaultdict[str,
263+
int] = defaultdict(lambda: 0)
264+
self._done_sending_count: defaultdict[str,
265+
int] = defaultdict(lambda: 0)
259266

260267
# Background thread for establishing new connections.
261268
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
262269

263270
@staticmethod
264271
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
265-
ready_event: threading.Event):
272+
ready_event: threading.Event, rank: int):
266273
"""Background thread for getting new NIXL handshakes."""
267274
# NOTE(rob): this is a simple implementation. We will move
268275
# to a better approach like an ETCD server in the future.
@@ -280,8 +287,13 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
280287

281288
# Listen for new requests for metadata.
282289
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
283-
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT
284-
with zmq_ctx(zmq.ROUTER, f"tcp://{host}:{port}") as sock:
290+
# NOTE(rob): we need each rank to have a unique port. This
291+
# hack to keeps us moving. We will switch when moving to etcd
292+
# or where we have a single ZMQ socket in the scheduler.
293+
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
294+
path = f"tcp://{host}:{port}"
295+
logger.debug("Starting listening on path: %s", path)
296+
with zmq_ctx(zmq.ROUTER, path) as sock:
285297
ready_event.set()
286298
while True:
287299
identity, _, msg = sock.recv_multipart()
@@ -294,7 +306,12 @@ def _nixl_handshake(self, host: str, port: int):
294306
"""Do a NIXL handshake with a remote instance."""
295307

296308
start_time = time.perf_counter()
297-
with zmq_ctx(zmq.REQ, f"tcp://{host}:{port}") as sock:
309+
# NOTE(rob): we need each rank to have a unique port. This is
310+
# a hack to keep us moving. We will switch when moving to etcd
311+
# or where we have a single ZMQ socket in the scheduler.
312+
path = f"tcp://{host}:{port + self.rank}"
313+
logger.debug("Querying metadata on path: %s", path)
314+
with zmq_ctx(zmq.REQ, path) as sock:
298315
# Send query for the request.
299316
sock.send(GET_META_MSG)
300317
metadata_bytes = sock.recv()
@@ -364,90 +381,125 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
364381
ready_event = threading.Event()
365382
self._nixl_handshake_listener_t = threading.Thread(
366383
target=self._nixl_handshake_listener,
367-
args=(metadata, ready_event),
384+
args=(metadata, ready_event, self.rank),
368385
daemon=True,
369386
name="nixl_handshake_listener")
370-
import os
371-
if os.getenv("SKIP", None) != "1":
372-
self._nixl_handshake_listener_t.start()
373-
ready_event.wait()
387+
self._nixl_handshake_listener_t.start()
388+
ready_event.wait()
374389

375-
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0):
390+
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
376391
engine_id = nixl_agent_meta.engine_id
377392
if engine_id in self._remote_agents:
378393
return
379394

380-
num_blocks = nixl_agent_meta.num_blocks
381-
logger.debug("Adding remote agent %s %s", engine_id, str(num_blocks))
382-
383-
agent_names = [
384-
self.nixl_wrapper.add_remote_agent(nixl_agent_meta.agent_metadata)
385-
]
386-
387-
self._remote_agents[engine_id] = agent_names
395+
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
396+
nixl_agent_meta.agent_metadata)
388397
self.kv_caches_base_addr[
389398
engine_id] = nixl_agent_meta.kv_caches_base_addr
390399

391-
# NOTE: once we support heterogeneous TP, we will need maintain the
392-
# src for each TP multiplier.
393-
# NOTE(rob): Dynamo only supports D TP size > P TP size.
394-
# https://github.com/vllm-project/vllm/pull/16124/files#diff-876efa5533f5dcff3fba850e8684a47d53c112e287988957c115b11691374f4bR331 # noqa: E501
395-
# Create descs and xfer side handles.
396-
tp_multiplier = 1
397-
dst_block_len = self.block_len // tp_multiplier
398-
if tp_multiplier not in self.src_xfer_side_handles:
399-
# Create descs and xfer side handles.
400-
blocks_data = []
401-
for base_addr in self.kv_caches_base_addr[self.engine_id]:
402-
for block_id in range(self.num_blocks):
403-
block_offset = block_id * self.block_len
404-
for i in range(tp_multiplier):
405-
tp_multiplier_offset = tp_idx * dst_block_len
406-
blocks_data.append(
407-
(base_addr + block_offset + tp_multiplier_offset,
408-
dst_block_len, self.rank))
409-
logger.debug("Created %s blocks for src engine %s and rank %s",
410-
len(blocks_data), self.engine_id, self.rank)
411-
412-
# Register with NIXL.
413-
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
414-
self.src_xfer_side_handles[tp_multiplier] = (
415-
self.nixl_wrapper.prep_xfer_dlist("", descs))
416-
417-
# create dst xfer side handles
418-
self.dst_num_blocks[engine_id] = num_blocks
400+
# Create src descs and xfer side handles.
401+
blocks_data = []
402+
for base_addr in self.kv_caches_base_addr[self.engine_id]:
403+
for block_id in range(self.num_blocks):
404+
block_offset = block_id * self.block_len
405+
# (addr, len, device id)
406+
blocks_data.append(
407+
(base_addr + block_offset, self.block_len, self.rank))
408+
logger.debug("Created %s blocks for src engine %s and rank %s",
409+
len(blocks_data), self.engine_id, self.rank)
410+
411+
# Register with NIXL.
412+
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
413+
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
414+
"NIXL_INIT_AGENT", descs)
415+
416+
# Create dst descs and xfer side handles.
417+
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
419418
blocks_data = []
420419
for base_addr in self.kv_caches_base_addr[engine_id]:
421-
for block_id in range(num_blocks):
422-
block_offset = block_id * dst_block_len
423-
blocks_data.append((base_addr + block_offset, dst_block_len,
424-
self.rank * tp_multiplier))
420+
for block_id in range(nixl_agent_meta.num_blocks):
421+
block_offset = block_id * self.block_len
422+
# (addr, len, device id)
423+
blocks_data.append(
424+
(base_addr + block_offset, self.block_len, self.rank))
425425
logger.debug("Created %s blocks for dst engine %s and rank %s",
426426
len(blocks_data), engine_id, self.rank)
427+
427428
# Register with NIXL.
428429
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
429-
self.dst_xfer_side_handles[engine_id][tp_idx] = (
430-
self.nixl_wrapper.prep_xfer_dlist(
431-
self._remote_agents[engine_id][self.rank * tp_multiplier +
432-
tp_idx], descs))
430+
self.dst_xfer_side_handles[
431+
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
432+
self._remote_agents[engine_id], descs)
433433

434434
def get_finished(self) -> tuple[set[str], set[str]]:
435-
"""Get requests that are done sending or recving."""
435+
"""
436+
Get requests that are done sending or recving.
437+
438+
In TP>1 setup, each rank exchanges KVs with its counterpart
439+
ranks independently. get_finished() runs in a worker creates
440+
the done_sending and done_recving sets that are sent to the
441+
scheduler via ModelRunnerOutput by Rank 0. To avoid race
442+
ensure trnxs are done before adding to finished, Ranks 1 to
443+
N-1 communicate to Rank 0 once their transaction is done.
444+
Rank 0 only returns finished once all ranks are complete.
445+
"""
436446
done_sending = self._get_new_notifs()
437447
done_recving = self._pop_done_transfers(self._recving_transfers)
438448
if len(done_sending) > 0 or len(done_recving) > 0:
439449
logger.debug(
440-
"get_finished: %s requests done sending "
441-
"and %s requests done recving", len(done_sending),
450+
"Rank %s, get_finished: %s requests done sending "
451+
"and %s requests done recving", self.rank, len(done_sending),
442452
len(done_recving))
443-
return done_sending, done_recving
453+
454+
if self.world_size == 1:
455+
return done_sending, done_recving
456+
457+
# Rank 0: get finished from all other ranks.
458+
if self.rank == 0:
459+
for req_id in done_sending:
460+
self._done_sending_count[req_id] += 1
461+
for req_id in done_recving:
462+
self._done_recving_count[req_id] += 1
463+
464+
# Keep track of how many other ranks have finished.
465+
other_ranks_finished_ids: list[str] = []
466+
for i in range(1, self.world_size):
467+
other_ranks_finished_ids.extend(
468+
self.tp_group.recv_object(src=i))
469+
for req_id in other_ranks_finished_ids:
470+
if (req_id in self._done_recving_count
471+
or req_id in self._recving_transfers):
472+
self._done_recving_count[req_id] += 1
473+
else:
474+
self._done_sending_count[req_id] += 1
475+
476+
# Return ids that finished on all ranks to the scheduler.
477+
all_done_recving: set[str] = set()
478+
for req_id in list(self._done_recving_count.keys()):
479+
if self._done_recving_count[req_id] == self.world_size:
480+
del self._done_recving_count[req_id]
481+
all_done_recving.add(req_id)
482+
483+
all_done_sending: set[str] = set()
484+
for req_id in list(self._done_sending_count.keys()):
485+
if self._done_sending_count[req_id] == self.world_size:
486+
del self._done_sending_count[req_id]
487+
all_done_sending.add(req_id)
488+
489+
return all_done_sending, all_done_recving
490+
491+
# Ranks 1 to N-1: send finished ids to Rank 0.
492+
else:
493+
finished_req_ids = list(done_recving.union(done_sending))
494+
self.tp_group.send_object(finished_req_ids, dst=0)
495+
496+
# Unused as only Rank 0 results are sent to scheduler.
497+
return done_sending, done_recving
444498

445499
def _get_new_notifs(self) -> set[str]:
446500
"""Get req_ids which got a remote xfer message."""
447501

448502
notified_req_ids: set[str] = set()
449-
# TODO: handle the TP case (N notifies for TP=N).
450-
# See: vllm/worker/worker_base.py L476 in DynamoPR.
451503
for req_ids in self.nixl_wrapper.get_new_notifs().values():
452504
for req_id in req_ids:
453505
assert req_id not in notified_req_ids
@@ -539,61 +591,44 @@ def _read_blocks(
539591
if len(local_block_ids) == 0:
540592
return
541593

542-
# TODO: support TP multipliers.
543-
tp_multiplier = 1
544-
remote_block_descs_ids = self._get_block_descs_ids(
545-
dst_engine_id, "all", remote_block_ids)
546-
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
547-
548-
# Read the data from the remote.
549-
for i in range(tp_multiplier):
550-
local_block_descs_ids = self._get_block_descs_ids(
551-
self.engine_id,
552-
"all",
553-
local_block_ids,
554-
i=None, #TODO: Enable both tp_multiplier and staging_ranges.
555-
tp_multiplier=tp_multiplier,
556-
staging_ranges=None)
557-
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
558-
remote_xfer_side_handle = self.dst_xfer_side_handles[
559-
dst_engine_id][i]
560-
561-
# NOTE(rob): we use the request_id as the notify msg, so we
562-
# must use the same request_id in both the p and d workers.
563-
handle = self.nixl_wrapper.make_prepped_xfer(
564-
"READ",
565-
local_xfer_side_handle,
566-
local_block_descs_ids,
567-
remote_xfer_side_handle,
568-
remote_block_descs_ids,
569-
notif_msg=request_id.encode("utf-8"),
570-
)
571-
572-
# Call transfer to begin the async transfer
573-
# We will check this is done in the next forward pass.
574-
self.nixl_wrapper.transfer(handle)
575-
self._recving_transfers[request_id].append(handle)
594+
# Get side handles.
595+
local_xfer_side_handle = self.src_xfer_side_handle
596+
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
576597

577-
def _get_block_descs_ids(self,
578-
engine_id,
579-
region_ids,
580-
block_ids,
581-
i=None,
582-
tp_multiplier=1,
583-
staging_ranges=None):
598+
# Get descs ids.
599+
remote_block_descs_ids = self._get_block_descs_ids(
600+
dst_engine_id, remote_block_ids)
601+
local_block_descs_ids = self._get_block_descs_ids(
602+
self.engine_id, local_block_ids)
603+
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
604+
605+
# Prepare transfer with Nixl.
606+
handle = self.nixl_wrapper.make_prepped_xfer(
607+
"READ",
608+
local_xfer_side_handle,
609+
local_block_descs_ids,
610+
remote_xfer_side_handle,
611+
remote_block_descs_ids,
612+
notif_msg=request_id.encode("utf-8"),
613+
)
584614

585-
if region_ids == "all":
586-
region_ids = range(self.num_regions)
587-
if block_ids == "all":
588-
block_ids = range(self.num_blocks)
615+
# Begin async xfer.
616+
self.nixl_wrapper.transfer(handle)
589617

590-
descs_ids = []
618+
# Use handle to check completion in future step().
619+
self._recving_transfers[request_id].append(handle)
591620

592-
if i is not None:
593-
raise NotImplementedError("Prefill and Decode instances must have "
594-
"the same TP size.")
621+
def _get_block_descs_ids(self, engine_id: str,
622+
block_ids: list[int]) -> list[int]:
623+
"""Get the descs ids for a set of block ids."""
624+
# TODO(rob): should we precompute this?
595625

626+
# range(1) for MLA, range(2) otherwise.
627+
region_ids = range(self.num_regions)
596628
num_blocks = self.dst_num_blocks[engine_id]
629+
630+
# Compute the desc ids for each block.
631+
descs_ids: list[int] = []
597632
for reg_id in region_ids:
598633
for block_id in block_ids:
599634
descs_ids.append(reg_id * num_blocks + block_id)

0 commit comments

Comments
 (0)