Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def _launch_subprocesses(
pp_rank,
None,
writer,
None,
),
)

Expand Down
52 changes: 50 additions & 2 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@
import logging
import multiprocessing as mp
import signal
import struct
import sys
import threading
import time
from enum import Enum, auto
from multiprocessing import shared_memory
from typing import Dict, List

import psutil
import setproctitle
import zmq

from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.data_parallel_meta import DPBalanceMeta
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand All @@ -44,6 +49,7 @@ class LoadBalanceMethod(Enum):

ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
MINIMUM_TOKENS = auto()

@classmethod
def from_str(cls, method: str):
Expand All @@ -57,7 +63,16 @@ def from_str(cls, method: str):
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""

def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
dp_balance_meta: DPBalanceMeta,
) -> None:
# for dp balance
self.global_balance_id = 0
self.balance_meta = dp_balance_meta

# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
Expand All @@ -78,6 +93,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]

Expand Down Expand Up @@ -231,6 +247,7 @@ def launch_tensor_parallel_group(
pp_rank,
dp_rank,
writer,
self.balance_meta,
),
)
with memory_saver_adapter.configure_subprocess():
Expand Down Expand Up @@ -266,6 +283,31 @@ def round_robin_scheduler(self, req: Req):
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()

def minimum_tokens_scheduler(self, req):
def get_next_global_balance_id() -> int:
INT32_MAX = 2147483647
current_id = self.global_balance_id
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the meaning of this variable?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, this variable corresponds to the balance_id in TokenizedGenerateReqInput.
We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).

return current_id

req.dp_balance_id = get_next_global_balance_id()
with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
Comment on lines 301 to 302
Copy link
Copy Markdown
Contributor

@Edenzzzz Edenzzzz Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be more formal to name these as "on_the_fly" or "in_flight"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can first discuss the current implementation of the algorithm, and then I’ll make a unified update accordingly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m wondering how long onfly_reqs can actually survive.
It seems the scheduler receives reqs from DPC almost immediately.
Meanwhile, onfly_reqs are appended in process_input_requests then excluded at the end of get_next_batch_to_run per 40 iterations.

Given this flow, the comment here might be misleading.

local_tokens = self.balance_meta.get_shared_local_tokens()
total_tokens = [
local_token + sum(onfly_dict.values())
for local_token, onfly_dict in zip(local_tokens, onfly_info)
]
target_worker = total_tokens.index(min(total_tokens))
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)

logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This log message uses logger.info and includes potentially large data structures (local_tokens, onfly_info). If requests are frequent, this could lead to excessive logging and performance overhead. Consider changing this to logger.debug or making it conditional, for example, logging only every N requests or if a specific debug flag is enabled.

self.workers[target_worker].send_pyobj(req)

def event_loop(self):
while True:
while True:
Expand Down Expand Up @@ -296,9 +338,12 @@ def run_data_parallel_controller_process(
setproctitle.setproctitle("sglang::data_parallel_controller")
configure_logger(server_args)
parent_process = psutil.Process().parent()
balance_meta = DPBalanceMeta(server_args.dp_size)

try:
controller = DataParallelController(server_args, port_args)
controller = DataParallelController(
server_args, port_args, dp_balance_meta=balance_meta
)
pipe_writer.send(
{
"status": "ready",
Expand All @@ -317,3 +362,6 @@ def run_data_parallel_controller_process(
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
finally:
# we need to destruct mp.Manager() in balance_meta
balance_meta.destructor()
54 changes: 54 additions & 0 deletions python/sglang/srt/managers/data_parallel_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import multiprocessing as mp
import pickle
import struct
from multiprocessing import shared_memory
from multiprocessing.managers import BaseManager
from typing import Dict, List

logger = logging.getLogger(__name__)

"""
This class will be use in scheduler and dp controller
If this class is placed in the dp controller,
it will cause circular references, so it is placed in a separate file.
"""


class DPBalanceMeta:
def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if we could abstract some methods to python/sglang/srt/utils.py and python/sglang/srt/distributed/parallel_state.py ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, i moved it to the manager/utils.py


init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]

self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)

def destructor(self):
# we must destructor this class manually
self._manager.shutdown()

def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]

def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data

def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)

def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data

