Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 219 additions & 10 deletions python/sglang/srt/disaggregation/conn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Optional
import struct
import threading
from functools import cache
from typing import Dict, Optional, Tuple

import numpy as np
import numpy.typing as npt
import zmq

from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine

logger = logging.getLogger(__name__)

Expand All @@ -21,8 +26,169 @@ class KVArgs:
ib_device: str


RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int32], Optional[int]]]
KVSENDER_POLLING_PORT = 17788
KVRECIVER_POLLING_PORT = 17789


class KVManager:
def __init__(self, args: KVArgs): ...
# TODO: make it general and support multiple transfer backend before merging
def __init__(self, args: KVArgs):
self.engine = MooncakeTransferEngine()
self.kv_args = args
self.request_pool: RequestPoolType = {0: (np.array([0], dtype=np.int32), None)}
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
self.prefill_thread_started = False
self.decode_thread_started = False

def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
):
self.engine.register(kv_data_ptr, kv_data_len)

for aux_data_ptr, aux_data_len in zip(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
):
self.engine.register(aux_data_ptr, aux_data_len)

@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket

def send_kvcache(
self,
endpoint: str,
bootstrap_room: int,
dst_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
):
prefill_indices, _ = self.request_pool[bootstrap_room]
layer_num = int(len(self.kv_args.kv_data_ptrs) / 2)
for layer_id in range(layer_num):
prefill_key_layer_ptr = self.kv_args.kv_data_ptrs[layer_id]
key_item_len = self.kv_args.kv_item_lens[layer_id]
prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id]
value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id]

decode_key_layer_ptr = dst_ptrs[layer_id]
decode_value_layer_ptr = dst_ptrs[layer_num + layer_id]
# TODO: Maybe combine multiple contiguous indices into one transfer_sync op
for prefill_index, decode_index in zip(prefill_indices, dst_kv_indices):
prefill_key_addr = prefill_key_layer_ptr + prefill_index * key_item_len
Copy link
Copy Markdown
Contributor

@GaoYusong GaoYusong Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got an OverflowError, which was resolved by casting prefill_index to np.int64(prefill_index) and decode_index to np.int64(decode_index). The same fix is also needed at lines 114 and 117.

File "./sglang/python/sglang/srt/disaggregation/conn.py", line 110, in send_kvcache
    prefill_key_addr = prefill_key_layer_ptr + prefill_index * key_item_len
OverflowError: Python integer 140111798861824 out of bounds for int32

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got. Thanks for your suggestion. We will fix it.

decode_key_addr = decode_key_layer_ptr + decode_index * key_item_len
# TODO: mooncake transfer engine can do async transfer. Do async later
self.engine.transfer_sync(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I met Transfer Return Error in this line. Can you give some advice?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zhou-sx Does your machines support GDR? Also, this PR is not ready yet, we are working on it now.

Copy link
Copy Markdown
Contributor

@Zhou-sx Zhou-sx Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zhou-sx Does your machines support GDR? Also, this PR is not ready yet, we are working on it now.

Thank you for your replay. So it can not support tcp yet?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zhou-sx Yes, please see the limitation described above: "If GDR is not supported, it should fall back to double buffering for data copies (we don't implement it now)". It will be in another PR in the future.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I met Transfer Return Error in this line. Can you give some advice?

Hi, Could you please tell me how to run a sglang model with mooncake kv cache transfer engine? Thanks.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I met Transfer Return Error in this line. Can you give some advice?

Hi, Could you please tell me how to run a sglang model with mooncake kv cache transfer engine? Thanks.
Based on this PR or mooncake integration in vlllm, there is a lot of work...

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zhou-sx Yes, please see the limitation described above: "If GDR is not supported, it should fall back to double buffering for data copies (we don't implement it now)". It will be in another PR in the future.

May I ask if TCP is supported now?

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

endpoint, prefill_key_addr, decode_key_addr, key_item_len
)

prefill_value_addr = (
prefill_value_layer_ptr + prefill_index * value_item_len
)
decode_value_addr = (
decode_value_layer_ptr + decode_index * value_item_len
)
# TODO: mooncake transfer engine can do async transfer. Do async later
self.engine.transfer_sync(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any benefit to do in async given that we already run in a background thread?

endpoint, prefill_value_addr, decode_value_addr, value_item_len
)

def send_aux(
self,
endpoint: str,
bootstrap_room: int,
dst_aux_ptrs: list[int],
dst_aux_index: int,
):
_, prefill_aux_index = self.request_pool[bootstrap_room]
aux_item_len = self.kv_args.aux_data_lens[0]
prefill_aux_addr = (
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
)
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
# TODO: mooncake transfer engine can do async transfer. Do async later
# Not sure about the amount of aux data, maybe transfer it by zmq is more effective
self.engine.transfer_sync(
endpoint, prefill_aux_addr, decode_aux_addr, aux_item_len
)

