Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 150 additions & 30 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -50,6 +52,13 @@
next_power_of_2,
)

if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
)

_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
Expand Down Expand Up @@ -797,6 +806,24 @@ def __init__(
"alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
)

# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher

self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)

if self.deepep_mode.enable_low_latency():
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
Expand Down Expand Up @@ -837,37 +864,128 @@ def forward(
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_batch: ForwardBatch,
):
dispatch_output = self.dispatch(
hidden_states, topk_idx, topk_weights, forward_batch
)
hidden_states = self.moe_impl(dispatch_output)
hidden_states = self.combine(
hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
forward_batch,
)
return hidden_states

def dispatch(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)

def moe_impl(self, dispatch_output: DispatchOutput):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_aiter(dispatch_output)
if dispatch_output.format.is_deepep_normal():
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
return self.forward_deepgemm_contiguous(dispatch_output)
else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
return self.forward_normal(dispatch_output)
elif dispatch_output.format.is_deepep_ll():
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")

def forward_normal(
def combine(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.combine(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)

def _prepare_for_normal(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
topk_idx: torch.Tensor,
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_run_moe_deep_preprocess,
)

if hidden_states.shape[0] == 0:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
else:
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states

reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input

def forward_normal(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, topk_idx = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
)
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
hidden_states, topk_idx
)
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device

Expand Down Expand Up @@ -983,10 +1101,13 @@ def forward_normal(

def forward_aiter(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
)
if hidden_states.shape[0] == 0:
return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
Expand Down Expand Up @@ -1014,11 +1135,11 @@ def forward_aiter(

def forward_deepgemm_contiguous(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int],
dispatch_output: DeepEPNormalOutput,
):
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
Expand Down Expand Up @@ -1138,10 +1259,9 @@ def forward_deepgemm_contiguous(

def forward_deepgemm_masked(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
masked_m: torch.Tensor,
expected_m: int,
dispatch_output: DeepEPLLOutput,
):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"

Expand Down
Loading
Loading