def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ class TokenizedGenerateReqInput:
# For data parallel rank routing
data_parallel_rank: Optional[int] = None

# For dp balance
dp_balance_id: int = -1


@dataclass
class EmbeddingReqInput:
Expand Down Expand Up @@ -650,6 +653,8 @@ class TokenizedEmbeddingReqInput:
token_type_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
# For dp balance
dp_balance_id: int = -1


@dataclass
Expand Down
118 changes: 117 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.data_parallel_meta import DPBalanceMeta
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
Expand Down Expand Up @@ -237,6 +238,7 @@ def __init__(
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
):
# Parse args
self.server_args = server_args
Expand Down Expand Up @@ -542,6 +544,15 @@ def __init__(
if get_bool_env_var("SGLANG_GC_LOG"):
configure_gc_logger()

self.balance_meta = dp_balance_meta
if (
server_args.enable_dp_attention
and server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None

self.recv_dp_balance_id_this_term = []

def current_scheduler_metrics_enabled(self):
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers

Expand Down Expand Up @@ -1092,6 +1103,12 @@ def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)

# Create a new request
if (
recv_req.session_params is None
Expand Down Expand Up @@ -1653,6 +1670,11 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:

# Handle DP attention
if need_dp_attn_preparation:
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
self.handle_dp_balance_data(ret)
ret = self.prepare_mlp_sync_batch(ret)

return ret
Expand Down Expand Up @@ -1978,6 +2000,91 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)

def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter holding_tokens_list in gather_dp_balance_info is misleading. Based on its usage in handle_dp_balance_data (line 1854-1855), it's a single integer representing the current worker's holding tokens, not a list.

Later, within gather_dp_balance_info (line 1814), holding_tokens_list is re-assigned to be a list of holding tokens gathered from all workers if self.tp_rank == 0.

This dual meaning and misnaming can cause confusion. Consider renaming the parameter to something like current_worker_holding_tokens: int and the locally gathered list to all_workers_holding_tokens: List[int].

"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)

gather_tensor_size = 512
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values 511 and 512 (for gather_tensor_size) are used here and seem related to max_onfly_req_per_worker defined implicitly in DPBalanceMeta. It would be better to define these as named constants, possibly in data_parallel_meta.py or a shared constants module, and import them here. This improves readability and makes it easier to update if the underlying limits change.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add explanation for this value?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it means the maximum size of the tensor used for gathering data.

# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)

if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None

torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)

gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)

return gathered_id_list_per_worker, holding_tokens_list
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint Union[None, List[List[int]]] for gather_dp_balance_info is incorrect. The function actually returns a tuple: (gathered_id_list_per_worker, holding_tokens_list).
Based on the logic:

  • If self.tp_rank != 0, it returns (None, input_argument_holding_tokens).
  • If self.tp_rank == 0, it returns (List[List[int]], List[int]) (where the second list is the gathered holding tokens for all workers).

A more accurate type hint would be Tuple[Optional[List[List[int]]], List[int]].

Suggested change
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)
return gathered_id_list_per_worker, holding_tokens_list
def gather_dp_balance_info(current_worker_holding_tokens: int) -> Tuple[Optional[List[List[int]]], List[int]]:


def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
meta = self.balance_meta

with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(
onfly_list
), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)

# prepare worker holding tokens this scheduler term
if local_batch is None:
holding_tokens = sum(req.seqlen for req in self.waiting_queue)
else:
holding_tokens = sum(req.seqlen for req in local_batch.reqs) + sum(
req.seqlen for req in self.waiting_queue
)

new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
holding_tokens
)

self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)

@staticmethod
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
Expand Down Expand Up @@ -2884,6 +2991,7 @@ def run_scheduler_process(
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
# Generate the prefix
prefix = ""
Expand Down Expand Up @@ -2918,7 +3026,15 @@ def run_scheduler_process(
init_embedding_cache(embedding_cache_size * 1024 * 1024)
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
scheduler = Scheduler(
server_args,
port_args,
gpu_id,
tp_rank,
pp_rank,
dp_rank,
dp_balance_meta=balance_meta,
)
pipe_writer.send(
{
"status": "ready",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
choices=[
"round_robin",
"shortest_queue",
"minimum_tokens",
],
)

Expand Down