Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/disaggregation/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
106 changes: 106 additions & 0 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from abc import ABC, abstractmethod
from typing import Optional

import numpy as np
import numpy.typing as npt

from sglang.srt.disaggregation.utils import DisaggregationMode


class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str


class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4


class BaseKVManager(ABC):
"""Base class for managing transfers states"""

@abstractmethod
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ...


class BaseKVSender(ABC):

@abstractmethod
def __init__(
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
): ...

@abstractmethod
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
"""
Notify the decoder server about the kv indices length and aux index
"""
...

@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int64]):
"""
Send the kv cache at the given kv indices to the decoder server
"""
...

@abstractmethod
def poll(self) -> KVPoll:
"""
Check the status of the kv cache transfer
"""
...

@abstractmethod
def failure_exception(self):
"""
Raise an exception if the kv cache transfer fails
"""
...


class BaseKVReceiver(ABC):

@abstractmethod
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
): ...

@abstractmethod
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
"""
Notify the prefill server about the kv indices and aux index
"""
...

@abstractmethod
def poll(self) -> KVPoll:
"""
Check the status of the kv cache transfer
"""
...

@abstractmethod
def failure_exception(self):
"""
Raise an exception if the kv cache transfer fails
"""
...


class BaseKVBootstrapServer(ABC):
@abstractmethod
def __init__(self, port: int): ...
23 changes: 18 additions & 5 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@
import torch
from torch.distributed import ProcessGroup

from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
from sglang.srt.disaggregation.base import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
poll_and_all_reduce,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
Expand All @@ -51,7 +60,7 @@
@dataclass
class DecodeRequest:
req: Req
kv_receiver: KVReceiver
kv_receiver: BaseKVReceiver
waiting_for_input: bool = False
metadata_buffer_index: int = -1

Expand All @@ -75,6 +84,7 @@ def __init__(
tp_rank: int,
tp_size: int,
bootstrap_port: int,
transfer_backend: TransferBackend,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
Expand All @@ -94,9 +104,10 @@ def __init__(

# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager()

def _init_kv_manager(self) -> KVManager:
def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = (
Expand All @@ -117,13 +128,15 @@ def _init_kv_manager(self) -> KVManager:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args, DisaggregationMode("decode"))
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE)
return kv_manager

def add(self, req: Req) -> None:
"""Add a request to the pending queue."""

kv_receiver = KVReceiver(
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
import zmq
from aiohttp import web

from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,25 +52,6 @@ def group_concurrent_contiguous(
return src_groups, dst_groups


class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str


class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4


RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]]
WaitingPoolType = Dict[
int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int]
Expand All @@ -71,8 +60,7 @@ class KVPoll:
KVRECEIVER_POLLING_PORT = 27788


class KVManager:
# TODO: make it general and support multiple transfer backend before merging
class MooncakeKVManager(BaseKVManager):
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine()
self.kv_args = args
Expand Down Expand Up @@ -331,9 +319,11 @@ def get_session_id(self):
return self.engine.get_session_id()


class KVSender:
class MooncakeKVSender(BaseKVSender):

def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
Expand All @@ -353,10 +343,13 @@ def failure_exception(self):
raise Exception("Fake KVSender Exception")


class KVReceiver:
class MooncakeKVReceiver(BaseKVReceiver):

def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
self,
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
Expand Down Expand Up @@ -403,7 +396,7 @@ def failure_exception(self):
raise Exception("Fake KVReceiver Exception")


class KVBootstrapServer:
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int):
self.port = port
self.app = web.Application()
Expand Down
22 changes: 18 additions & 4 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@

import torch

from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
from sglang.srt.disaggregation.base import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
poll_and_all_reduce,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
Expand All @@ -38,6 +47,7 @@
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache


logger = logging.getLogger(__name__)


Expand All @@ -56,6 +66,7 @@ def __init__(
tp_size: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
transfer_backend: TransferBackend,
):
self.token_to_kv_pool = token_to_kv_pool
self.aux_dtype = aux_dtype
Expand All @@ -64,6 +75,7 @@ def __init__(
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.tp_size = tp_size
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager()
self.queue: List[Req] = []
self.gloo_group = gloo_group
Expand All @@ -74,7 +86,7 @@ def allocate_token_id(self, idx: int, token_id: int):
output_id_buffer = self.metadata_buffers[0]
output_id_buffer[idx] = token_id

def _init_kv_manager(self) -> KVManager:
def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = (
Expand All @@ -96,11 +108,13 @@ def _init_kv_manager(self) -> KVManager:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args, DisaggregationMode("prefill"))
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL)
return kv_manager

def add(self, req: Req) -> None:
req.disagg_kv_sender = KVSender(
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
Expand Down
31 changes: 31 additions & 0 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,34 @@ def alloc(self) -> List[int]:

def free(self, free_index: int):
self.free_slots.append(free_index)


class TransferBackend(Enum):
MOONCAKE = "mooncake"
FAKE = "fake"


class KVClassType(Enum):
MANAGER = "manager"
SENDER = "sender"
RECEIVER = "receiver"
BOOTSTRAP_SERVER = "bootstrap_server"


def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)

class_mapping = {
KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: MooncakeKVReceiver,
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
Loading
Loading