From 85ad5acb71c5fb7a9a15f2e57b42f833f18c9aec Mon Sep 17 00:00:00 2001 From: guanyewang Date: Thu, 19 Jun 2025 15:32:40 +0800 Subject: [PATCH 01/11] [feat] support minimum token load balance in dp attention --- python/sglang/srt/entrypoints/engine.py | 1 + .../srt/managers/data_parallel_controller.py | 52 +++++++- .../sglang/srt/managers/data_parallel_meta.py | 96 +++++++++++++++ python/sglang/srt/managers/io_struct.py | 5 + python/sglang/srt/managers/scheduler.py | 115 +++++++++++++++++- python/sglang/srt/server_args.py | 1 + python/sglang/srt/utils.py | 17 +++ 7 files changed, 284 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/managers/data_parallel_meta.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index fe1a7844fa20..656abd3362da 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -737,6 +737,7 @@ def _launch_subprocesses( pp_rank, None, writer, + None, ), ) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 62c3800c2ef4..dc4717b93929 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -16,15 +16,20 @@ import logging import multiprocessing as mp import signal +import struct +import sys import threading import time from enum import Enum, auto +from multiprocessing import shared_memory +from typing import Dict, List import psutil import setproctitle import zmq from sglang.srt.layers.dp_attention import compute_dp_attention_world_info +from sglang.srt.managers.data_parallel_meta import DPBalanceMeta from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -44,6 +49,7 @@ class LoadBalanceMethod(Enum): ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() + MINIMUM_TOKENS = auto() @classmethod def from_str(cls, method: str): @@ -57,7 +63,16 @@ def from_str(cls, method: str): class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + dp_balance_meta: DPBalanceMeta, + ) -> None: + # for dp balance + self.global_balance_id = 0 + self.balance_meta = dp_balance_meta + # Parse args self.max_total_num_tokens = None self.server_args = server_args @@ -78,6 +93,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler, } self.dispatching = dispatch_lookup[self.load_balance_method] @@ -231,6 +247,7 @@ def launch_tensor_parallel_group( pp_rank, dp_rank, writer, + self.balance_meta, ), ) with memory_saver_adapter.configure_subprocess(): @@ -266,6 +283,31 @@ def round_robin_scheduler(self, req: Req): def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() + def minimum_tokens_scheduler(self, req): + def get_next_global_balance_id() -> int: + INT32_MAX = 2147483647 + current_id = self.global_balance_id + self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX + return current_id + + req.dp_balance_id = get_next_global_balance_id() + with self.balance_meta.mutex: + # 1. local_tokens represents the tokens currently inferring on the worker, + # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. + onfly_info = self.balance_meta.get_shared_onfly() + local_tokens = self.balance_meta.get_shared_local_tokens() + total_tokens = [ + local_token + sum(onfly_dict.values()) + for local_token, onfly_dict in zip(local_tokens, onfly_info) + ] + target_worker = total_tokens.index(min(total_tokens)) + onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids) + # 2. write the new onfly info to the shm + self.balance_meta.set_shared_onfly_info(onfly_info) + + logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") + self.workers[target_worker].send_pyobj(req) + def event_loop(self): while True: while True: @@ -296,9 +338,12 @@ def run_data_parallel_controller_process( setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) parent_process = psutil.Process().parent() + balance_meta = DPBalanceMeta(server_args.dp_size, do_init=True) try: - controller = DataParallelController(server_args, port_args) + controller = DataParallelController( + server_args, port_args, dp_balance_meta=balance_meta + ) pipe_writer.send( { "status": "ready", @@ -317,3 +362,6 @@ def run_data_parallel_controller_process( traceback = get_exception_traceback() logger.error(f"DataParallelController hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) + finally: + # we need to free shared mem manually, because POSIX shm is a named file + balance_meta.destructor() diff --git a/python/sglang/srt/managers/data_parallel_meta.py b/python/sglang/srt/managers/data_parallel_meta.py new file mode 100644 index 000000000000..25c04be16090 --- /dev/null +++ b/python/sglang/srt/managers/data_parallel_meta.py @@ -0,0 +1,96 @@ +import logging +import multiprocessing as mp +import pickle +import struct +from multiprocessing import shared_memory +from typing import Dict, List + +logger = logging.getLogger(__name__) + +""" +This class will be use in scheduler and dp controller +If this class is placed in the dp controller, +it will cause circular references, so it is placed in a separate file. +""" + + +class DPBalanceMeta: + def __init__(self, num_workers: int, do_init: bool = False): + self.mutex = mp.Lock() + self.shm_name_onfly_info = "sglang_dp_balance_onfly_info" + self.shm_name_local_tokens = "sglang_dp_balance_local_tokens" + self.onfly_info_size = ( + 512 * num_workers * 8 + ) # max_onfly_req_per_worker * num_workers * dByte + self.local_tokens_size = num_workers * 8 + self.num_workers = num_workers + + if do_init: + self.shm1 = shared_memory.SharedMemory( + name=self.shm_name_onfly_info, create=True, size=self.onfly_info_size + ) + self.shm2 = shared_memory.SharedMemory( + name=self.shm_name_local_tokens, + create=True, + size=self.local_tokens_size, + ) + init_local_tokens = [0 for _ in range(num_workers)] + init_onfly_req = [{} for _ in range(num_workers)] + self.set_shared_local_tokens(init_local_tokens) + self.set_shared_onfly_info(init_onfly_req) + self.shm1.name + + def destructor(self): + # we must destructor this class manually, otherwise will cause shm leak + self.shm1.close() + self.shm1.unlink() + self.shm2.close() + self.shm2.unlink() + + def get_shared_onfly(self) -> List[Dict[int, int]]: + """Retrieve data from shared memory and deserialize it into List[Dict[int, int]]""" + shm = shared_memory.SharedMemory(name=self.shm_name_onfly_info) + + header_size = struct.calcsize("Q") + data_size = struct.unpack("Q", shm.buf[:header_size])[0] + assert 0 <= data_size <= self.onfly_info_size, "no valid data in shared memory" + + serialized_data = bytes(shm.buf[header_size : header_size + data_size]) + onfly_info = pickle.loads(serialized_data) + + shm.close() + return onfly_info + + def set_shared_onfly_info(self, data: List[Dict[int, int]]): + """Serialize the data and write it to shared memory.""" + serialized_data = pickle.dumps(data) + data_size = len(serialized_data) + + assert data_size < self.onfly_info_size, ( + f"The size of the serialized data {data_size} " + f"exceeds the shared memory capacity {self.onfly_info_size} bytes. " + "Please increase onfly_info_size." + ) + + shm = shared_memory.SharedMemory(name=self.shm_name_onfly_info) + shm.buf[: struct.calcsize("Q")] = struct.pack("Q", data_size) + shm.buf[struct.calcsize("Q") : struct.calcsize("Q") + data_size] = ( + serialized_data + ) + shm.close() + + def get_shared_local_tokens(self) -> List[int]: + shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens) + serialized_data = bytes(shm.buf) + worker_onfly_data = pickle.loads(serialized_data) + shm.close() + return worker_onfly_data + + def set_shared_local_tokens(self, data: List[int]): + serialized_data = pickle.dumps(data) + data_size = len(serialized_data) + + shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens) + shm.buf[:data_size] = serialized_data + + shm.close() diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 9b2768160169..b69a4267b79c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -510,6 +510,9 @@ class TokenizedGenerateReqInput: # For data parallel rank routing data_parallel_rank: Optional[int] = None + # For dp balance + dp_balance_id: int = -1 + @dataclass class EmbeddingReqInput: @@ -637,6 +640,8 @@ class TokenizedEmbeddingReqInput: token_type_ids: List[int] # Dummy sampling params for compatibility sampling_params: SamplingParams + # For dp balance + dp_balance_id: int = -1 @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 452a6d5ab4fa..8c67a64adf2d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -67,6 +67,7 @@ ) from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.data_parallel_meta import DPBalanceMeta from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -237,6 +238,7 @@ def __init__( tp_rank: int, pp_rank: int, dp_rank: Optional[int], + dp_balance_meta: Optional[DPBalanceMeta] = None, ): # Parse args self.server_args = server_args @@ -528,6 +530,15 @@ def __init__( if get_bool_env_var("SGLANG_GC_LOG"): configure_gc_logger() + + self.balance_meta = dp_balance_meta + if ( + server_args.enable_dp_attention + and server_args.load_balance_method == "minimum_tokens" + ): + assert dp_balance_meta is not None + + self.recv_dp_balance_id_this_term = [] def maybe_sleep_on_idle(self): if self.idle_sleeper is not None: @@ -1057,6 +1068,12 @@ def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, ): + if ( + self.server_args.enable_dp_attention + and self.server_args.load_balance_method == "minimum_tokens" + ): + self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) + # Create a new request if ( recv_req.session_params is None @@ -1519,6 +1536,8 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Handle DP attention if need_dp_attn_preparation: + if self.server_args.load_balance_method == "minimum_tokens": + self.handle_dp_balance_data(ret) ret = self.prepare_mlp_sync_batch(ret) return ret @@ -1837,6 +1856,91 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), ) + def handle_dp_balance_data(self, local_batch: ScheduleBatch): + def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: + """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" + recv_list = self.recv_dp_balance_id_this_term + assert len(recv_list) <= 511, ( + "The number of requests received this round is too large. " + "Please increase gather_tensor_size and onfly_info_size." + ) + + gather_tensor_size = 512 + # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids + recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) + recv_tensor[0] = holding_tokens_list + recv_tensor[1] = len( + recv_list + ) # The first element is the length of the list. + recv_tensor[2 : len(recv_list) + 2] = torch.tensor( + recv_list, dtype=torch.int32 + ) + + if self.tp_rank == 0: + gathered_list = [ + torch.zeros(gather_tensor_size, dtype=torch.int32) + for _ in range(self.balance_meta.num_workers) + ] + else: + gathered_list = None + + torch.distributed.gather( + recv_tensor, gathered_list, group=self.tp_cpu_group + ) + + gathered_id_list_per_worker = None + if self.tp_rank == 0: + gathered_id_list_per_worker = [] + holding_tokens_list = [] + for tensor in gathered_list: + holding_tokens_list.append(tensor[0].item()) + list_length = tensor[1].item() + gathered_id_list_per_worker.append( + tensor[2 : list_length + 2].tolist() + ) + + return gathered_id_list_per_worker, holding_tokens_list + + def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens): + meta = self.balance_meta + + with meta.mutex: + onfly_list: List[Dict[int, int]] = meta.get_shared_onfly() + assert len(new_recv_rid_lists) == len( + onfly_list + ), "num_worker not equal" + # 1.Check if the rid received by each worker this round is present in onfly. + # If it is, remove the corresponding onfly item. + worker_id = 0 + for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list): + for new_recv_rid in new_recv_rids: + assert ( + new_recv_rid in on_fly_reqs + ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong" + del on_fly_reqs[new_recv_rid] + worker_id += 1 + # 2. Atomically write local_tokens and onfly into shm under the mutex + meta.set_shared_onfly_info(onfly_list) + meta.set_shared_local_tokens(local_tokens) + + # prepare worker holding tokens this scheduler term + if local_batch is None: + holding_tokens = sum(req.seqlen for req in self.waiting_queue) + else: + holding_tokens = sum(req.seqlen for req in local_batch.reqs) + sum( + req.seqlen for req in self.waiting_queue + ) + + new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info( + holding_tokens + ) + + self.recv_dp_balance_id_this_term.clear() + if self.tp_rank == 0: # only first worker write info + write_shared_dp_balance_info( + new_recv_dp_balance_id_list, holding_token_list + ) + @staticmethod def prepare_mlp_sync_batch_raw( local_batch: ScheduleBatch, @@ -2716,6 +2820,7 @@ def run_scheduler_process( pp_rank: int, dp_rank: Optional[int], pipe_writer, + balance_meta: Optional[DPBalanceMeta] = None, ): # Generate the prefix prefix = "" @@ -2750,7 +2855,15 @@ def run_scheduler_process( init_embedding_cache(embedding_cache_size * 1024 * 1024) # Create a scheduler and run the event loop try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank) + scheduler = Scheduler( + server_args, + port_args, + gpu_id, + tp_rank, + pp_rank, + dp_rank, + dp_balance_meta=balance_meta, + ) pipe_writer.send( { "status": "ready", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c44f53f7ec90..8ec183377acd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1076,6 +1076,7 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=[ "round_robin", "shortest_queue", + "minimum_tokens", ], ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ce159a4da77b..f0f1326e25f4 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -847,6 +847,23 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): """Kill the process and all its child processes.""" + + def close_and_unlink_shm(name): + from multiprocessing import shared_memory + + try: + shm = shared_memory.SharedMemory(name=name) + shm.close() + shm.unlink() + except FileNotFoundError: + pass + + # Clean up shared memory. If the scheduler's code crashes or be killed, + # the POSIX shared memory will leak globally, causing DPBalance to fail to init. + shm_names = ["sglang_dp_balance_onfly_info", "sglang_dp_balance_local_tokens"] + for shm_name in shm_names: + close_and_unlink_shm(shm_name) + # Remove sigchld handler to avoid spammy logs. if threading.current_thread() is threading.main_thread(): signal.signal(signal.SIGCHLD, signal.SIG_DFL) From 912e14a2efa86703df42b2f7024d0eb3f528d368 Mon Sep 17 00:00:00 2001 From: guanyewang Date: Tue, 1 Jul 2025 11:10:51 +0800 Subject: [PATCH 02/11] fix lint error --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8c67a64adf2d..5bdc375fa179 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -530,7 +530,7 @@ def __init__( if get_bool_env_var("SGLANG_GC_LOG"): configure_gc_logger() - + self.balance_meta = dp_balance_meta if ( server_args.enable_dp_attention From bae0c2f211c2542c1ccf19587130dfba02c72639 Mon Sep 17 00:00:00 2001 From: guanyewang Date: Wed, 2 Jul 2025 18:25:14 +0800 Subject: [PATCH 03/11] fix cap small error --- python/sglang/srt/managers/data_parallel_meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_meta.py b/python/sglang/srt/managers/data_parallel_meta.py index 25c04be16090..4c8dc36ac954 100644 --- a/python/sglang/srt/managers/data_parallel_meta.py +++ b/python/sglang/srt/managers/data_parallel_meta.py @@ -22,7 +22,7 @@ def __init__(self, num_workers: int, do_init: bool = False): self.onfly_info_size = ( 512 * num_workers * 8 ) # max_onfly_req_per_worker * num_workers * dByte - self.local_tokens_size = num_workers * 8 + self.local_tokens_size = num_workers * 8 + 512 self.num_workers = num_workers if do_init: From 79962c5c9446e2d79426847068aeefa4f07895f8 Mon Sep 17 00:00:00 2001 From: guanyewang Date: Thu, 10 Jul 2025 19:42:47 +0800 Subject: [PATCH 04/11] manager v1 --- .../srt/managers/data_parallel_controller.py | 2 +- .../sglang/srt/managers/data_parallel_meta.py | 102 ++++++------------ 2 files changed, 36 insertions(+), 68 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index dc4717b93929..4941d7204fac 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -338,7 +338,7 @@ def run_data_parallel_controller_process( setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) parent_process = psutil.Process().parent() - balance_meta = DPBalanceMeta(server_args.dp_size, do_init=True) + balance_meta = DPBalanceMeta(server_args.dp_size) try: controller = DataParallelController( diff --git a/python/sglang/srt/managers/data_parallel_meta.py b/python/sglang/srt/managers/data_parallel_meta.py index 4c8dc36ac954..7a7f3cb19b8f 100644 --- a/python/sglang/srt/managers/data_parallel_meta.py +++ b/python/sglang/srt/managers/data_parallel_meta.py @@ -3,6 +3,7 @@ import pickle import struct from multiprocessing import shared_memory +from multiprocessing.managers import BaseManager from typing import Dict, List logger = logging.getLogger(__name__) @@ -15,82 +16,49 @@ class DPBalanceMeta: - def __init__(self, num_workers: int, do_init: bool = False): - self.mutex = mp.Lock() - self.shm_name_onfly_info = "sglang_dp_balance_onfly_info" - self.shm_name_local_tokens = "sglang_dp_balance_local_tokens" - self.onfly_info_size = ( - 512 * num_workers * 8 - ) # max_onfly_req_per_worker * num_workers * dByte - self.local_tokens_size = num_workers * 8 + 512 + def __init__(self, num_workers: int): self.num_workers = num_workers + self._manager = mp.Manager() + self.mutex = self._manager.Lock() - if do_init: - self.shm1 = shared_memory.SharedMemory( - name=self.shm_name_onfly_info, create=True, size=self.onfly_info_size - ) - self.shm2 = shared_memory.SharedMemory( - name=self.shm_name_local_tokens, - create=True, - size=self.local_tokens_size, - ) - init_local_tokens = [0 for _ in range(num_workers)] - init_onfly_req = [{} for _ in range(num_workers)] - self.set_shared_local_tokens(init_local_tokens) - self.set_shared_onfly_info(init_onfly_req) - self.shm1.name + init_local_tokens = [0] * self.num_workers + init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] + + self.shared_state = self._manager.Namespace() + self.shared_state.local_tokens = self._manager.list(init_local_tokens) + self.shared_state.onfly_info = self._manager.list(init_onfly_info) def destructor(self): - # we must destructor this class manually, otherwise will cause shm leak - self.shm1.close() - self.shm1.unlink() - self.shm2.close() - self.shm2.unlink() + # we must destructor this class manually + self._manager.shutdown() def get_shared_onfly(self) -> List[Dict[int, int]]: - """Retrieve data from shared memory and deserialize it into List[Dict[int, int]]""" - shm = shared_memory.SharedMemory(name=self.shm_name_onfly_info) - - header_size = struct.calcsize("Q") - data_size = struct.unpack("Q", shm.buf[:header_size])[0] - assert 0 <= data_size <= self.onfly_info_size, "no valid data in shared memory" - - serialized_data = bytes(shm.buf[header_size : header_size + data_size]) - onfly_info = pickle.loads(serialized_data) - - shm.close() - return onfly_info + return [dict(d) for d in self.shared_state.onfly_info] def set_shared_onfly_info(self, data: List[Dict[int, int]]): - """Serialize the data and write it to shared memory.""" - serialized_data = pickle.dumps(data) - data_size = len(serialized_data) - - assert data_size < self.onfly_info_size, ( - f"The size of the serialized data {data_size} " - f"exceeds the shared memory capacity {self.onfly_info_size} bytes. " - "Please increase onfly_info_size." - ) - - shm = shared_memory.SharedMemory(name=self.shm_name_onfly_info) - shm.buf[: struct.calcsize("Q")] = struct.pack("Q", data_size) - shm.buf[struct.calcsize("Q") : struct.calcsize("Q") + data_size] = ( - serialized_data - ) - shm.close() + self.shared_state.onfly_info = data def get_shared_local_tokens(self) -> List[int]: - shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens) - serialized_data = bytes(shm.buf) - worker_onfly_data = pickle.loads(serialized_data) - shm.close() - return worker_onfly_data + return list(self.shared_state.local_tokens) def set_shared_local_tokens(self, data: List[int]): - serialized_data = pickle.dumps(data) - data_size = len(serialized_data) - - shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens) - shm.buf[:data_size] = serialized_data - - shm.close() + self.shared_state.local_tokens = data + + def __getstate__(self): + """ + 自定义序列化方法。 + 在序列化时,排除掉无法被 pickle 的 _manager 属性。 + """ + state = self.__dict__.copy() + del state["_manager"] + return state + + def __setstate__(self, state): + """ + 自定义反序列化方法。 + 在新进程中恢复对象状态。_manager 属性将不存在,这没关系, + 因为子进程只需要代理对象(mutex, shared_state),而不需要 manager 本身。 + """ + self.__dict__.update(state) + # 在子进程中,self._manager 会是 None + self._manager = None From ebb79ca6fcfc1a9583801fcef88c2b2ff81450ce Mon Sep 17 00:00:00 2001 From: guanyewang Date: Fri, 11 Jul 2025 20:42:56 +0800 Subject: [PATCH 05/11] collect message per 10 ct --- python/sglang/srt/managers/scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5bdc375fa179..1056c9395f34 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1536,7 +1536,10 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Handle DP attention if need_dp_attn_preparation: - if self.server_args.load_balance_method == "minimum_tokens": + if ( + self.server_args.load_balance_method == "minimum_tokens" + and self.forward_ct % 10 == 0 + ): self.handle_dp_balance_data(ret) ret = self.prepare_mlp_sync_batch(ret) From 43e18611bb710e852c8af10955da2b7eeb5ef846 Mon Sep 17 00:00:00 2001 From: guanyewang Date: Fri, 11 Jul 2025 21:51:59 +0800 Subject: [PATCH 06/11] use mp.manager to do load balancing every 40 forward_cts --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1056c9395f34..d6f61bf0e726 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1538,7 +1538,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: if need_dp_attn_preparation: if ( self.server_args.load_balance_method == "minimum_tokens" - and self.forward_ct % 10 == 0 + and self.forward_ct % 40 == 0 ): self.handle_dp_balance_data(ret) ret = self.prepare_mlp_sync_batch(ret) From 4cd40145e2deec84598aa70b5a4a330776be581e Mon Sep 17 00:00:00 2001 From: guanyewang Date: Tue, 15 Jul 2025 20:46:50 +0800 Subject: [PATCH 07/11] clean irrelevant code --- .../srt/managers/data_parallel_controller.py | 2 +- .../sglang/srt/managers/data_parallel_meta.py | 9 --------- python/sglang/srt/utils.py | 17 ----------------- 3 files changed, 1 insertion(+), 27 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 4941d7204fac..0b93418ac803 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -363,5 +363,5 @@ def run_data_parallel_controller_process( logger.error(f"DataParallelController hit an exception: {traceback}") parent_process.send_signal(signal.SIGQUIT) finally: - # we need to free shared mem manually, because POSIX shm is a named file + # we need to destruct mp.Manager() in balance_meta balance_meta.destructor() diff --git a/python/sglang/srt/managers/data_parallel_meta.py b/python/sglang/srt/managers/data_parallel_meta.py index 7a7f3cb19b8f..9f778808a91b 100644 --- a/python/sglang/srt/managers/data_parallel_meta.py +++ b/python/sglang/srt/managers/data_parallel_meta.py @@ -45,20 +45,11 @@ def set_shared_local_tokens(self, data: List[int]): self.shared_state.local_tokens = data def __getstate__(self): - """ - 自定义序列化方法。 - 在序列化时,排除掉无法被 pickle 的 _manager 属性。 - """ state = self.__dict__.copy() del state["_manager"] return state def __setstate__(self, state): - """ - 自定义反序列化方法。 - 在新进程中恢复对象状态。_manager 属性将不存在,这没关系, - 因为子进程只需要代理对象(mutex, shared_state),而不需要 manager 本身。 - """ self.__dict__.update(state) # 在子进程中,self._manager 会是 None self._manager = None diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f0f1326e25f4..ce159a4da77b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -847,23 +847,6 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): """Kill the process and all its child processes.""" - - def close_and_unlink_shm(name): - from multiprocessing import shared_memory - - try: - shm = shared_memory.SharedMemory(name=name) - shm.close() - shm.unlink() - except FileNotFoundError: - pass - - # Clean up shared memory. If the scheduler's code crashes or be killed, - # the POSIX shared memory will leak globally, causing DPBalance to fail to init. - shm_names = ["sglang_dp_balance_onfly_info", "sglang_dp_balance_local_tokens"] - for shm_name in shm_names: - close_and_unlink_shm(shm_name) - # Remove sigchld handler to avoid spammy logs. if threading.current_thread() is threading.main_thread(): signal.signal(signal.SIGCHLD, signal.SIG_DFL) From 9d76ee15fdb3e069a8ab2862d5895018e4654d84 Mon Sep 17 00:00:00 2001 From: guanyewang Date: Tue, 15 Jul 2025 20:49:24 +0800 Subject: [PATCH 08/11] clean irrelevant code 2 --- python/sglang/srt/managers/data_parallel_meta.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_meta.py b/python/sglang/srt/managers/data_parallel_meta.py index 9f778808a91b..85f38d48ca4b 100644 --- a/python/sglang/srt/managers/data_parallel_meta.py +++ b/python/sglang/srt/managers/data_parallel_meta.py @@ -51,5 +51,4 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) - # 在子进程中,self._manager 会是 None self._manager = None From d8b86728dade219720e8b46f13448b11955e738d Mon Sep 17 00:00:00 2001 From: guanyewang Date: Mon, 28 Jul 2025 10:26:36 +0800 Subject: [PATCH 09/11] move DPBalanceMeta to manager/utils.py and add some comments --- .../srt/managers/data_parallel_controller.py | 4 +- .../sglang/srt/managers/data_parallel_meta.py | 54 ------------------- python/sglang/srt/managers/scheduler.py | 6 +-- python/sglang/srt/managers/utils.py | 46 +++++++++++++++- 4 files changed, 51 insertions(+), 59 deletions(-) delete mode 100644 python/sglang/srt/managers/data_parallel_meta.py diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 0b93418ac803..a7614e7b2d9e 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -29,13 +29,13 @@ import zmq from sglang.srt.layers.dp_attention import compute_dp_attention_world_info -from sglang.srt.managers.data_parallel_meta import DPBalanceMeta from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.utils import DPBalanceMeta from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket @@ -284,6 +284,8 @@ def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() def minimum_tokens_scheduler(self, req): + # This variable corresponds to the balance_id in TokenizedGenerateReqInput. + # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). def get_next_global_balance_id() -> int: INT32_MAX = 2147483647 current_id = self.global_balance_id diff --git a/python/sglang/srt/managers/data_parallel_meta.py b/python/sglang/srt/managers/data_parallel_meta.py deleted file mode 100644 index 85f38d48ca4b..000000000000 --- a/python/sglang/srt/managers/data_parallel_meta.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -import multiprocessing as mp -import pickle -import struct -from multiprocessing import shared_memory -from multiprocessing.managers import BaseManager -from typing import Dict, List - -logger = logging.getLogger(__name__) - -""" -This class will be use in scheduler and dp controller -If this class is placed in the dp controller, -it will cause circular references, so it is placed in a separate file. -""" - - -class DPBalanceMeta: - def __init__(self, num_workers: int): - self.num_workers = num_workers - self._manager = mp.Manager() - self.mutex = self._manager.Lock() - - init_local_tokens = [0] * self.num_workers - init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] - - self.shared_state = self._manager.Namespace() - self.shared_state.local_tokens = self._manager.list(init_local_tokens) - self.shared_state.onfly_info = self._manager.list(init_onfly_info) - - def destructor(self): - # we must destructor this class manually - self._manager.shutdown() - - def get_shared_onfly(self) -> List[Dict[int, int]]: - return [dict(d) for d in self.shared_state.onfly_info] - - def set_shared_onfly_info(self, data: List[Dict[int, int]]): - self.shared_state.onfly_info = data - - def get_shared_local_tokens(self) -> List[int]: - return list(self.shared_state.local_tokens) - - def set_shared_local_tokens(self, data: List[int]): - self.shared_state.local_tokens = data - - def __getstate__(self): - state = self.__dict__.copy() - del state["_manager"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._manager = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index bb6e23bd06ca..d641ae15c031 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -67,7 +67,6 @@ ) from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.data_parallel_meta import DPBalanceMeta from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -129,7 +128,7 @@ from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.managers.utils import validate_input_length +from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache @@ -2011,8 +2010,9 @@ def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: "The number of requests received this round is too large. " "Please increase gather_tensor_size and onfly_info_size." ) - + # The maximum size of the tensor used for gathering data from all workers. gather_tensor_size = 512 + # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) recv_tensor[0] = holding_tokens_list diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 2909e759739b..2ab32f242778 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -1,6 +1,7 @@ import logging +import multiprocessing as mp from http import HTTPStatus -from typing import Optional +from typing import Dict, List, Optional from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req @@ -38,3 +39,46 @@ def validate_input_length( return error_msg return None + + +class DPBalanceMeta: + """ + This class will be use in scheduler and dp controller + """ + + def __init__(self, num_workers: int): + self.num_workers = num_workers + self._manager = mp.Manager() + self.mutex = self._manager.Lock() + + init_local_tokens = [0] * self.num_workers + init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] + + self.shared_state = self._manager.Namespace() + self.shared_state.local_tokens = self._manager.list(init_local_tokens) + self.shared_state.onfly_info = self._manager.list(init_onfly_info) + + def destructor(self): + # we must destructor this class manually + self._manager.shutdown() + + def get_shared_onfly(self) -> List[Dict[int, int]]: + return [dict(d) for d in self.shared_state.onfly_info] + + def set_shared_onfly_info(self, data: List[Dict[int, int]]): + self.shared_state.onfly_info = data + + def get_shared_local_tokens(self) -> List[int]: + return list(self.shared_state.local_tokens) + + def set_shared_local_tokens(self, data: List[int]): + self.shared_state.local_tokens = data + + def __getstate__(self): + state = self.__dict__.copy() + del state["_manager"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._manager = None From 464db430fbfd46103659ff0e6d8757283beed1a9 Mon Sep 17 00:00:00 2001 From: guanyewang Date: Sun, 3 Aug 2025 13:25:12 +0800 Subject: [PATCH 10/11] add ci and remove dp log --- .../srt/managers/data_parallel_controller.py | 2 +- test/srt/test_dp_attention.py | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 7c8d08db8a2a..76b9e1a018a9 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -310,7 +310,7 @@ def get_next_global_balance_id() -> int: # 2. write the new onfly info to the shm self.balance_meta.set_shared_onfly_info(onfly_info) - logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") + # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") self.workers[target_worker].send_pyobj(req) def event_loop(self): diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index af50dc7803c1..f997382f9404 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -137,5 +137,60 @@ def test_gsm8k(self): self.assertGreater(avg_spec_accept_length, 2.5) +class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--enable-dp-attention", + "--dp", + "2", + "--enable-torch-compile", + "--torch-compile-max-bs", + "2", + "--load-balance-method", + "minimum_tokens", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.5) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.8) + + if __name__ == "__main__": unittest.main() From 0b03b3db8b2ac3e17e70851b6a947eecb7083caa Mon Sep 17 00:00:00 2001 From: guanyewang Date: Sun, 3 Aug 2025 14:56:34 +0800 Subject: [PATCH 11/11] add doc --- docs/backend/server_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 636bb4f1b3cf..09e77fca2055 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -155,7 +155,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--dp-size` | The data parallelism size. | 1 | -| `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin | +| `--load-balance-method` | The load balancing strategy for data parallelism. Options include: 'round_robin', 'minimum_tokens'. The Minimum Token algorithm can only be used when DP attention is applied. This algorithm performs load balancing based on the real-time token load of the DP workers. | round_robin | ## Multi-node distributed serving