Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--dp-size` | The data parallelism size. | 1 |
| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin |
| `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin |

## Multi-node distributed serving

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def _launch_subprocesses(
pp_rank,
None,
writer,
None,
),
)

Expand Down
54 changes: 52 additions & 2 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
import logging
import multiprocessing as mp
import signal
import struct
import sys
import threading
import time
from enum import Enum, auto
from multiprocessing import shared_memory
from typing import Dict, List

import psutil
import setproctitle
Expand All @@ -32,6 +36,7 @@
)
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
Expand All @@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):

ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
MINIMUM_TOKENS = auto()

@classmethod
def from_str(cls, method: str):
Expand All @@ -58,7 +64,16 @@ def from_str(cls, method: str):
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""

def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
dp_balance_meta: DPBalanceMeta,
) -> None:
# for dp balance
self.global_balance_id = 0
self.balance_meta = dp_balance_meta

# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
Expand All @@ -79,6 +94,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]

Expand Down Expand Up @@ -234,6 +250,7 @@ def launch_tensor_parallel_group(
pp_rank,
dp_rank,
writer,
self.balance_meta,
),
)
with memory_saver_adapter.configure_subprocess():
Expand Down Expand Up @@ -269,6 +286,33 @@ def round_robin_scheduler(self, req: Req):
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()

def minimum_tokens_scheduler(self, req):
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def get_next_global_balance_id() -> int:
INT32_MAX = 2147483647
current_id = self.global_balance_id
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the meaning of this variable?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, this variable corresponds to the balance_id in TokenizedGenerateReqInput.
We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).

return current_id

req.dp_balance_id = get_next_global_balance_id()
with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
Comment on lines 301 to 302
Copy link
Copy Markdown
Contributor

@Edenzzzz Edenzzzz Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be more formal to name these as "on_the_fly" or "in_flight"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can first discuss the current implementation of the algorithm, and then I’ll make a unified update accordingly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m wondering how long onfly_reqs can actually survive.
It seems the scheduler receives reqs from DPC almost immediately.
Meanwhile, onfly_reqs are appended in process_input_requests then excluded at the end of get_next_batch_to_run per 40 iterations.

Given this flow, the comment here might be misleading.

local_tokens = self.balance_meta.get_shared_local_tokens()
total_tokens = [
local_token + sum(onfly_dict.values())
for local_token, onfly_dict in zip(local_tokens, onfly_info)
]
target_worker = total_tokens.index(min(total_tokens))
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)

# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
self.workers[target_worker].send_pyobj(req)

def event_loop(self):
while True:
while True:
Expand Down Expand Up @@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
setproctitle.setproctitle("sglang::data_parallel_controller")
configure_logger(server_args)
parent_process = psutil.Process().parent()
balance_meta = DPBalanceMeta(server_args.dp_size)

try:
controller = DataParallelController(server_args, port_args)
controller = DataParallelController(
server_args, port_args, dp_balance_meta=balance_meta
)
pipe_writer.send(
{
"status": "ready",
Expand All @@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
finally:
# we need to destruct mp.Manager() in balance_meta
balance_meta.destructor()
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
# For data parallel rank routing
data_parallel_rank: Optional[int] = None

# For dp balance
dp_balance_id: int = -1


@dataclass
class EmbeddingReqInput:
Expand Down Expand Up @@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
token_type_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
# For dp balance
dp_balance_id: int = -1


@dataclass
Expand Down
113 changes: 111 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
Expand Down Expand Up @@ -203,6 +203,7 @@ def __init__(
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
):
# Parse args
self.server_args = server_args
Expand Down Expand Up @@ -522,6 +523,15 @@ def __init__(
]
)

self.balance_meta = dp_balance_meta
if (
server_args.enable_dp_attention
and server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None

self.recv_dp_balance_id_this_term = []

def init_tokenizer(self):
server_args = self.server_args

Expand Down Expand Up @@ -1032,6 +1042,12 @@ def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)

# Create a new request
if (
recv_req.session_params is None
Expand Down Expand Up @@ -1442,6 +1458,11 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:

# Handle DP attention
if need_dp_attn_preparation:
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
self.handle_dp_balance_data(ret)
ret = self.prepare_mlp_sync_batch(ret)

return ret
Expand Down Expand Up @@ -1767,6 +1788,86 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)

def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter holding_tokens_list in gather_dp_balance_info is misleading. Based on its usage in handle_dp_balance_data (line 1854-1855), it's a single integer representing the current worker's holding tokens, not a list.

Later, within gather_dp_balance_info (line 1814), holding_tokens_list is re-assigned to be a list of holding tokens gathered from all workers if self.tp_rank == 0.

This dual meaning and misnaming can cause confusion. Consider renaming the parameter to something like current_worker_holding_tokens: int and the locally gathered list to all_workers_holding_tokens: List[int].

"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add explanation for this value?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it means the maximum size of the tensor used for gathering data.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be assert len(recv_list) < 511?
Or recv_tensor could be 1 + 1 + 511 in length, which would exceed 512.

Also, as the Gemini bot mentioned, holding_tokens_list is misleading when it’s actually length of tokens rather than a list — especially since it’s later mixed with real list operations.


# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)

if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None

torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)

gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)

return gathered_id_list_per_worker, holding_tokens_list

def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
meta = self.balance_meta

with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(
onfly_list
), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)

holding_tokens = self.get_load()

new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
holding_tokens
)

self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)

@staticmethod
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
Expand Down Expand Up @@ -2367,6 +2468,7 @@ def run_scheduler_process(
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
# Generate the prefix
prefix = ""
Expand Down Expand Up @@ -2400,7 +2502,14 @@ def run_scheduler_process(
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
dp_rank,
dp_balance_meta=balance_meta,
)
pipe_writer.send(
{
Expand Down
46 changes: 45 additions & 1 deletion python/sglang/srt/managers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import multiprocessing as mp
from http import HTTPStatus
from typing import Optional
from typing import Dict, List, Optional

from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req

Expand Down Expand Up @@ -38,3 +39,46 @@ def validate_input_length(
return error_msg

return None


class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller
"""

def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()

init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]

self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)

def destructor(self):
# we must destructor this class manually
self._manager.shutdown()

def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]

def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data

def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)

def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intentional, but in both set_* functions, passing a Python list directly into a multiprocessing.Manager().List() could replace the managed object, losing the cross-process synchronization.


def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
choices=[
"round_robin",
"shortest_queue",
"minimum_tokens",
],
)

Expand Down
Loading
Loading