From 0394bf8b02a2eeeaa27f9403f399165cdac29635 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 00:37:05 +0000 Subject: [PATCH 01/10] init --- python/sglang/srt/entrypoints/EngineBase.py | 6 ++++++ python/sglang/srt/entrypoints/engine.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/python/sglang/srt/entrypoints/EngineBase.py b/python/sglang/srt/entrypoints/EngineBase.py index c7dfafd410fa..bfbc05ea3447 100644 --- a/python/sglang/srt/entrypoints/EngineBase.py +++ b/python/sglang/srt/entrypoints/EngineBase.py @@ -23,6 +23,12 @@ def generate( token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: Optional[bool] = None, + stream: Optional[bool] = None, + bootstrap_host: Optional[Union[List[str], str]] = None, + bootstrap_port: Optional[Union[List[int], int]] = None, + bootstrap_room: Optional[Union[List[int], int]] = None, + attn_dp_rank: Optional[int] = None, ) -> Union[Dict, Iterator[Dict]]: """Generate outputs based on given inputs.""" pass diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 813fc4c7d17c..270341454731 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -167,6 +167,7 @@ def generate( bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, + attn_dp_rank: Optional[int] = None, ) -> Union[Dict, Iterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. @@ -188,6 +189,7 @@ def generate( bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, + attn_dp_rank=attn_dp_rank ) loop = asyncio.get_event_loop() generator = self.tokenizer_manager.generate_request(obj, None) From 7f864e3f5318ff4a7f9ac19ef8116a516a0d1c14 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 05:01:03 +0000 Subject: [PATCH 02/10] frontend before scheduler --- python/sglang/srt/entrypoints/EngineBase.py | 2 +- python/sglang/srt/entrypoints/engine.py | 24 +++++++++++++++++-- python/sglang/srt/managers/io_struct.py | 7 ++++++ .../sglang/srt/managers/tokenizer_manager.py | 1 + 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/EngineBase.py b/python/sglang/srt/entrypoints/EngineBase.py index bfbc05ea3447..9ac68faa7a27 100644 --- a/python/sglang/srt/entrypoints/EngineBase.py +++ b/python/sglang/srt/entrypoints/EngineBase.py @@ -28,7 +28,7 @@ def generate( bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, - attn_dp_rank: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ) -> Union[Dict, Iterator[Dict]]: """Generate outputs based on given inputs.""" pass diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 270341454731..20f00583ddca 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -167,12 +167,20 @@ def generate( bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, - attn_dp_rank: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ) -> Union[Dict, Iterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ + if self.server_args.enable_dp_attention: + if data_parallel_rank is None: + logger.info("data_parallel_rank not provided, using default dispatch") + elif data_parallel_rank < 0: + raise ValueError("data_parallel_rank must be non-negative") + elif data_parallel_rank >= self.server_args.dp_size: + raise ValueError(f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}") + obj = GenerateReqInput( text=prompt, input_ids=input_ids, @@ -189,7 +197,7 @@ def generate( bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, - attn_dp_rank=attn_dp_rank + data_parallel_rank=data_parallel_rank, ) loop = asyncio.get_event_loop() generator = self.tokenizer_manager.generate_request(obj, None) @@ -239,11 +247,22 @@ async def async_generate( bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, + data_parallel_rank: Optional[int] = None, ) -> Union[Dict, AsyncIterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. """ + + if self.server_args.enable_dp_attention: + if data_parallel_rank is None: + logger.info("data_parallel_rank not provided, using default dispatch") + elif data_parallel_rank < 0: + raise ValueError("data_parallel_rank must be non-negative") + elif data_parallel_rank >= self.server_args.dp_size: + raise ValueError(f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]") + + logger.info(f"data_parallel_rank: {data_parallel_rank}") obj = GenerateReqInput( text=prompt, input_ids=input_ids, @@ -259,6 +278,7 @@ async def async_generate( bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, + data_parallel_rank=data_parallel_rank, ) generator = self.tokenizer_manager.generate_request(obj, None) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f13b23b175a8..2d1dab5e2a4b 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -106,6 +106,9 @@ class GenerateReqInput: bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None + # For data parallel rank routing + data_parallel_rank: Optional[int] = None + def contains_mm_input(self) -> bool: return has_valid_data(self.image_data) or has_valid_data(self.audio_data) @@ -417,6 +420,7 @@ def __getitem__(self, i): bootstrap_room=( self.bootstrap_room[i] if self.bootstrap_room is not None else None ), + data_parallel_rank=self.data_parallel_rank if self.data_parallel_rank is not None else None, ) @@ -464,6 +468,9 @@ class TokenizedGenerateReqInput: bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None + # For data parallel rank routing + data_parallel_rank: Optional[int] = None + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f1ee50e559c9..753ef739e3b5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -569,6 +569,7 @@ def _create_tokenized_object( session_params=session_params, custom_logit_processor=obj.custom_logit_processor, return_hidden_states=obj.return_hidden_states, + data_parallel_rank=obj.data_parallel_rank, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( From 10eaa93beb48ed0135b54aa2507c913fd629b077 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 05:14:43 +0000 Subject: [PATCH 03/10] lol --- python/sglang/srt/managers/schedule_batch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 38ec548564b0..eaa5662d496f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -446,6 +446,7 @@ def __init__( bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): # Input and output info self.rid = rid @@ -600,6 +601,9 @@ def __init__( self.bootstrap_room: Optional[int] = bootstrap_room self.disagg_kv_sender: Optional[BaseKVSender] = None + # For data parallel rank routing + self.data_parallel_rank: Optional[int] = data_parallel_rank + # the start index of the sent kv cache # We want to send it chunk by chunk for chunked prefill. # After every chunk forward, we do the following: From d0d529b7eb9d5dfac0b55fffc3a7fa1adb5edf0c Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 05:39:12 +0000 Subject: [PATCH 04/10] bump --- .../sglang/srt/managers/data_parallel_controller.py | 12 ++++++++---- python/sglang/srt/managers/scheduler.py | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 876472312480..8e25ae2697d4 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -248,10 +248,14 @@ def launch_tensor_parallel_group( def round_robin_scheduler(self, req: Req): if self.server_args.disaggregation_mode == "null": - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len( - self.workers - ) + if req.data_parallel_rank is not None: + logger.info(f"Direct routing to DP rank {req.data_parallel_rank}") + self.workers[req.data_parallel_rank].send_pyobj(req) + else: + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) else: self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5c2141d773cf..5f835ad45829 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -948,6 +948,7 @@ def handle_generate_request( bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, bootstrap_room=recv_req.bootstrap_room, + data_parallel_rank=recv_req.data_parallel_rank, ) req.tokenizer = self.tokenizer From 7aacd0cf39554a1b5d7a877b6e4e950181dc103c Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 21:57:26 +0000 Subject: [PATCH 05/10] ok this works lol --- python/sglang/srt/disaggregation/common/conn.py | 8 +++++++- python/sglang/srt/disaggregation/decode.py | 1 + python/sglang/srt/disaggregation/nixl/conn.py | 3 ++- python/sglang/srt/managers/data_parallel_controller.py | 6 +++++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 4d66c18af4e7..fdb18d906fdf 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -109,10 +109,12 @@ def __init__( mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + target_dp_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr + self.target_dp_rank = target_dp_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( @@ -180,7 +182,11 @@ def __init__( self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.target_dp_group = bootstrap_room % self.prefill_dp_size + if self.target_dp_rank is not None: + logger.info(f"[DISAGG] Reciever got rank {self.target_dp_rank}") + self.target_dp_group = self.target_dp_rank + else: + self.target_dp_group = bootstrap_room % self.prefill_dp_size # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 7982f7b6305e..918ff1d0161f 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -156,6 +156,7 @@ def add(self, req: Req) -> None: mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, + target_dp_rank=req.data_parallel_rank, ) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 3ed021a6bdc2..ef041c706810 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -407,9 +407,10 @@ def __init__( mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + target_dp_rank: Optional[int] = None, ): self.started_transfer = False - super().__init__(mgr, bootstrap_addr, bootstrap_room) + super().__init__(mgr, bootstrap_addr, bootstrap_room, target_dp_rank) def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 8e25ae2697d4..aaa5d30cc8bd 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -257,7 +257,11 @@ def round_robin_scheduler(self, req: Req): self.workers ) else: - self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) + if req.data_parallel_rank is not None: + logger.info(f"[DISAGG] Direct routing to DP rank {req.data_parallel_rank}") + self.workers[req.data_parallel_rank].send_pyobj(req) + else: + self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) def shortest_queue_scheduler(self, input_requests): raise NotImplementedError() From d45826e927ad650a5b468ee6691d3d2f86c0ed90 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 22:14:03 +0000 Subject: [PATCH 06/10] logshehe --- python/sglang/srt/disaggregation/common/conn.py | 2 +- python/sglang/srt/managers/data_parallel_controller.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index fdb18d906fdf..73fcf3eb9839 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -183,7 +183,7 @@ def __init__( self.required_dst_info_num = 1 if self.target_dp_rank is not None: - logger.info(f"[DISAGG] Reciever got rank {self.target_dp_rank}") + logger.debug(f"Targeting DP rank: {self.target_dp_rank}") self.target_dp_group = self.target_dp_rank else: self.target_dp_group = bootstrap_room % self.prefill_dp_size diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index aaa5d30cc8bd..901579fc4bf5 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -258,7 +258,7 @@ def round_robin_scheduler(self, req: Req): ) else: if req.data_parallel_rank is not None: - logger.info(f"[DISAGG] Direct routing to DP rank {req.data_parallel_rank}") + logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") self.workers[req.data_parallel_rank].send_pyobj(req) else: self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) From c2c85254735367255bda9103eccba83efbe5be39 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Thu, 5 Jun 2025 23:06:17 +0000 Subject: [PATCH 07/10] debug print --- python/sglang/srt/managers/data_parallel_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 901579fc4bf5..62c3800c2ef4 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -249,7 +249,7 @@ def launch_tensor_parallel_group( def round_robin_scheduler(self, req: Req): if self.server_args.disaggregation_mode == "null": if req.data_parallel_rank is not None: - logger.info(f"Direct routing to DP rank {req.data_parallel_rank}") + logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}") self.workers[req.data_parallel_rank].send_pyobj(req) else: self.workers[self.round_robin_counter].send_pyobj(req) From 11949bdfb8e6eebc7b60ce7a8642de7325be34d7 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Fri, 6 Jun 2025 18:08:33 +0000 Subject: [PATCH 08/10] unify name --- python/sglang/srt/disaggregation/common/conn.py | 10 +++++----- python/sglang/srt/disaggregation/decode.py | 2 +- python/sglang/srt/disaggregation/nixl/conn.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 73fcf3eb9839..e6a6ad445b17 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -109,12 +109,12 @@ def __init__( mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - target_dp_rank: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr - self.target_dp_rank = target_dp_rank + self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( @@ -182,9 +182,9 @@ def __init__( self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - if self.target_dp_rank is not None: - logger.debug(f"Targeting DP rank: {self.target_dp_rank}") - self.target_dp_group = self.target_dp_rank + if self.data_parallel_rank is not None: + logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") + self.target_dp_group = self.data_parallel_rank else: self.target_dp_group = bootstrap_room % self.prefill_dp_size diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 918ff1d0161f..e206450b6e09 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -156,7 +156,7 @@ def add(self, req: Req) -> None: mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, - target_dp_rank=req.data_parallel_rank, + data_parallel_rank=req.data_parallel_rank, ) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index ef041c706810..f9a0e931cf50 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -407,10 +407,10 @@ def __init__( mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, - target_dp_rank: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.started_transfer = False - super().__init__(mgr, bootstrap_addr, bootstrap_room, target_dp_rank) + super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: From b9744b07edd8080d8dd3b9881d0137a9a2b4d2d5 Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Fri, 6 Jun 2025 14:02:33 -0700 Subject: [PATCH 09/10] pre --- python/sglang/srt/entrypoints/engine.py | 8 ++++++-- python/sglang/srt/managers/io_struct.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 20f00583ddca..665eb521afc5 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -179,7 +179,9 @@ def generate( elif data_parallel_rank < 0: raise ValueError("data_parallel_rank must be non-negative") elif data_parallel_rank >= self.server_args.dp_size: - raise ValueError(f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}") + raise ValueError( + f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}" + ) obj = GenerateReqInput( text=prompt, @@ -260,7 +262,9 @@ async def async_generate( elif data_parallel_rank < 0: raise ValueError("data_parallel_rank must be non-negative") elif data_parallel_rank >= self.server_args.dp_size: - raise ValueError(f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]") + raise ValueError( + f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]" + ) logger.info(f"data_parallel_rank: {data_parallel_rank}") obj = GenerateReqInput( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2d1dab5e2a4b..fdee14ef8841 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -106,7 +106,7 @@ class GenerateReqInput: bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None - # For data parallel rank routing + # For data parallel rank routing data_parallel_rank: Optional[int] = None def contains_mm_input(self) -> bool: @@ -420,7 +420,9 @@ def __getitem__(self, i): bootstrap_room=( self.bootstrap_room[i] if self.bootstrap_room is not None else None ), - data_parallel_rank=self.data_parallel_rank if self.data_parallel_rank is not None else None, + data_parallel_rank=( + self.data_parallel_rank if self.data_parallel_rank is not None else None + ), ) @@ -468,7 +470,7 @@ class TokenizedGenerateReqInput: bootstrap_port: Optional[int] = None bootstrap_room: Optional[int] = None - # For data parallel rank routing + # For data parallel rank routing data_parallel_rank: Optional[int] = None From be0a34cd31126dc20b5ccbe7b5a846c9c4a31ffd Mon Sep 17 00:00:00 2001 From: ishandhanani Date: Sun, 8 Jun 2025 20:41:43 +0000 Subject: [PATCH 10/10] dp rank for mooncake and fake --- python/sglang/srt/disaggregation/fake/conn.py | 1 + python/sglang/srt/disaggregation/mooncake/conn.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 1e650753e7ee..d080c8e2ed19 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -56,6 +56,7 @@ def __init__( mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.has_init = False diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 940a25d7423b..1779970435b2 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -758,6 +758,7 @@ def __init__( mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, + data_parallel_rank: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr @@ -765,6 +766,7 @@ def __init__( self.session_id = self.kv_mgr.get_session_id() self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) self.conclude_state = None + self.data_parallel_rank = data_parallel_rank if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( @@ -838,7 +840,11 @@ def __init__( self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.target_dp_group = self.bootstrap_room % self.prefill_dp_size + if self.data_parallel_rank is not None: + logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") + self.target_dp_group = self.data_parallel_rank + else: + self.target_dp_group = bootstrap_room % self.prefill_dp_size # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = (