-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[feat] support minimum token load balance in dp attention #7379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
85ad5ac
912e14a
bae0c2f
79962c5
ebb79ca
43e1861
95b8614
3d5c409
e5f1742
4cd4014
9d76ee1
0112b56
a8a03a6
ae9b671
34dcb3f
d8b8672
854ab53
457de67
464db43
0b03b3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -741,6 +741,7 @@ def _launch_subprocesses( | |
| pp_rank, | ||
| None, | ||
| writer, | ||
| None, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,15 +16,20 @@ | |
| import logging | ||
| import multiprocessing as mp | ||
| import signal | ||
| import struct | ||
| import sys | ||
| import threading | ||
| import time | ||
| from enum import Enum, auto | ||
| from multiprocessing import shared_memory | ||
| from typing import Dict, List | ||
|
|
||
| import psutil | ||
| import setproctitle | ||
| import zmq | ||
|
|
||
| from sglang.srt.layers.dp_attention import compute_dp_attention_world_info | ||
| from sglang.srt.managers.data_parallel_meta import DPBalanceMeta | ||
| from sglang.srt.managers.io_struct import ( | ||
| TokenizedEmbeddingReqInput, | ||
| TokenizedGenerateReqInput, | ||
|
|
@@ -44,6 +49,7 @@ class LoadBalanceMethod(Enum): | |
|
|
||
| ROUND_ROBIN = auto() | ||
| SHORTEST_QUEUE = auto() | ||
| MINIMUM_TOKENS = auto() | ||
|
|
||
| @classmethod | ||
| def from_str(cls, method: str): | ||
|
|
@@ -57,7 +63,16 @@ def from_str(cls, method: str): | |
| class DataParallelController: | ||
| """A controller that dispatches requests to multiple data parallel workers.""" | ||
|
|
||
| def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: | ||
| def __init__( | ||
| self, | ||
| server_args: ServerArgs, | ||
| port_args: PortArgs, | ||
| dp_balance_meta: DPBalanceMeta, | ||
| ) -> None: | ||
| # for dp balance | ||
| self.global_balance_id = 0 | ||
| self.balance_meta = dp_balance_meta | ||
|
|
||
| # Parse args | ||
| self.max_total_num_tokens = None | ||
| self.server_args = server_args | ||
|
|
@@ -78,6 +93,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: | |
| dispatch_lookup = { | ||
| LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, | ||
| LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, | ||
| LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler, | ||
| } | ||
| self.dispatching = dispatch_lookup[self.load_balance_method] | ||
|
|
||
|
|
@@ -231,6 +247,7 @@ def launch_tensor_parallel_group( | |
| pp_rank, | ||
| dp_rank, | ||
| writer, | ||
| self.balance_meta, | ||
| ), | ||
| ) | ||
| with memory_saver_adapter.configure_subprocess(): | ||
|
|
@@ -266,6 +283,31 @@ def round_robin_scheduler(self, req: Req): | |
| def shortest_queue_scheduler(self, input_requests): | ||
| raise NotImplementedError() | ||
|
|
||
| def minimum_tokens_scheduler(self, req): | ||
| def get_next_global_balance_id() -> int: | ||
| INT32_MAX = 2147483647 | ||
| current_id = self.global_balance_id | ||
| self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX | ||
| return current_id | ||
|
|
||
| req.dp_balance_id = get_next_global_balance_id() | ||
| with self.balance_meta.mutex: | ||
| # 1. local_tokens represents the tokens currently inferring on the worker, | ||
| # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. | ||
| onfly_info = self.balance_meta.get_shared_onfly() | ||
|
||
| local_tokens = self.balance_meta.get_shared_local_tokens() | ||
| total_tokens = [ | ||
| local_token + sum(onfly_dict.values()) | ||
| for local_token, onfly_dict in zip(local_tokens, onfly_info) | ||
| ] | ||
| target_worker = total_tokens.index(min(total_tokens)) | ||
| onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids) | ||
| # 2. write the new onfly info to the shm | ||
| self.balance_meta.set_shared_onfly_info(onfly_info) | ||
|
|
||
| logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") | ||
|
||
| self.workers[target_worker].send_pyobj(req) | ||
|
|
||
| def event_loop(self): | ||
| while True: | ||
| while True: | ||
|
|
@@ -296,9 +338,12 @@ def run_data_parallel_controller_process( | |
| setproctitle.setproctitle("sglang::data_parallel_controller") | ||
| configure_logger(server_args) | ||
| parent_process = psutil.Process().parent() | ||
| balance_meta = DPBalanceMeta(server_args.dp_size) | ||
|
|
||
| try: | ||
| controller = DataParallelController(server_args, port_args) | ||
| controller = DataParallelController( | ||
| server_args, port_args, dp_balance_meta=balance_meta | ||
| ) | ||
| pipe_writer.send( | ||
| { | ||
| "status": "ready", | ||
|
|
@@ -317,3 +362,6 @@ def run_data_parallel_controller_process( | |
| traceback = get_exception_traceback() | ||
| logger.error(f"DataParallelController hit an exception: {traceback}") | ||
| parent_process.send_signal(signal.SIGQUIT) | ||
| finally: | ||
| # we need to destruct mp.Manager() in balance_meta | ||
| balance_meta.destructor() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import logging | ||
| import multiprocessing as mp | ||
| import pickle | ||
| import struct | ||
| from multiprocessing import shared_memory | ||
| from multiprocessing.managers import BaseManager | ||
| from typing import Dict, List | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| """ | ||
| This class will be use in scheduler and dp controller | ||
| If this class is placed in the dp controller, | ||
| it will cause circular references, so it is placed in a separate file. | ||
| """ | ||
|
|
||
|
|
||
| class DPBalanceMeta: | ||
| def __init__(self, num_workers: int): | ||
| self.num_workers = num_workers | ||
| self._manager = mp.Manager() | ||
| self.mutex = self._manager.Lock() | ||
|
||
|
|
||
| init_local_tokens = [0] * self.num_workers | ||
| init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)] | ||
|
|
||
| self.shared_state = self._manager.Namespace() | ||
| self.shared_state.local_tokens = self._manager.list(init_local_tokens) | ||
| self.shared_state.onfly_info = self._manager.list(init_onfly_info) | ||
|
|
||
| def destructor(self): | ||
| # we must destructor this class manually | ||
| self._manager.shutdown() | ||
|
|
||
| def get_shared_onfly(self) -> List[Dict[int, int]]: | ||
| return [dict(d) for d in self.shared_state.onfly_info] | ||
|
|
||
| def set_shared_onfly_info(self, data: List[Dict[int, int]]): | ||
| self.shared_state.onfly_info = data | ||
|
|
||
| def get_shared_local_tokens(self) -> List[int]: | ||
| return list(self.shared_state.local_tokens) | ||
|
|
||
| def set_shared_local_tokens(self, data: List[int]): | ||
| self.shared_state.local_tokens = data | ||
|
|
||
| def __getstate__(self): | ||
| state = self.__dict__.copy() | ||
| del state["_manager"] | ||
| return state | ||
|
|
||
| def __setstate__(self, state): | ||
| self.__dict__.update(state) | ||
| self._manager = None | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -67,6 +67,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.layers.dp_attention import compute_dp_attention_world_info | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.managers.data_parallel_meta import DPBalanceMeta | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.managers.io_struct import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AbortReq, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CloseSessionReqInput, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -237,6 +238,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tp_rank: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pp_rank: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dp_rank: Optional[int], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dp_balance_meta: Optional[DPBalanceMeta] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Parse args | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.server_args = server_args | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -542,6 +544,15 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if get_bool_env_var("SGLANG_GC_LOG"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| configure_gc_logger() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.balance_meta = dp_balance_meta | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_args.enable_dp_attention | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and server_args.load_balance_method == "minimum_tokens" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert dp_balance_meta is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.recv_dp_balance_id_this_term = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def current_scheduler_metrics_enabled(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1092,6 +1103,12 @@ def handle_generate_request( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_req: TokenizedGenerateReqInput, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.server_args.enable_dp_attention | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and self.server_args.load_balance_method == "minimum_tokens" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create a new request | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_req.session_params is None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1653,6 +1670,11 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Handle DP attention | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if need_dp_attn_preparation: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.server_args.load_balance_method == "minimum_tokens" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and self.forward_ct % 40 == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.handle_dp_balance_data(ret) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ret = self.prepare_mlp_sync_batch(ret) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ret | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1978,6 +2000,91 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| disable_overlap_schedule=self.server_args.disable_overlap_schedule, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def handle_dp_balance_data(self, local_batch: ScheduleBatch): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_list = self.recv_dp_balance_id_this_term | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert len(recv_list) <= 511, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "The number of requests received this round is too large. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Please increase gather_tensor_size and onfly_info_size." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gather_tensor_size = 512 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_tensor[0] = holding_tokens_list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_tensor[1] = len( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) # The first element is the length of the list. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_tensor[2 : len(recv_list) + 2] = torch.tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_list, dtype=torch.int32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.tp_rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gathered_list = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.zeros(gather_tensor_size, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for _ in range(self.balance_meta.num_workers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gathered_list = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.distributed.gather( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recv_tensor, gathered_list, group=self.tp_cpu_group | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gathered_id_list_per_worker = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.tp_rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gathered_id_list_per_worker = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| holding_tokens_list = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for tensor in gathered_list: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| holding_tokens_list.append(tensor[0].item()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| list_length = tensor[1].item() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gathered_id_list_per_worker.append( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensor[2 : list_length + 2].tolist() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return gathered_id_list_per_worker, holding_tokens_list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: | |
| """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" | |
| recv_list = self.recv_dp_balance_id_this_term | |
| assert len(recv_list) <= 511, ( | |
| "The number of requests received this round is too large. " | |
| "Please increase gather_tensor_size and onfly_info_size." | |
| ) | |
| gather_tensor_size = 512 | |
| # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids | |
| recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) | |
| recv_tensor[0] = holding_tokens_list | |
| recv_tensor[1] = len( | |
| recv_list | |
| ) # The first element is the length of the list. | |
| recv_tensor[2 : len(recv_list) + 2] = torch.tensor( | |
| recv_list, dtype=torch.int32 | |
| ) | |
| if self.tp_rank == 0: | |
| gathered_list = [ | |
| torch.zeros(gather_tensor_size, dtype=torch.int32) | |
| for _ in range(self.balance_meta.num_workers) | |
| ] | |
| else: | |
| gathered_list = None | |
| torch.distributed.gather( | |
| recv_tensor, gathered_list, group=self.tp_cpu_group | |
| ) | |
| gathered_id_list_per_worker = None | |
| if self.tp_rank == 0: | |
| gathered_id_list_per_worker = [] | |
| holding_tokens_list = [] | |
| for tensor in gathered_list: | |
| holding_tokens_list.append(tensor[0].item()) | |
| list_length = tensor[1].item() | |
| gathered_id_list_per_worker.append( | |
| tensor[2 : list_length + 2].tolist() | |
| ) | |
| return gathered_id_list_per_worker, holding_tokens_list | |
| def gather_dp_balance_info(current_worker_holding_tokens: int) -> Tuple[Optional[List[List[int]]], List[int]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the meaning of this variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, this variable corresponds to the balance_id in TokenizedGenerateReqInput.
We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).