Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
self.last_batch_in_queue = last_batch_in_queue

def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch, _ = self.prepare_mlp_sync_batch(batch)
batch = self.prepare_mlp_sync_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
batch = self.get_new_batch_prefill()

if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch

if batch:
Expand Down Expand Up @@ -310,7 +310,7 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
batch = self.get_new_batch_prefill()

if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
DeepEPMode,
ceil_div,
Expand Down Expand Up @@ -1178,12 +1178,14 @@ def forward(
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode,
forward_batch: ForwardBatch,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
forward_batch: ForwardBatch,
forward_batch_info: ForwardBatch,

):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
Expand Down
28 changes: 15 additions & 13 deletions python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()

Expand Down Expand Up @@ -686,21 +686,21 @@ def dispatch_a(
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_mode: ForwardMode = None,
forward_batch: ForwardBatch,
):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_mode).dispatch_a(
inner_state = self._get_impl(forward_batch).dispatch_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
self._dispatch_intermediate_state = forward_mode, inner_state
self._dispatch_intermediate_state = forward_batch, inner_state

def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_mode, inner_state = self._dispatch_intermediate_state
forward_batch, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state
return self._get_impl(forward_mode).dispatch_b(*inner_state)
return self._get_impl(forward_batch).dispatch_b(*inner_state)

def combine(self, *args, **kwargs) -> Tuple:
self.combine_a(*args, **kwargs)
Expand All @@ -712,24 +712,26 @@ def combine_a(
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_mode: ForwardMode,
forward_batch: ForwardBatch,
):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_mode).combine_a(
inner_state = self._get_impl(forward_batch).combine_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
)
self._combine_intermediate_state = forward_mode, inner_state
self._combine_intermediate_state = forward_batch, inner_state

def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_mode, inner_state = self._combine_intermediate_state
forward_batch, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state)
return self._get_impl(forward_batch).combine_b(*inner_state)

def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None
Expand Down Expand Up @@ -1706,6 +1707,7 @@ def get_model_worker_batch(
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
Expand Down Expand Up @@ -1790,6 +1792,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]
Expand Down
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
new_batch = self.prepare_mlp_sync_batch(new_batch)
need_dp_attn_preparation = new_batch is None

if new_batch is not None:
Expand All @@ -1506,7 +1506,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:

# Handle DP attention
if need_dp_attn_preparation:
ret, _ = self.prepare_mlp_sync_batch(ret)
ret = self.prepare_mlp_sync_batch(ret)

return ret

Expand Down Expand Up @@ -1923,8 +1923,7 @@ def prepare_mlp_sync_batch_raw(
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph

# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch)
return local_batch

def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class ForwardBatch:
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
gathered_buffer: Optional[torch.Tensor] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None

Expand Down Expand Up @@ -299,6 +300,7 @@ def init_new(
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
is_extend_in_batch=batch.is_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
lora_paths=batch.lora_paths,
Expand Down
14 changes: 7 additions & 7 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def forward_deepep(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
forward_batch=forward_batch,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
forward_batch=forward_batch,
forward_batch_info=forward_batch,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: wondering whether we should change the same thing in Qwen and other models

)
final_hidden_states = self.experts(
hidden_states=hidden_states,
Expand All @@ -569,14 +569,14 @@ def forward_deepep(
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
forward_batch=forward_batch,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
forward_batch=forward_batch,
forward_batch_info=forward_batch,

)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
hidden_states=final_hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_mode=forward_mode,
forward_batch=forward_batch,
)

if shared_output is not None:
Expand Down Expand Up @@ -651,7 +651,7 @@ def op_dispatch_a(self, state):
hidden_states=state.hidden_states_mlp_input,
topk_idx=state.pop("topk_idx_local"),
topk_weights=state.pop("topk_weights_local"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
forward_batch=state.forward_batch,
forward_batch_info=state.forward_batch,

tbo_subbatch_index=state.get("tbo_subbatch_index"),
)

Expand Down Expand Up @@ -683,7 +683,7 @@ def op_experts(self, state):
masked_m=state.pop("masked_m"),
expected_m=state.pop("expected_m"),
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
forward_batch=state.forward_batch,
forward_batch_info=state.forward_batch,

)

def op_combine_a(self, state):
Expand All @@ -692,7 +692,7 @@ def op_combine_a(self, state):
hidden_states=state.pop("hidden_states_experts_output"),
topk_idx=state.pop("topk_idx_dispatched"),
topk_weights=state.pop("topk_weights_dispatched"),
forward_mode=state.forward_batch.forward_mode,
forward_batch=state.forward_batch,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
forward_batch=state.forward_batch,
forward_batch_info=state.forward_batch,

tbo_subbatch_index=state.get("tbo_subbatch_index"),
)

Expand Down Expand Up @@ -1886,7 +1886,7 @@ def op_mlp(self, state):
and hidden_states.shape[0] == 0
):
state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch.forward_mode
hidden_states, state.forward_batch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming forward_batch to forward_batch_info for clarity.

Suggested change
hidden_states, state.forward_batch
hidden_states, state.forward_batch_info

)
else:
state.hidden_states_mlp_output = hidden_states
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,6 @@ def __post_init__(self):

# DeepEP MoE
if self.enable_deepep_moe:
if self.deepep_mode == "auto":
assert (
not self.enable_dp_attention
), "DeepEP MoE `auto` mode is not supported with DP Attention."
if self.deepep_mode == "normal":
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
self.disable_cuda_graph = True
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/srt/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
Expand Down Expand Up @@ -272,7 +272,11 @@ def replay_prepare(

class TboDPAttentionPreparer:
def prepare_all_gather(
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
self,
local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
):
self.enable_two_batch_overlap = enable_two_batch_overlap

Expand All @@ -294,7 +298,7 @@ def prepare_all_gather(
extend_lens=local_batch.extend_lens,
token_num_per_seq=token_num_per_seq,
)
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
(
local_batch.forward_mode.is_extend()
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,14 +2104,14 @@ def enable_normal(self):
def enable_low_latency(self):
return self in [DeepEPMode.low_latency, DeepEPMode.auto]

def resolve(self, forward_mode):
def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.auto:
return self

if forward_mode.is_decode():
return DeepEPMode.low_latency
else:
if is_extend_in_batch:
return DeepEPMode.normal
else:
return DeepEPMode.low_latency


def is_non_idle_and_non_empty(forward_mode, hidden_states):
Expand Down
Loading
Loading