def start_prefill_thread(self):
if self.prefill_thread_started:
return
self.prefill_thread_started = True
sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(sender_rank_port))

def prefill_thread():
while True:
(
endpoint,
bootstrap_room,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
) = self.server_socket.recv_multipart()
if bootstrap_room.decode("ascii") == "None":
continue
endpoint = endpoint.decode("ascii")
bootstrap_room = int(bootstrap_room.decode("ascii"))
dst_ptrs = list(struct.unpack(f"{len(dst_ptrs)//8}q", dst_ptrs))
dst_kv_indices = np.frombuffer(dst_kv_indices, dtype=np.int32)
dst_aux_ptrs = list(
struct.unpack(f"{len(dst_aux_ptrs)//8}q", dst_aux_ptrs)
)
dst_aux_index = int(dst_aux_index.decode("ascii"))
self.send_kvcache(endpoint, bootstrap_room, dst_ptrs, dst_kv_indices)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this checks if bootstrap_room is enqueued first?

self.send_aux(endpoint, bootstrap_room, dst_aux_ptrs, dst_aux_index)
self.request_pool.pop(bootstrap_room)
self._connect(
"tcp://"
+ endpoint
+ ":"
+ str(KVRECIVER_POLLING_PORT + self.kv_args.engine_rank)
).send_multipart(
[
str(bootstrap_room).encode("ascii"),
"Done",
]
)

threading.Thread(target=prefill_thread).start()

def start_decode_thread(self):
if self.decode_thread_started:
return
self.decode_thread_started = True
reciver_rank_port = KVRECIVER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(reciver_rank_port))

def decode_thread():
while True:
(bootstrap_room, status) = self.server_socket.recv_multipart()
bootstrap_room = int(bootstrap_room.decode("ascii"))
self.request_pool.pop(bootstrap_room)

threading.Thread(target=decode_thread).start()

def enqueue_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int],
):
self.request_pool[bootstrap_room] = (kv_indices, aux_index)

def has_finished(self, bootstrap_room: int):
if bootstrap_room in self.request_pool:
return False
return True

def get_localhost(self):
return self.engine.get_localhost()


class KVPoll:
Expand All @@ -34,48 +200,91 @@ class KVPoll:


class KVSender:

def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.has_sent = False
self.kv_mgr.start_prefill_thread()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be at _init_kv_manager? I assume it only needs to run once globally

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. One P or D endpoint only needs one instance. We are considering moving it to KVManager.


def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.aux_index = aux_index
self.num_kv_indices = num_kv_indices

def send(self, kv_indices: npt.NDArray[np.int32]):
self.has_sent = True
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, self.aux_index)

def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
if self.kv_mgr.has_finished(self.bootstrap_room):
self.has_sent = True
return KVPoll.Success
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success

def failure_exception(self):
raise Exception("Fake KVSender Exception")


class KVReceiver:

def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.prefill_server_url = (
bootstrap_addr.split(":")[0]
+ ":"
+ str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank)
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here has a problem, the self.kv_mgr.kv_args.engine_rank is the decode's tp_rank, instead of the prefill's. (Considering the special case: prefill's tp_size and the decode's tp_size are different) Meanwhile since the KVSENDER_POLLING_PORT is hardcoded, here it supposes the D's tp_rank(or the rank layout) is identical with P's. Otherwise, since in this decode init phase, there's no handshake at all, the decode has no idea which tp_rank the prefill is using. In previously mentioned special case, it is even a worse problem.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuan-luo This is the first version which assumes p and d have the same TP size while implementing. Different TP support will be in the next PR once we figure out how to gracefully handle handshake for all kinds of situations. After that, it can be achieved by attaching the TP size info in the kv args, and we can process the kvcache according to the src_tp_size, src_engine_rank, dst_tp_size, and dst_engine_rank.

Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
self.decode_ip = self.kv_mgr.get_localhost()
self.kv_mgr.start_decode_thread()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

self.has_init = False

@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket

def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
self.has_init = True
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, aux_index)
packed_kv_data_ptrs = b"".join(
struct.pack("q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
self.decode_ip.encode("ascii"),
str(self.bootstrap_room).encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
]
)

def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
if self.kv_mgr.has_finished(self.bootstrap_room):
self.has_init = True
return KVPoll.Success
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success

def failure_exception(self):
raise Exception("Fake KVReceiver Exception")


class KVBootstrapServer:

def __init__(self, port: int): ...

def poll(self) -> KVPoll: ...
Loading