-
Notifications
You must be signed in to change notification settings - Fork 6
[P/D Disagg] [1/N] Support Homogeneous TP > 1 #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 141 commits
4730522
4162650
3ccd34c
161010c
38a2eb8
6c3191f
1f708e9
038f2f8
5c4fc6f
1800689
7a1f25f
2385d8e
6eeb47c
266fcee
f7e16f1
f591b8e
184d0b6
d4a9e5b
4b0d1dc
bfef039
54f4a43
e604b09
2fc00ad
e5967b6
f1bc0f7
1cea2bb
489e4c0
437ac91
ea47af7
554b27d
1aea5ba
e0c112b
c7717c1
e0af1db
9533471
2eb068e
0f2b7e3
6127cb8
568249e
ccb44ea
3785905
8a94b2e
ac19437
6391ec9
7dd764b
97316d9
2771353
baed1bf
d0ad6d9
055885e
5ed3806
58266b5
344d9da
2996638
bcc88dc
62205ae
b4609a5
5d78ba6
c1f26b9
9b9ef36
c60639e
006dda3
c5e023e
8b0c93c
5e45d90
20a5491
cee3c61
5972571
1b69d33
74e105a
8adf1ad
21ab3d9
3a27bbc
f252df9
8104803
10bbe21
a14278c
65ea91f
f2550ef
985bac3
bf37a7d
ebe1263
a008aa3
195dceb
e2cc365
2324a50
b4b64fe
6686397
8736043
dcbf6e5
7c8e21a
a4855d2
0914040
c5b3053
7502819
7768b96
a5950b7
610a357
5b026ab
f2fadd6
4060f86
bfe9d19
ced529a
83f2872
e853b3c
1c45ed1
a45a694
f6d0ac5
2f9a3f3
90ba831
9378594
790c1b2
e4802fd
f4c2915
6346a64
a8832ec
dd0935a
42a28ff
422a9ac
836e76b
0aafe4a
4fe1829
d6b2531
1bbd623
afdcd2f
6790c00
87277d6
8ff421e
79af352
e21f5f9
39fee21
99a5afd
1e0db0b
9bdbe38
911e480
357bd03
93a32eb
181d68d
01e5864
04cba85
06c5c39
9a87c34
027689d
48add56
ed6fd4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,9 @@ | |
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
| KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) | ||
| from vllm.distributed.parallel_state import ( | ||
| get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, | ||
| get_tp_group) | ||
| from vllm.logger import init_logger | ||
| from vllm.sampling_params import KVTransferParams | ||
| from vllm.utils import round_down | ||
|
|
@@ -47,8 +50,6 @@ class NixlAgentMetadata( | |
| dict=True): | ||
| engine_id: str | ||
| agent_metadata: bytes | ||
| # Base addr for each layer for KVs | ||
| # NOTE: we will need another list for TP>1 | ||
| kv_caches_base_addr: list[int] | ||
| num_blocks: int | ||
|
|
||
|
|
@@ -222,47 +223,53 @@ def __init__(self, engine_id: str): | |
|
|
||
| # Agent. | ||
| self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) | ||
| # Map of engine_id -> list[agent_names] (1 per rank). | ||
| self._remote_agents: dict[str, list[str]] = {} | ||
| # Map of engine_id -> agent_name. | ||
| self._remote_agents: dict[str, str] = {} | ||
|
|
||
| # Metadata. | ||
| self.engine_id = engine_id | ||
| self.rank = 0 | ||
| self.rank = get_tensor_model_parallel_rank() | ||
| self.world_size = get_tensor_model_parallel_world_size() | ||
| self.tp_group = get_tp_group() | ||
|
|
||
| # KV Caches and nixl tracking data. | ||
| self.kv_caches: dict[str, torch.Tensor] = {} | ||
|
|
||
| # Map of engine_id -> kv_caches_base_addr | ||
| # For Local: base addr for *this* rank, each layer for K,V | ||
| # For Remote: base addr for *each* rank, each layer for K,V | ||
| # KV_CACHES_ADDR_TYPE = Union[list[tuple[int, int]], | ||
| # list[list[tuple[int, int]]]] | ||
| self.kv_caches_base_addr: dict[str, list[int]] = {} | ||
|
|
||
| # Number of NIXL regions. Currently one region per cache | ||
| # (so 1 per layer for MLA, otherwise 2 per layer) | ||
| self.num_regions = 0 | ||
|
|
||
| # Map of tp_mult -> nixl_prepped_dlist_handle (int). | ||
| self.src_xfer_side_handles: dict[int, int] = {} | ||
| # Map of engine_id -> map[tp_mult -> nixl_prepped_dlist_handle (int)]. | ||
| self.dst_xfer_side_handles: defaultdict[str, | ||
| dict[int, | ||
| int]] = defaultdict(dict) | ||
| # nixl_prepped_dlist_handle (int). | ||
| self.src_xfer_side_handle: int = 0 | ||
| # Map of engine_id -> nixl_prepped_dlist_handle (int)]. | ||
| self.dst_xfer_side_handles: dict[str, int] = {} | ||
|
|
||
| # Map of engine_id -> num_blocks. | ||
| self.dst_num_blocks: dict[str, int] = {} | ||
| self._registered_descs: list[Any] = [] | ||
|
|
||
| # In progress transfers. | ||
| # [req_id -> list[handle]] | ||
| self._recving_transfers: dict[str, list[Any]] = defaultdict(list[Any]) | ||
| self._recving_transfers: defaultdict[str, list[Any]] = defaultdict( | ||
| list[Any]) | ||
|
|
||
| # Complete transfer tracker. Used by the rank 0 to track finished | ||
| # transactions on ranks 1 to N-1. | ||
| # [req_id -> count] | ||
| self._done_recving_count: defaultdict[str, | ||
| int] = defaultdict(lambda: 0) | ||
| self._done_sending_count: defaultdict[str, | ||
| int] = defaultdict(lambda: 0) | ||
|
|
||
| # Background thread for establishing new connections. | ||
| self._nixl_handshake_listener_t: Optional[threading.Thread] = None | ||
|
|
||
| @staticmethod | ||
| def _nixl_handshake_listener(metadata: NixlAgentMetadata, | ||
| ready_event: threading.Event): | ||
| ready_event: threading.Event, rank: int): | ||
| """Background thread for getting new NIXL handshakes.""" | ||
| # NOTE(rob): this is a simple implementation. We will move | ||
| # to a better approach like an ETCD server in the future. | ||
|
|
@@ -280,8 +287,13 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, | |
|
|
||
| # Listen for new requests for metadata. | ||
| host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST | ||
| port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT | ||
| with zmq_ctx(zmq.ROUTER, f"tcp://{host}:{port}") as sock: | ||
| # NOTE(rob): we need each rank to have a unique port. This | ||
| # hack to keeps us moving. We will switch when moving to etcd | ||
| # or where we have a single ZMQ socket in the scheduler. | ||
| port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank | ||
| path = f"tcp://{host}:{port}" | ||
| logger.debug("Starting listening on path: %s", path) | ||
| with zmq_ctx(zmq.ROUTER, path) as sock: | ||
| ready_event.set() | ||
| while True: | ||
| identity, _, msg = sock.recv_multipart() | ||
|
|
@@ -294,7 +306,12 @@ def _nixl_handshake(self, host: str, port: int): | |
| """Do a NIXL handshake with a remote instance.""" | ||
|
|
||
| start_time = time.perf_counter() | ||
| with zmq_ctx(zmq.REQ, f"tcp://{host}:{port}") as sock: | ||
| # NOTE(rob): we need each rank to have a unique port. This | ||
| # hack to keeps us moving. We will switch when moving to etcd | ||
| # or where we have a single ZMQ socket in the scheduler. | ||
| path = f"tcp://{host}:{port + self.rank}" | ||
| logger.debug("Querying metadata on path: %s", path) | ||
| with zmq_ctx(zmq.REQ, path) as sock: | ||
| # Send query for the request. | ||
| sock.send(GET_META_MSG) | ||
| metadata_bytes = sock.recv() | ||
|
|
@@ -364,83 +381,115 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| ready_event = threading.Event() | ||
| self._nixl_handshake_listener_t = threading.Thread( | ||
| target=self._nixl_handshake_listener, | ||
| args=(metadata, ready_event), | ||
| args=(metadata, ready_event, self.rank), | ||
| daemon=True, | ||
| name="nixl_handshake_listener") | ||
| import os | ||
| if os.getenv("SKIP", None) != "1": | ||
| self._nixl_handshake_listener_t.start() | ||
| ready_event.wait() | ||
| self._nixl_handshake_listener_t.start() | ||
| ready_event.wait() | ||
|
|
||
| def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0): | ||
| def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): | ||
| engine_id = nixl_agent_meta.engine_id | ||
| if engine_id in self._remote_agents: | ||
| return | ||
|
|
||
| num_blocks = nixl_agent_meta.num_blocks | ||
| logger.debug("Adding remote agent %s %s", engine_id, str(num_blocks)) | ||
|
|
||
| agent_names = [ | ||
| self.nixl_wrapper.add_remote_agent(nixl_agent_meta.agent_metadata) | ||
| ] | ||
|
|
||
| self._remote_agents[engine_id] = agent_names | ||
| self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( | ||
| nixl_agent_meta.agent_metadata) | ||
| self.kv_caches_base_addr[ | ||
| engine_id] = nixl_agent_meta.kv_caches_base_addr | ||
|
|
||
| # NOTE: once we support heterogeneous TP, we will need maintain the | ||
| # src for each TP multiplier. | ||
| # NOTE(rob): Dynamo only supports D TP size > P TP size. | ||
| # https://github.com/vllm-project/vllm/pull/16124/files#diff-876efa5533f5dcff3fba850e8684a47d53c112e287988957c115b11691374f4bR331 # noqa: E501 | ||
| # Create descs and xfer side handles. | ||
| tp_multiplier = 1 | ||
| dst_block_len = self.block_len // tp_multiplier | ||
| if tp_multiplier not in self.src_xfer_side_handles: | ||
| # Create descs and xfer side handles. | ||
| blocks_data = [] | ||
| for base_addr in self.kv_caches_base_addr[self.engine_id]: | ||
| for block_id in range(self.num_blocks): | ||
| block_offset = block_id * self.block_len | ||
| for i in range(tp_multiplier): | ||
| tp_multiplier_offset = tp_idx * dst_block_len | ||
| blocks_data.append( | ||
| (base_addr + block_offset + tp_multiplier_offset, | ||
| dst_block_len, self.rank)) | ||
| logger.debug("Created %s blocks for src engine %s and rank %s", | ||
| len(blocks_data), self.engine_id, self.rank) | ||
|
|
||
| # Register with NIXL. | ||
| descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") | ||
| self.src_xfer_side_handles[tp_multiplier] = ( | ||
| self.nixl_wrapper.prep_xfer_dlist("", descs)) | ||
|
|
||
| # create dst xfer side handles | ||
| self.dst_num_blocks[engine_id] = num_blocks | ||
| # Create src descs and xfer side handles. | ||
| blocks_data = [] | ||
| for base_addr in self.kv_caches_base_addr[self.engine_id]: | ||
| for block_id in range(self.num_blocks): | ||
| block_offset = block_id * self.block_len | ||
| # (addr, len, device id) | ||
| blocks_data.append( | ||
| (base_addr + block_offset, self.block_len, self.rank)) | ||
| logger.debug("Created %s blocks for src engine %s and rank %s", | ||
| len(blocks_data), self.engine_id, self.rank) | ||
|
|
||
| # Register with NIXL. | ||
| descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") | ||
| self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( | ||
| "NIXL_INIT_AGENT", descs) | ||
|
|
||
| # Create dst descs and xfer side handles. | ||
| self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks | ||
| blocks_data = [] | ||
| for base_addr in self.kv_caches_base_addr[engine_id]: | ||
| for block_id in range(num_blocks): | ||
| block_offset = block_id * dst_block_len | ||
| blocks_data.append((base_addr + block_offset, dst_block_len, | ||
| self.rank * tp_multiplier)) | ||
| for block_id in range(nixl_agent_meta.num_blocks): | ||
| block_offset = block_id * self.block_len | ||
| # (addr, len, device id) | ||
| blocks_data.append( | ||
| (base_addr + block_offset, self.block_len, self.rank)) | ||
| logger.debug("Created %s blocks for dst engine %s and rank %s", | ||
| len(blocks_data), engine_id, self.rank) | ||
|
|
||
| # Register with NIXL. | ||
| descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") | ||
| self.dst_xfer_side_handles[engine_id][tp_idx] = ( | ||
| self.nixl_wrapper.prep_xfer_dlist( | ||
| self._remote_agents[engine_id][self.rank * tp_multiplier + | ||
| tp_idx], descs)) | ||
| self.dst_xfer_side_handles[ | ||
| engine_id] = self.nixl_wrapper.prep_xfer_dlist( | ||
| self._remote_agents[engine_id], descs) | ||
|
|
||
| def get_finished(self) -> tuple[set[str], set[str]]: | ||
| """Get requests that are done sending or recving.""" | ||
| done_sending = self._get_new_notifs() | ||
| done_recving = self._pop_done_transfers(self._recving_transfers) | ||
| if len(done_sending) > 0 or len(done_recving) > 0: | ||
| logger.debug( | ||
| "get_finished: %s requests done sending " | ||
| "and %s requests done recving", len(done_sending), | ||
| "Rank %s, get_finished: %s requests done sending " | ||
| "and %s requests done recving", self.rank, len(done_sending), | ||
| len(done_recving)) | ||
| return done_sending, done_recving | ||
|
|
||
| if self.world_size == 1: | ||
| return done_sending, done_recving | ||
|
|
||
| # In TP>1 setup, each rank exchanges KVs with its counterpart | ||
| # ranks independently. get_finished() runs in a worker creates | ||
| # the done_sending and done_recving sets that are sent to the | ||
| # scheduler via ModelRunnerOutput by Rank 0. To avoid race | ||
| # ensure trnxs are done before adding to finished, Ranks 1 to | ||
| # N-1 communicate to Rank 0 once their transaction is done. | ||
| # Rank 0 only returns finished once all ranks are complete. | ||
| if self.rank == 0: | ||
| for req_id in done_sending: | ||
| self._done_sending_count[req_id] += 1 | ||
| for req_id in done_recving: | ||
| self._done_recving_count[req_id] += 1 | ||
|
|
||
| # Update the counts of how many ranks have finished. | ||
| # Get notifies from other ranks that txns are done. | ||
| other_ranks_finished_ids: list[str] = [] | ||
| for i in range(1, self.world_size): | ||
| other_ranks_finished_ids.extend( | ||
| self.tp_group.recv_object(src=i)) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is how Dyanmo does it (with the tp_group) I wonder if there is a better way cc @njhill
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @robertgshaw2-redhat here is an alternative to consider robertgshaw2-redhat#7 Guess this might be preferable latency wise since we don't have additional gather collective, but not sure (since now scheduler needs to receive from all ranks .. though it was doing this anyhow until recently).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets just time things and see which one is faster
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like for TP=2, the setup I have is taking <1ms, so I think this is good enough for now as I would prefer to keep the changes in this file if possible |
||
| for req_id in other_ranks_finished_ids: | ||
| if (req_id in self._done_recving_count | ||
| or req_id in self._recving_transfers): | ||
| self._done_recving_count[req_id] += 1 | ||
| else: | ||
| self._done_sending_count[req_id] += 1 | ||
|
|
||
| # Return ids that have finished on all ranks to the scheduler. | ||
| all_done_sending: set[str] = set() | ||
| all_done_recving: set[str] = set() | ||
| for req_id in list(self._done_recving_count.keys()): | ||
| if self._done_recving_count[req_id] == self.world_size: | ||
| self._done_recving_count.pop(req_id) | ||
| all_done_recving.add(req_id) | ||
| for req_id in list(self._done_sending_count.keys()): | ||
| if self._done_sending_count[req_id] == self.world_size: | ||
| self._done_sending_count.pop(req_id) | ||
| all_done_sending.add(req_id) | ||
|
|
||
| return all_done_sending, all_done_recving | ||
|
|
||
| else: | ||
| finished_req_ids = list(done_recving.union(done_sending)) | ||
| self.tp_group.send_object(finished_req_ids, dst=0) | ||
|
|
||
| # NOTE(rob): unused as only Rank 0 sends to sched. | ||
| return done_sending, done_recving | ||
|
|
||
| def _get_new_notifs(self) -> set[str]: | ||
| """Get req_ids which got a remote xfer message.""" | ||
|
|
@@ -539,61 +588,44 @@ def _read_blocks( | |
| if len(local_block_ids) == 0: | ||
| return | ||
|
|
||
| # TODO: support TP multipliers. | ||
| tp_multiplier = 1 | ||
| remote_block_descs_ids = self._get_block_descs_ids( | ||
| dst_engine_id, "all", remote_block_ids) | ||
| local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] | ||
|
|
||
| # Read the data from the remote. | ||
| for i in range(tp_multiplier): | ||
| local_block_descs_ids = self._get_block_descs_ids( | ||
| self.engine_id, | ||
| "all", | ||
| local_block_ids, | ||
| i=None, #TODO: Enable both tp_multiplier and staging_ranges. | ||
| tp_multiplier=tp_multiplier, | ||
| staging_ranges=None) | ||
| assert len(local_block_descs_ids) == len(remote_block_descs_ids) | ||
| remote_xfer_side_handle = self.dst_xfer_side_handles[ | ||
| dst_engine_id][i] | ||
|
|
||
| # NOTE(rob): we use the request_id as the notify msg, so we | ||
| # must use the same request_id in both the p and d workers. | ||
| handle = self.nixl_wrapper.make_prepped_xfer( | ||
| "READ", | ||
| local_xfer_side_handle, | ||
| local_block_descs_ids, | ||
| remote_xfer_side_handle, | ||
| remote_block_descs_ids, | ||
| notif_msg=request_id.encode("utf-8"), | ||
| ) | ||
|
|
||
| # Call transfer to begin the async transfer | ||
| # We will check this is done in the next forward pass. | ||
| self.nixl_wrapper.transfer(handle) | ||
| self._recving_transfers[request_id].append(handle) | ||
| # Get side handles. | ||
| local_xfer_side_handle = self.src_xfer_side_handle | ||
| remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] | ||
|
|
||
| def _get_block_descs_ids(self, | ||
| engine_id, | ||
| region_ids, | ||
| block_ids, | ||
| i=None, | ||
| tp_multiplier=1, | ||
| staging_ranges=None): | ||
| # Get descs ids. | ||
| remote_block_descs_ids = self._get_block_descs_ids( | ||
| dst_engine_id, remote_block_ids) | ||
| local_block_descs_ids = self._get_block_descs_ids( | ||
| self.engine_id, local_block_ids) | ||
| assert len(local_block_descs_ids) == len(remote_block_descs_ids) | ||
|
|
||
| # Prepare transfer with Nixl. | ||
| handle = self.nixl_wrapper.make_prepped_xfer( | ||
| "READ", | ||
| local_xfer_side_handle, | ||
| local_block_descs_ids, | ||
| remote_xfer_side_handle, | ||
| remote_block_descs_ids, | ||
| notif_msg=request_id.encode("utf-8"), | ||
| ) | ||
|
|
||
| if region_ids == "all": | ||
| region_ids = range(self.num_regions) | ||
| if block_ids == "all": | ||
| block_ids = range(self.num_blocks) | ||
| # Begin async xfer. | ||
| self.nixl_wrapper.transfer(handle) | ||
|
|
||
| descs_ids = [] | ||
| # Use handle to check completion in future step(). | ||
| self._recving_transfers[request_id].append(handle) | ||
|
|
||
| if i is not None: | ||
| raise NotImplementedError("Prefill and Decode instances must have " | ||
| "the same TP size.") | ||
| def _get_block_descs_ids(self, engine_id: str, | ||
| block_ids: list[int]) -> list[int]: | ||
| """Get the descs ids for a set of block ids.""" | ||
| # TODO(rob): should we precompute this? | ||
|
|
||
| # range(1) for MLA, range(2) otherwise. | ||
| region_ids = range(self.num_regions) | ||
| num_blocks = self.dst_num_blocks[engine_id] | ||
|
|
||
| # Compute the desc ids for each block. | ||
| descs_ids: list[int] = [] | ||
| for reg_id in region_ids: | ||
| for block_id in block_ids: | ||
| descs_ids.append(reg_id * num_blocks + block_id) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.