Skip to content

Commit f156987

Browse files
authored
feat: add direct routing strategy to DP worker (#6884)
1 parent 3465d7a commit f156987

File tree

12 files changed

+78
-8
lines changed

12 files changed

+78
-8
lines changed

python/sglang/srt/disaggregation/common/conn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ def __init__(
109109
mgr: BaseKVManager,
110110
bootstrap_addr: str,
111111
bootstrap_room: Optional[int] = None,
112+
data_parallel_rank: Optional[int] = None,
112113
):
113114
self.bootstrap_room = bootstrap_room
114115
self.bootstrap_addr = bootstrap_addr
115116
self.kv_mgr = mgr
117+
self.data_parallel_rank = data_parallel_rank
116118

117119
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
118120
self.prefill_tp_size, self.prefill_dp_size = (
@@ -180,7 +182,11 @@ def __init__(
180182
self.target_tp_rank = self.target_tp_ranks[0]
181183
self.required_dst_info_num = 1
182184

183-
self.target_dp_group = bootstrap_room % self.prefill_dp_size
185+
if self.data_parallel_rank is not None:
186+
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
187+
self.target_dp_group = self.data_parallel_rank
188+
else:
189+
self.target_dp_group = bootstrap_room % self.prefill_dp_size
184190

185191
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
186192
bootstrap_key = (

python/sglang/srt/disaggregation/decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def add(self, req: Req) -> None:
156156
mgr=self.kv_manager,
157157
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
158158
bootstrap_room=req.bootstrap_room,
159+
data_parallel_rank=req.data_parallel_rank,
159160
)
160161
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
161162

python/sglang/srt/disaggregation/fake/conn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
mgr: BaseKVManager,
5757
bootstrap_addr: str,
5858
bootstrap_room: Optional[int] = None,
59+
data_parallel_rank: Optional[int] = None,
5960
):
6061
self.has_init = False
6162

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,13 +765,15 @@ def __init__(
765765
mgr: MooncakeKVManager,
766766
bootstrap_addr: str,
767767
bootstrap_room: Optional[int] = None,
768+
data_parallel_rank: Optional[int] = None,
768769
):
769770
self.bootstrap_room = bootstrap_room
770771
self.bootstrap_addr = bootstrap_addr
771772
self.kv_mgr = mgr
772773
self.session_id = self.kv_mgr.get_session_id()
773774
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
774775
self.conclude_state = None
776+
self.data_parallel_rank = data_parallel_rank
775777

776778
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
777779
self.prefill_tp_size, self.prefill_dp_size = (
@@ -845,7 +847,11 @@ def __init__(
845847
self.target_tp_rank = self.target_tp_ranks[0]
846848
self.required_dst_info_num = 1
847849

848-
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
850+
if self.data_parallel_rank is not None:
851+
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
852+
self.target_dp_group = self.data_parallel_rank
853+
else:
854+
self.target_dp_group = bootstrap_room % self.prefill_dp_size
849855

850856
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
851857
bootstrap_key = (

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,10 @@ def __init__(
407407
mgr: NixlKVManager,
408408
bootstrap_addr: str,
409409
bootstrap_room: Optional[int] = None,
410+
data_parallel_rank: Optional[int] = None,
410411
):
411412
self.started_transfer = False
412-
super().__init__(mgr, bootstrap_addr, bootstrap_room)
413+
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
413414

414415
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
415416
for bootstrap_info in self.bootstrap_infos:

python/sglang/srt/entrypoints/EngineBase.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def generate(
2323
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
2424
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
2525
custom_logit_processor: Optional[Union[List[str], str]] = None,
26+
return_hidden_states: Optional[bool] = None,
27+
stream: Optional[bool] = None,
28+
bootstrap_host: Optional[Union[List[str], str]] = None,
29+
bootstrap_port: Optional[Union[List[int], int]] = None,
30+
bootstrap_room: Optional[Union[List[int], int]] = None,
31+
data_parallel_rank: Optional[int] = None,
2632
) -> Union[Dict, Iterator[Dict]]:
2733
"""Generate outputs based on given inputs."""
2834
pass

python/sglang/srt/entrypoints/engine.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,22 @@ def generate(
167167
bootstrap_host: Optional[Union[List[str], str]] = None,
168168
bootstrap_port: Optional[Union[List[int], int]] = None,
169169
bootstrap_room: Optional[Union[List[int], int]] = None,
170+
data_parallel_rank: Optional[int] = None,
170171
) -> Union[Dict, Iterator[Dict]]:
171172
"""
172173
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
173174
Please refer to `GenerateReqInput` for the documentation.
174175
"""
176+
if self.server_args.enable_dp_attention:
177+
if data_parallel_rank is None:
178+
logger.info("data_parallel_rank not provided, using default dispatch")
179+
elif data_parallel_rank < 0:
180+
raise ValueError("data_parallel_rank must be non-negative")
181+
elif data_parallel_rank >= self.server_args.dp_size:
182+
raise ValueError(
183+
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
184+
)
185+
175186
obj = GenerateReqInput(
176187
text=prompt,
177188
input_ids=input_ids,
@@ -188,6 +199,7 @@ def generate(
188199
bootstrap_host=bootstrap_host,
189200
bootstrap_port=bootstrap_port,
190201
bootstrap_room=bootstrap_room,
202+
data_parallel_rank=data_parallel_rank,
191203
)
192204
loop = asyncio.get_event_loop()
193205
generator = self.tokenizer_manager.generate_request(obj, None)
@@ -237,11 +249,24 @@ async def async_generate(
237249
bootstrap_host: Optional[Union[List[str], str]] = None,
238250
bootstrap_port: Optional[Union[List[int], int]] = None,
239251
bootstrap_room: Optional[Union[List[int], int]] = None,
252+
data_parallel_rank: Optional[int] = None,
240253
) -> Union[Dict, AsyncIterator[Dict]]:
241254
"""
242255
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
243256
Please refer to `GenerateReqInput` for the documentation.
244257
"""
258+
259+
if self.server_args.enable_dp_attention:
260+
if data_parallel_rank is None:
261+
logger.info("data_parallel_rank not provided, using default dispatch")
262+
elif data_parallel_rank < 0:
263+
raise ValueError("data_parallel_rank must be non-negative")
264+
elif data_parallel_rank >= self.server_args.dp_size:
265+
raise ValueError(
266+
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
267+
)
268+
269+
logger.info(f"data_parallel_rank: {data_parallel_rank}")
245270
obj = GenerateReqInput(
246271
text=prompt,
247272
input_ids=input_ids,
@@ -257,6 +282,7 @@ async def async_generate(
257282
bootstrap_host=bootstrap_host,
258283
bootstrap_port=bootstrap_port,
259284
bootstrap_room=bootstrap_room,
285+
data_parallel_rank=data_parallel_rank,
260286
)
261287
generator = self.tokenizer_manager.generate_request(obj, None)
262288

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,20 @@ def launch_tensor_parallel_group(
248248

249249
def round_robin_scheduler(self, req: Req):
250250
if self.server_args.disaggregation_mode == "null":
251-
self.workers[self.round_robin_counter].send_pyobj(req)
252-
self.round_robin_counter = (self.round_robin_counter + 1) % len(
253-
self.workers
254-
)
251+
if req.data_parallel_rank is not None:
252+
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
253+
self.workers[req.data_parallel_rank].send_pyobj(req)
254+
else:
255+
self.workers[self.round_robin_counter].send_pyobj(req)
256+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
257+
self.workers
258+
)
255259
else:
256-
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
260+
if req.data_parallel_rank is not None:
261+
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
262+
self.workers[req.data_parallel_rank].send_pyobj(req)
263+
else:
264+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
257265

258266
def shortest_queue_scheduler(self, input_requests):
259267
raise NotImplementedError()

python/sglang/srt/managers/io_struct.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ class GenerateReqInput:
106106
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
107107
bootstrap_room: Optional[Union[List[int], int]] = None
108108

109+
# For data parallel rank routing
110+
data_parallel_rank: Optional[int] = None
111+
109112
def contains_mm_input(self) -> bool:
110113
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
111114

@@ -417,6 +420,9 @@ def __getitem__(self, i):
417420
bootstrap_room=(
418421
self.bootstrap_room[i] if self.bootstrap_room is not None else None
419422
),
423+
data_parallel_rank=(
424+
self.data_parallel_rank if self.data_parallel_rank is not None else None
425+
),
420426
)
421427

422428

@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
464470
bootstrap_port: Optional[int] = None
465471
bootstrap_room: Optional[int] = None
466472

473+
# For data parallel rank routing
474+
data_parallel_rank: Optional[int] = None
475+
467476

468477
@dataclass
469478
class EmbeddingReqInput:

python/sglang/srt/managers/schedule_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def __init__(
451451
bootstrap_host: Optional[str] = None,
452452
bootstrap_port: Optional[int] = None,
453453
bootstrap_room: Optional[int] = None,
454+
data_parallel_rank: Optional[int] = None,
454455
):
455456
# Input and output info
456457
self.rid = rid
@@ -605,6 +606,9 @@ def __init__(
605606
self.bootstrap_room: Optional[int] = bootstrap_room
606607
self.disagg_kv_sender: Optional[BaseKVSender] = None
607608

609+
# For data parallel rank routing
610+
self.data_parallel_rank: Optional[int] = data_parallel_rank
611+
608612
# the start index of the sent kv cache
609613
# We want to send it chunk by chunk for chunked prefill.
610614
# After every chunk forward, we do the following:

0 commit comments

Comments
 (0)