-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[feat] support minimum token load balance in dp attention #7379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
85ad5ac
912e14a
bae0c2f
79962c5
ebb79ca
43e1861
95b8614
3d5c409
e5f1742
4cd4014
9d76ee1
0112b56
a8a03a6
ae9b671
34dcb3f
d8b8672
854ab53
457de67
464db43
0b03b3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -731,6 +731,7 @@ def _launch_subprocesses( | |
| pp_rank, | ||
| None, | ||
| writer, | ||
| None, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,9 +16,13 @@ | |
| import logging | ||
| import multiprocessing as mp | ||
| import signal | ||
| import struct | ||
| import sys | ||
| import threading | ||
| import time | ||
| from enum import Enum, auto | ||
| from multiprocessing import shared_memory | ||
| from typing import Dict, List | ||
|
|
||
| import psutil | ||
| import setproctitle | ||
|
|
@@ -32,6 +36,7 @@ | |
| ) | ||
| from sglang.srt.managers.schedule_batch import Req | ||
| from sglang.srt.managers.scheduler import run_scheduler_process | ||
| from sglang.srt.managers.utils import DPBalanceMeta | ||
| from sglang.srt.server_args import PortArgs, ServerArgs | ||
| from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter | ||
| from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket | ||
|
|
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum): | |
|
|
||
| ROUND_ROBIN = auto() | ||
| SHORTEST_QUEUE = auto() | ||
| MINIMUM_TOKENS = auto() | ||
|
|
||
| @classmethod | ||
| def from_str(cls, method: str): | ||
|
|
@@ -58,7 +64,16 @@ def from_str(cls, method: str): | |
| class DataParallelController: | ||
| """A controller that dispatches requests to multiple data parallel workers.""" | ||
|
|
||
| def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: | ||
| def __init__( | ||
| self, | ||
| server_args: ServerArgs, | ||
| port_args: PortArgs, | ||
| dp_balance_meta: DPBalanceMeta, | ||
| ) -> None: | ||
| # for dp balance | ||
| self.global_balance_id = 0 | ||
| self.balance_meta = dp_balance_meta | ||
|
|
||
| # Parse args | ||
| self.max_total_num_tokens = None | ||
| self.server_args = server_args | ||
|
|
@@ -79,6 +94,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: | |
| dispatch_lookup = { | ||
| LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, | ||
| LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, | ||
| LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler, | ||
| } | ||
| self.dispatching = dispatch_lookup[self.load_balance_method] | ||
|
|
||
|
|
@@ -234,6 +250,7 @@ def launch_tensor_parallel_group( | |
| pp_rank, | ||
| dp_rank, | ||
| writer, | ||
| self.balance_meta, | ||
| ), | ||
| ) | ||
| with memory_saver_adapter.configure_subprocess(): | ||
|
|
@@ -269,6 +286,33 @@ def round_robin_scheduler(self, req: Req): | |
| def shortest_queue_scheduler(self, input_requests): | ||
| raise NotImplementedError() | ||
|
|
||
| def minimum_tokens_scheduler(self, req): | ||
| # This variable corresponds to the balance_id in TokenizedGenerateReqInput. | ||
| # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). | ||
| def get_next_global_balance_id() -> int: | ||
| INT32_MAX = 2147483647 | ||
| current_id = self.global_balance_id | ||
| self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX | ||
| return current_id | ||
|
|
||
| req.dp_balance_id = get_next_global_balance_id() | ||
| with self.balance_meta.mutex: | ||
| # 1. local_tokens represents the tokens currently inferring on the worker, | ||
| # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. | ||
| onfly_info = self.balance_meta.get_shared_onfly() | ||
|
||
| local_tokens = self.balance_meta.get_shared_local_tokens() | ||
| total_tokens = [ | ||
| local_token + sum(onfly_dict.values()) | ||
| for local_token, onfly_dict in zip(local_tokens, onfly_info) | ||
| ] | ||
| target_worker = total_tokens.index(min(total_tokens)) | ||
| onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids) | ||
| # 2. write the new onfly info to the shm | ||
| self.balance_meta.set_shared_onfly_info(onfly_info) | ||
|
|
||
| # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") | ||
| self.workers[target_worker].send_pyobj(req) | ||
|
|
||
| def event_loop(self): | ||
| while True: | ||
| while True: | ||
|
|
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process( | |
| setproctitle.setproctitle("sglang::data_parallel_controller") | ||
| configure_logger(server_args) | ||
| parent_process = psutil.Process().parent() | ||
| balance_meta = DPBalanceMeta(server_args.dp_size) | ||
|
|
||
| try: | ||
| controller = DataParallelController(server_args, port_args) | ||
| controller = DataParallelController( | ||
| server_args, port_args, dp_balance_meta=balance_meta | ||
| ) | ||
| pipe_writer.send( | ||
| { | ||
| "status": "ready", | ||
|
|
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process( | |
| traceback = get_exception_traceback() | ||
| logger.error(f"DataParallelController hit an exception: {traceback}") | ||
| parent_process.send_signal(signal.SIGQUIT) | ||
| finally: | ||
| # we need to destruct mp.Manager() in balance_meta | ||
| balance_meta.destructor() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,7 +125,7 @@ | |
| from sglang.srt.managers.session_controller import Session | ||
| from sglang.srt.managers.tp_worker import TpModelWorker | ||
| from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient | ||
| from sglang.srt.managers.utils import validate_input_length | ||
| from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length | ||
| from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache | ||
| from sglang.srt.mem_cache.hiradix_cache import HiRadixCache | ||
| from sglang.srt.mem_cache.radix_cache import RadixCache | ||
|
|
@@ -203,6 +203,7 @@ def __init__( | |
| moe_ep_rank: int, | ||
| pp_rank: int, | ||
| dp_rank: Optional[int], | ||
| dp_balance_meta: Optional[DPBalanceMeta] = None, | ||
| ): | ||
| # Parse args | ||
| self.server_args = server_args | ||
|
|
@@ -522,6 +523,15 @@ def __init__( | |
| ] | ||
| ) | ||
|
|
||
| self.balance_meta = dp_balance_meta | ||
| if ( | ||
| server_args.enable_dp_attention | ||
| and server_args.load_balance_method == "minimum_tokens" | ||
| ): | ||
| assert dp_balance_meta is not None | ||
|
|
||
| self.recv_dp_balance_id_this_term = [] | ||
|
|
||
| def init_tokenizer(self): | ||
| server_args = self.server_args | ||
|
|
||
|
|
@@ -1032,6 +1042,12 @@ def handle_generate_request( | |
| self, | ||
| recv_req: TokenizedGenerateReqInput, | ||
| ): | ||
| if ( | ||
| self.server_args.enable_dp_attention | ||
| and self.server_args.load_balance_method == "minimum_tokens" | ||
| ): | ||
| self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) | ||
|
|
||
| # Create a new request | ||
| if ( | ||
| recv_req.session_params is None | ||
|
|
@@ -1442,6 +1458,11 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: | |
|
|
||
| # Handle DP attention | ||
| if need_dp_attn_preparation: | ||
| if ( | ||
| self.server_args.load_balance_method == "minimum_tokens" | ||
| and self.forward_ct % 40 == 0 | ||
| ): | ||
| self.handle_dp_balance_data(ret) | ||
| ret = self.prepare_mlp_sync_batch(ret) | ||
|
|
||
| return ret | ||
|
|
@@ -1767,6 +1788,86 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): | |
| disable_overlap_schedule=self.server_args.disable_overlap_schedule, | ||
| ) | ||
|
|
||
| def handle_dp_balance_data(self, local_batch: ScheduleBatch): | ||
| def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: | ||
|
||
| """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" | ||
| recv_list = self.recv_dp_balance_id_this_term | ||
| assert len(recv_list) <= 511, ( | ||
| "The number of requests received this round is too large. " | ||
| "Please increase gather_tensor_size and onfly_info_size." | ||
| ) | ||
| # The maximum size of the tensor used for gathering data from all workers. | ||
| gather_tensor_size = 512 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add explanation for this value?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it means the maximum size of the tensor used for gathering data.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be assert len(recv_list) < 511? Also, as the Gemini bot mentioned, holding_tokens_list is misleading when it’s actually length of tokens rather than a list — especially since it’s later mixed with real list operations. |
||
|
|
||
| # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids | ||
| recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) | ||
| recv_tensor[0] = holding_tokens_list | ||
| recv_tensor[1] = len( | ||
| recv_list | ||
| ) # The first element is the length of the list. | ||
| recv_tensor[2 : len(recv_list) + 2] = torch.tensor( | ||
| recv_list, dtype=torch.int32 | ||
| ) | ||
|
|
||
| if self.tp_rank == 0: | ||
| gathered_list = [ | ||
| torch.zeros(gather_tensor_size, dtype=torch.int32) | ||
| for _ in range(self.balance_meta.num_workers) | ||
| ] | ||
| else: | ||
| gathered_list = None | ||
|
|
||
| torch.distributed.gather( | ||
| recv_tensor, gathered_list, group=self.tp_cpu_group | ||
| ) | ||
|
|
||
| gathered_id_list_per_worker = None | ||
| if self.tp_rank == 0: | ||
| gathered_id_list_per_worker = [] | ||
| holding_tokens_list = [] | ||
| for tensor in gathered_list: | ||
| holding_tokens_list.append(tensor[0].item()) | ||
| list_length = tensor[1].item() | ||
| gathered_id_list_per_worker.append( | ||
| tensor[2 : list_length + 2].tolist() | ||
| ) | ||
|
|
||
| return gathered_id_list_per_worker, holding_tokens_list | ||
|
|
||
| def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens): | ||
| meta = self.balance_meta | ||
|
|
||
| with meta.mutex: | ||
| onfly_list: List[Dict[int, int]] = meta.get_shared_onfly() | ||
| assert len(new_recv_rid_lists) == len( | ||
| onfly_list | ||
| ), "num_worker not equal" | ||
| # 1.Check if the rid received by each worker this round is present in onfly. | ||
| # If it is, remove the corresponding onfly item. | ||
| worker_id = 0 | ||
| for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list): | ||
| for new_recv_rid in new_recv_rids: | ||
| assert ( | ||
| new_recv_rid in on_fly_reqs | ||
| ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong" | ||
| del on_fly_reqs[new_recv_rid] | ||
| worker_id += 1 | ||
| # 2. Atomically write local_tokens and onfly into shm under the mutex | ||
| meta.set_shared_onfly_info(onfly_list) | ||
| meta.set_shared_local_tokens(local_tokens) | ||
|
|
||
| holding_tokens = self.get_load() | ||
|
|
||
| new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info( | ||
| holding_tokens | ||
| ) | ||
|
|
||
| self.recv_dp_balance_id_this_term.clear() | ||
| if self.tp_rank == 0: # only first worker write info | ||
| write_shared_dp_balance_info( | ||
| new_recv_dp_balance_id_list, holding_token_list | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def prepare_mlp_sync_batch_raw( | ||
| local_batch: ScheduleBatch, | ||
|
|
@@ -2367,6 +2468,7 @@ def run_scheduler_process( | |
| pp_rank: int, | ||
| dp_rank: Optional[int], | ||
| pipe_writer, | ||
| balance_meta: Optional[DPBalanceMeta] = None, | ||
| ): | ||
| # Generate the prefix | ||
| prefix = "" | ||
|
|
@@ -2400,7 +2502,14 @@ def run_scheduler_process( | |
| # Create a scheduler and run the event loop | ||
| try: | ||
| scheduler = Scheduler( | ||
| server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank | ||
| server_args, | ||
| port_args, | ||
| gpu_id, | ||
| tp_rank, | ||
| moe_ep_rank, | ||
| pp_rank, | ||
| dp_rank, | ||
| dp_balance_meta=balance_meta, | ||
| ) | ||
| pipe_writer.send( | ||
| { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import logging | ||
| import multiprocessing as mp | ||
| from http import HTTPStatus | ||
| from typing import Optional | ||
| from typing import Dict, List, Optional | ||
|
|
||
| from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req | ||
|
|
||
|
|
@@ -38,3 +39,46 @@ def validate_input_length( | |
| return error_msg | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| class DPBalanceMeta: | ||
| """ | ||
| This class will be use in scheduler and dp controller | ||
| """ | ||
|
|
||
| def __init__(self, num_workers: int): | ||
| self.num_workers = num_workers | ||
| self._manager = mp.Manager() | ||
| self.mutex = self._manager.Lock() | ||
|
|
||
| init_local_tokens = [0] * self.num_workers | ||
| init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] | ||
|
|
||
| self.shared_state = self._manager.Namespace() | ||
| self.shared_state.local_tokens = self._manager.list(init_local_tokens) | ||
| self.shared_state.onfly_info = self._manager.list(init_onfly_info) | ||
|
|
||
| def destructor(self): | ||
| # we must destructor this class manually | ||
| self._manager.shutdown() | ||
|
|
||
| def get_shared_onfly(self) -> List[Dict[int, int]]: | ||
| return [dict(d) for d in self.shared_state.onfly_info] | ||
|
|
||
| def set_shared_onfly_info(self, data: List[Dict[int, int]]): | ||
| self.shared_state.onfly_info = data | ||
|
|
||
| def get_shared_local_tokens(self) -> List[int]: | ||
| return list(self.shared_state.local_tokens) | ||
|
|
||
| def set_shared_local_tokens(self, data: List[int]): | ||
| self.shared_state.local_tokens = data | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this is intentional, but in both set_* functions, passing a Python list directly into a multiprocessing.Manager().List() could replace the managed object, losing the cross-process synchronization. |
||
|
|
||
| def __getstate__(self): | ||
| state = self.__dict__.copy() | ||
| del state["_manager"] | ||
| return state | ||
|
|
||
| def __setstate__(self, state): | ||
| self.__dict__.update(state) | ||
| self._manager = None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the meaning of this variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, this variable corresponds to the balance_id in TokenizedGenerateReqInput.
We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).