-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Support dynamic connection and TP 16 #5351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,15 +2,18 @@ | |
|
|
||
| import asyncio | ||
| import dataclasses | ||
| import json | ||
| import logging | ||
| import queue | ||
| import random | ||
| import struct | ||
| import threading | ||
| from functools import cache | ||
| from typing import Dict, List, Optional, Tuple | ||
| from typing import Dict, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import numpy.typing as npt | ||
| import requests | ||
| import zmq | ||
| from aiohttp import web | ||
|
|
||
|
|
@@ -24,9 +27,21 @@ | |
| ) | ||
| from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine | ||
| from sglang.srt.disaggregation.utils import DisaggregationMode | ||
| from sglang.srt.utils import is_port_available | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def find_available_ports(base_port: int, count: int) -> List[int]: | ||
| """Find consecutive available ports starting from base_port.""" | ||
| available_ports = [] | ||
| current_port = base_port | ||
|
|
||
| while len(available_ports) < count: | ||
| if is_port_available(current_port): | ||
| available_ports.append(current_port) | ||
| current_port += random.randint(100, 1000) | ||
|
|
||
| return available_ports | ||
|
|
||
| def group_concurrent_contiguous( | ||
| src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] | ||
|
|
@@ -65,9 +80,10 @@ class TransferKVChunk: | |
|
|
||
| @dataclasses.dataclass | ||
| class TransferInfo: | ||
| room: int | ||
| endpoint: str | ||
| decode_port: int | ||
| mooncake_session_id: str | ||
| room: int | ||
| dst_kv_ptrs: list[int] | ||
| dst_kv_indices: npt.NDArray[np.int64] | ||
| dst_aux_ptrs: list[int] | ||
|
|
@@ -77,25 +93,24 @@ class TransferInfo: | |
| def from_zmq(cls, msg: List[bytes]): | ||
| return cls( | ||
| endpoint=msg[0].decode("ascii"), | ||
| mooncake_session_id=msg[1].decode("ascii"), | ||
| room=int(msg[2].decode("ascii")), | ||
| dst_kv_ptrs=list(struct.unpack(f"{len(msg[3])//8}Q", msg[3])), | ||
| dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64), | ||
| dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), | ||
| dst_aux_index=int(msg[6].decode("ascii")), | ||
| decode_port=int(msg[1].decode("ascii")), | ||
| mooncake_session_id=msg[2].decode("ascii"), | ||
| room=int(msg[3].decode("ascii")), | ||
| dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), | ||
| dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), | ||
| dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), | ||
| dst_aux_index=int(msg[7].decode("ascii")), | ||
| ) | ||
|
|
||
|
|
||
| KVSENDER_POLLING_PORT = 17788 | ||
| KVRECEIVER_POLLING_PORT = 27788 | ||
|
|
||
|
|
||
| class MooncakeKVManager(BaseKVManager): | ||
| def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): | ||
| self.engine = MooncakeTransferEngine() | ||
| self.kv_args = args | ||
| self.disaggregation_mode = disaggregation_mode | ||
| self.request_status: Dict[int, KVPoll] = {} | ||
| self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {} | ||
| self.rank_port = None | ||
| self.server_socket = zmq.Context().socket(zmq.PULL) | ||
| self.register_buffer_to_engine() | ||
| if self.disaggregation_mode == DisaggregationMode.PREFILL: | ||
|
|
@@ -202,31 +217,27 @@ def send_aux( | |
| ) | ||
| return status | ||
|
|
||
| def sync_status_to_decode_endpoint(self, remote: str, room: int): | ||
| def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int): | ||
| if ":" in remote: | ||
| remote = remote.split(":")[0] | ||
| self._connect( | ||
| "tcp://" | ||
| + remote | ||
| + ":" | ||
| + str(KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank) | ||
| ).send_multipart( | ||
| self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart( | ||
| [ | ||
| str(room).encode("ascii"), | ||
| str(self.request_status[room]).encode("ascii"), | ||
| ] | ||
| ) | ||
|
|
||
| def start_prefill_thread(self): | ||
| sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank | ||
| self.server_socket.bind("tcp://*:" + str(sender_rank_port)) | ||
| # Find available port for prefill tp | ||
| self.rank_port = find_available_ports(20000, 1)[0] | ||
|
||
| self.server_socket.bind("tcp://*:" + str(self.rank_port)) | ||
|
|
||
| def bootstrap_thread(): | ||
| """This thread recvs pre-alloc notification from the decode engine""" | ||
| # KVPoll.Bootstrapping -> KVPoll.WaitingForInput | ||
| while True: | ||
| waiting_req_bytes = self.server_socket.recv_multipart() | ||
| room = waiting_req_bytes[2].decode("ascii") | ||
| room = waiting_req_bytes[3].decode("ascii") | ||
| if room == "None": | ||
| continue | ||
| room = int(room) | ||
|
|
@@ -254,7 +265,7 @@ def transfer_thread(): | |
| ) | ||
| if ret != 0: | ||
| self.request_status[kv_chunk.room] = KVPoll.Failed | ||
| self.sync_status_to_decode_endpoint(req.endpoint, req.room) | ||
| self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room) | ||
| continue | ||
|
|
||
| if kv_chunk.is_last: | ||
|
|
@@ -268,7 +279,7 @@ def transfer_thread(): | |
| self.request_status[req.room] = ( | ||
| KVPoll.Success if ret == 0 else KVPoll.Failed | ||
| ) | ||
| self.sync_status_to_decode_endpoint(req.endpoint, req.room) | ||
| self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room) | ||
| self.transfer_infos.pop(req.room) | ||
|
|
||
| except queue.Empty: | ||
|
|
@@ -278,8 +289,8 @@ def transfer_thread(): | |
| threading.Thread(target=transfer_thread).start() | ||
|
|
||
| def start_decode_thread(self): | ||
| receiver_rank_port = KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank | ||
| self.server_socket.bind("tcp://*:" + str(receiver_rank_port)) | ||
| self.rank_port = find_available_ports(25000, 1)[0] | ||
| self.server_socket.bind("tcp://*:" + str(self.rank_port)) | ||
|
|
||
| def decode_thread(): | ||
| while True: | ||
|
|
@@ -342,6 +353,38 @@ def __init__( | |
| self.bootstrap_room = bootstrap_room | ||
| self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) | ||
| self.aux_index = None | ||
| self.bootstrap_server_url = bootstrap_addr | ||
|
|
||
| self.session_id = self.kv_mgr.get_session_id() | ||
|
|
||
| # Register to bootstrap server | ||
| self._register_to_bootstrap() | ||
|
|
||
| def _register_to_bootstrap(self): | ||
| """Register KVSender to bootstrap server via HTTP POST.""" | ||
| url = f"http://{self.bootstrap_server_url}/kv_route" | ||
| payload = { | ||
| "identity": self.session_id, | ||
| "role": "Prefill", | ||
| "serve_ip": self.kv_mgr.get_localhost(), | ||
| "serve_port": self.kv_mgr.rank_port, | ||
| "tp_rank": self.kv_mgr.kv_args.engine_rank, | ||
| } | ||
|
|
||
| logger.info( | ||
| f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}" | ||
| ) | ||
|
|
||
| try: | ||
| response = requests.put(url, json=payload) | ||
| if response.status_code == 200: | ||
| logger.info(f"Prefill successfully registered to bootstrap server.") | ||
| else: | ||
| logger.info( | ||
| f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}" | ||
| ) | ||
| except Exception as e: | ||
| logger.info(f"Prefill Failed to register to bootstrap server: {e}") | ||
|
|
||
| def init(self, num_kv_indices: int, aux_index: Optional[int] = None): | ||
| self.num_kv_indices = num_kv_indices | ||
|
|
@@ -384,14 +427,28 @@ def __init__( | |
| self.bootstrap_room = bootstrap_room | ||
| self.bootstrap_addr = bootstrap_addr | ||
| self.kv_mgr = mgr | ||
| self.prefill_server_url = ( | ||
| bootstrap_addr.split(":")[0] | ||
| + ":" | ||
| + str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank) | ||
| ) | ||
| self.decode_ip = self.kv_mgr.get_localhost() | ||
| self.session_id = self.kv_mgr.get_session_id() | ||
| self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) | ||
| self.prefill_engine_rank = None | ||
| self.decode_port = self.kv_mgr.rank_port | ||
| self.dealer_socket = None | ||
|
|
||
| def _get_prefill_info_from_bootstrap(self, tp_rank: int): | ||
|
||
| """Fetch the prefill server port corresponding to tp_rank from the bootstrap server.""" | ||
| try: | ||
| url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}" | ||
| response = requests.get(url) | ||
| if response.status_code == 200: | ||
| prefill_info = response.json() | ||
| return prefill_info | ||
| else: | ||
| logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}") | ||
| return None | ||
| except Exception as e: | ||
| logger.error(f"Error fetching prefill info from bootstrap: {e}") | ||
| return None | ||
|
|
||
|
|
||
| @cache | ||
| def _connect(self, endpoint: str): | ||
|
|
@@ -400,6 +457,31 @@ def _connect(self, endpoint: str): | |
| return socket | ||
|
|
||
| def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): | ||
| prefill_info = None | ||
| logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.") | ||
|
|
||
| if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool: | ||
| prefill_info = self._get_prefill_info_from_bootstrap( | ||
|
||
| self.kv_mgr.kv_args.engine_rank | ||
| ) | ||
| if prefill_info is None: | ||
| logger.error( | ||
| logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}") | ||
| ) | ||
| else: | ||
| self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info | ||
| else: | ||
| prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] | ||
|
|
||
| if prefill_info: | ||
| self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}" | ||
|
|
||
| logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}") | ||
| self.handshake_prefill_server(kv_indices, aux_index) | ||
|
|
||
| def handshake_prefill_server( | ||
| self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None | ||
| ): | ||
| packed_kv_data_ptrs = b"".join( | ||
| struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs | ||
| ) | ||
|
|
@@ -409,6 +491,7 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non | |
| self._connect("tcp://" + self.prefill_server_url).send_multipart( | ||
| [ | ||
| self.decode_ip.encode("ascii"), | ||
| str(self.decode_port).encode("ascii"), | ||
| self.session_id.encode("ascii"), | ||
| str(self.bootstrap_room).encode("ascii"), | ||
| packed_kv_data_ptrs, | ||
|
|
@@ -432,6 +515,12 @@ def __init__(self, port: int): | |
| self.store = dict() | ||
| self.lock = asyncio.Lock() | ||
| self._setup_routes() | ||
| # prefill_engine_rank -> prefill_info | ||
| self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {} | ||
|
|
||
| self.context = zmq.Context() | ||
|
|
||
| self.prefill_engine_rank = None | ||
|
|
||
| # Start bootstrap server | ||
| self.thread = threading.Thread(target=self._run_server, daemon=True) | ||
|
|
@@ -442,21 +531,22 @@ def run(self): | |
|
|
||
| def _setup_routes(self): | ||
| self.app.router.add_route("*", "/metadata", self._handle_metadata) | ||
| self.app.router.add_route("*", "/kv_route", self._handle_kv_route) | ||
|
|
||
| async def _handle_metadata(self, request: web.Request): | ||
| key = request.query.get("key", "") | ||
|
|
||
| if request.method == "GET": | ||
| return await self._handle_get(key) | ||
| return await self._handle_metadata_get(key) | ||
| elif request.method == "PUT": | ||
| return await self._handle_put(key, request) | ||
| return await self._handle_metadata_put(key, request) | ||
| elif request.method == "DELETE": | ||
| return await self._handle_delete(key) | ||
| return await self._handle_metadata_delete(key) | ||
| return web.Response( | ||
| text="Method not allowed", status=405, content_type="application/json" | ||
| ) | ||
|
|
||
| async def _handle_get(self, key): | ||
| async def _handle_metadata_get(self, key): | ||
| async with self.lock: | ||
| value = self.store.get(key) | ||
| if value is None: | ||
|
|
@@ -465,15 +555,15 @@ async def _handle_get(self, key): | |
| ) | ||
| return web.Response(body=value, status=200, content_type="application/json") | ||
|
|
||
| async def _handle_put(self, key, request): | ||
| async def _handle_metadata_put(self, key, request): | ||
| data = await request.read() | ||
| async with self.lock: | ||
| self.store[key] = data | ||
| return web.Response( | ||
| text="metadata updated", status=200, content_type="application/json" | ||
| ) | ||
|
|
||
| async def _handle_delete(self, key): | ||
| async def _handle_metadata_delete(self, key): | ||
| async with self.lock: | ||
| if key not in self.store: | ||
| return web.Response( | ||
|
|
@@ -486,6 +576,52 @@ async def _handle_delete(self, key): | |
| text="metadata deleted", status=200, content_type="application/json" | ||
| ) | ||
|
|
||
| async def _handle_kv_route(self, request: web.Request): | ||
| method = request.method | ||
| if method == "PUT": | ||
| return await self._handle_kv_route_put(request) | ||
| elif method == "GET": | ||
| return await self._handle_kv_route_get(request) | ||
| else: | ||
| return web.Response( | ||
| text="Method not allowed", status=405, content_type="application/json" | ||
| ) | ||
|
|
||
| async def _handle_kv_route_put(self, request: web.Request): | ||
| data = await request.json() | ||
| identity = data["identity"] | ||
| role = data["role"] | ||
| serve_ip = data["serve_ip"] | ||
| serve_port = int(data["serve_port"]) # Assuming serve_port is an integer | ||
| tp_rank = int(data["tp_rank"]) | ||
|
|
||
| # Add lock to make sure thread-safe | ||
| if role == "Prefill": | ||
| async with self.lock: | ||
| self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port} | ||
| logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}") | ||
|
|
||
| return web.Response(text="OK", status=200) | ||
|
|
||
| async def _handle_kv_route_get(self, request: web.Request): | ||
| tp_rank = request.query.get("tp_rank") | ||
| if not tp_rank: | ||
| return web.Response(text="Missing tp_rank", status=400) | ||
| try: | ||
| tp_rank = int(tp_rank) | ||
| except ValueError: | ||
| return web.Response(text="tp_rank must be int", status=400) | ||
|
|
||
| # Find corresponding prefill info | ||
| async with self.lock: | ||
| prefill_info = self.prefill_port_table.get(tp_rank) | ||
|
|
||
| if prefill_info is not None: | ||
| return web.json_response(prefill_info, status=200) | ||
|
|
||
| else: | ||
| return web.Response(text="Not Found", status=404) | ||
|
|
||
| def _run_server(self): | ||
| try: | ||
| # Event Loop | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.