Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 173 additions & 37 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

import asyncio
import dataclasses
import json
import logging
import queue
import random
import struct
import threading
from functools import cache
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web

Expand All @@ -24,9 +27,21 @@
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.utils import is_port_available

logger = logging.getLogger(__name__)

def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port

while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)

return available_ports

def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
Expand Down Expand Up @@ -65,9 +80,10 @@ class TransferKVChunk:

@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
decode_port: int
mooncake_session_id: str
room: int
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
Expand All @@ -77,25 +93,24 @@ class TransferInfo:
def from_zmq(cls, msg: List[bytes]):
return cls(
endpoint=msg[0].decode("ascii"),
mooncake_session_id=msg[1].decode("ascii"),
room=int(msg[2].decode("ascii")),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[3])//8}Q", msg[3])),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_index=int(msg[6].decode("ascii")),
decode_port=int(msg[1].decode("ascii")),
mooncake_session_id=msg[2].decode("ascii"),
room=int(msg[3].decode("ascii")),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
)


KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT = 27788


class MooncakeKVManager(BaseKVManager):
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine()
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
self.request_status: Dict[int, KVPoll] = {}
self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
Expand Down Expand Up @@ -202,31 +217,27 @@ def send_aux(
)
return status

def sync_status_to_decode_endpoint(self, remote: str, room: int):
def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int):
if ":" in remote:
remote = remote.split(":")[0]
self._connect(
"tcp://"
+ remote
+ ":"
+ str(KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank)
).send_multipart(
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
[
str(room).encode("ascii"),
str(self.request_status[room]).encode("ascii"),
]
)

def start_prefill_thread(self):
sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(sender_rank_port))
# Find available port for prefill tp
self.rank_port = find_available_ports(20000, 1)[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use socket method to get free port , to avoid the potential. conflicts

self.server_socket.bind("tcp://*:" + str(self.rank_port))

def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[2].decode("ascii")
room = waiting_req_bytes[3].decode("ascii")
if room == "None":
continue
room = int(room)
Expand Down Expand Up @@ -254,7 +265,7 @@ def transfer_thread():
)
if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(req.endpoint, req.room)
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
continue

if kv_chunk.is_last:
Expand All @@ -268,7 +279,7 @@ def transfer_thread():
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(req.endpoint, req.room)
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.transfer_infos.pop(req.room)

except queue.Empty:
Expand All @@ -278,8 +289,8 @@ def transfer_thread():
threading.Thread(target=transfer_thread).start()

def start_decode_thread(self):
receiver_rank_port = KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(receiver_rank_port))
self.rank_port = find_available_ports(25000, 1)[0]
self.server_socket.bind("tcp://*:" + str(self.rank_port))

def decode_thread():
while True:
Expand Down Expand Up @@ -342,6 +353,38 @@ def __init__(
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr

self.session_id = self.kv_mgr.get_session_id()

# Register to bootstrap server
self._register_to_bootstrap()

def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
url = f"http://{self.bootstrap_server_url}/kv_route"
payload = {
"identity": self.session_id,
"role": "Prefill",
"serve_ip": self.kv_mgr.get_localhost(),
"serve_port": self.kv_mgr.rank_port,
"tp_rank": self.kv_mgr.kv_args.engine_rank,
}

logger.info(
f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)

try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.info(f"Prefill successfully registered to bootstrap server.")
else:
logger.info(
f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.info(f"Prefill Failed to register to bootstrap server: {e}")

def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
Expand Down Expand Up @@ -384,14 +427,28 @@ def __init__(
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.prefill_server_url = (
bootstrap_addr.split(":")[0]
+ ":"
+ str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank)
)
self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
self.prefill_engine_rank = None
self.decode_port = self.kv_mgr.rank_port
self.dealer_socket = None

def _get_prefill_info_from_bootstrap(self, tp_rank: int):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has a serious problem in tp16 case

"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}"
response = requests.get(url)
if response.status_code == 200:
prefill_info = response.json()
return prefill_info
else:
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None


@cache
def _connect(self, endpoint: str):
Expand All @@ -400,6 +457,31 @@ def _connect(self, endpoint: str):
return socket

def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
prefill_info = None
logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.")

if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool:
prefill_info = self._get_prefill_info_from_bootstrap(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should cached result per decode instance.

self.kv_mgr.kv_args.engine_rank
)
if prefill_info is None:
logger.error(
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]

if prefill_info:
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"

logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
self.handshake_prefill_server(kv_indices, aux_index)

def handshake_prefill_server(
self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None
):
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
Expand All @@ -409,6 +491,7 @@ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = Non
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
self.decode_ip.encode("ascii"),
str(self.decode_port).encode("ascii"),
self.session_id.encode("ascii"),
str(self.bootstrap_room).encode("ascii"),
packed_kv_data_ptrs,
Expand All @@ -432,6 +515,12 @@ def __init__(self, port: int):
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
# prefill_engine_rank -> prefill_info
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}

self.context = zmq.Context()

self.prefill_engine_rank = None

# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
Expand All @@ -442,21 +531,22 @@ def run(self):

def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/kv_route", self._handle_kv_route)

async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")

if request.method == "GET":
return await self._handle_get(key)
return await self._handle_metadata_get(key)
elif request.method == "PUT":
return await self._handle_put(key, request)
return await self._handle_metadata_put(key, request)
elif request.method == "DELETE":
return await self._handle_delete(key)
return await self._handle_metadata_delete(key)
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)

async def _handle_get(self, key):
async def _handle_metadata_get(self, key):
async with self.lock:
value = self.store.get(key)
if value is None:
Expand All @@ -465,15 +555,15 @@ async def _handle_get(self, key):
)
return web.Response(body=value, status=200, content_type="application/json")

async def _handle_put(self, key, request):
async def _handle_metadata_put(self, key, request):
data = await request.read()
async with self.lock:
self.store[key] = data
return web.Response(
text="metadata updated", status=200, content_type="application/json"
)

async def _handle_delete(self, key):
async def _handle_metadata_delete(self, key):
async with self.lock:
if key not in self.store:
return web.Response(
Expand All @@ -486,6 +576,52 @@ async def _handle_delete(self, key):
text="metadata deleted", status=200, content_type="application/json"
)

async def _handle_kv_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_kv_route_put(request)
elif method == "GET":
return await self._handle_kv_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)

async def _handle_kv_route_put(self, request: web.Request):
data = await request.json()
identity = data["identity"]
role = data["role"]
serve_ip = data["serve_ip"]
serve_port = int(data["serve_port"]) # Assuming serve_port is an integer
tp_rank = int(data["tp_rank"])

# Add lock to make sure thread-safe
if role == "Prefill":
async with self.lock:
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")

return web.Response(text="OK", status=200)

async def _handle_kv_route_get(self, request: web.Request):
tp_rank = request.query.get("tp_rank")
if not tp_rank:
return web.Response(text="Missing tp_rank", status=400)
try:
tp_rank = int(tp_rank)
except ValueError:
return web.Response(text="tp_rank must be int", status=400)

# Find corresponding prefill info
async with self.lock:
prefill_info = self.prefill_port_table.get(tp_rank)

if prefill_info is not None:
return web.json_response(prefill_info, status=200)

else:
return web.Response(text="Not Found", status=404)

def _run_server(self):
try:
# Event Loop
Expand Down
Loading