Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
tp_cpu_group=model_runner.tp_group.cpu_group,
tp_group=model_runner.tp_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
)


Expand Down
16 changes: 13 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,7 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
tp_cpu_group=self.tp_cpu_group,
tp_group=self.tp_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
Expand All @@ -1835,14 +1835,15 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)

@staticmethod
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
tp_cpu_group,
tp_group,
get_idle_batch,
disable_cuda_graph: bool,
spec_algorithm,
Expand All @@ -1851,6 +1852,7 @@ def prepare_mlp_sync_batch_raw(
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
disable_overlap_schedule: bool,
):
# Check if other DP workers have running batches
if local_batch is None:
Expand Down Expand Up @@ -1881,6 +1883,12 @@ def prepare_mlp_sync_batch_raw(
)

tbo_preparer = TboDPAttentionPreparer()
if disable_overlap_schedule:
group = tp_group.device_group
device = tp_group.device
else:
group = tp_group.cpu_group
device = "cpu"

local_info = torch.tensor(
[
Expand All @@ -1896,15 +1904,17 @@ def prepare_mlp_sync_batch_raw(
),
],
dtype=torch.int64,
device=device,
)
global_info = torch.empty(
(dp_size, attn_tp_size, 6),
dtype=torch.int64,
device=device,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=tp_cpu_group,
group=group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
Expand Down
Loading