Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
from sgl_kernel import moe_fused_gate

_is_cuda = is_cuda()
_is_hip = is_hip()
Expand Down Expand Up @@ -220,23 +221,33 @@ def biased_grouped_topk(
compiled: bool = True,
n_share_experts_fusion: int = 0,
):
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
if n_share_experts_fusion == 0:
return moe_fused_gate(
gating_output,
correction_bias,
num_expert_group,
topk_group,
topk,
)
else:
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)


def select_experts(
Expand Down
Loading