Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 11 additions & 30 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,29 +1350,6 @@ def check_memory(self):
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()

def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""

local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch

def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
Expand Down Expand Up @@ -1407,13 +1384,17 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:

new_batch = self.get_new_batch_prefill()

# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
need_dp_attn_preparation = (
self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm
)
if new_batch is not None or any_new_batch:

if need_dp_attn_preparation and not self.spec_algorithm.is_none():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch, _ = self.prepare_dp_attn_batch(new_batch)
need_dp_attn_preparation = new_batch is None

if new_batch is not None:
# Run prefill first if possible
ret = new_batch
else:
Expand All @@ -1425,7 +1406,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
ret = None

# Handle DP attention
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
if need_dp_attn_preparation:
ret, _ = self.prepare_dp_attn_batch(ret)

return ret
Expand Down
Loading