Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
from typing import Optional

from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import torch
from sgl_kernel import gelu_and_mul, silu_and_mul
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx

from sglang.srt.utils import direct_register_custom_op

if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput


def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
Expand All @@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:

if not renormalize:
gating_output = torch.softmax(gating_output, dim=-1)
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
assert topk_output.format.is_triton_kernel()
routing_data, gather_idx, scatter_idx = topk_output

return triton_kernel_fused_experts(
hidden_states,
Expand Down
106 changes: 84 additions & 22 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from __future__ import annotations

import math
from typing import Callable, NamedTuple, Optional
from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable

import torch
import torch.nn.functional as F
Expand All @@ -27,6 +28,7 @@
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
Expand All @@ -37,6 +39,12 @@
is_npu,
)

try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass


_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
Expand All @@ -58,15 +66,58 @@
import torch_npu


class TopKOutput(NamedTuple):
# -------------------------------- TopKOutput ---------------------------------------


class TopKOutputFormat(Enum):
STANDARD = auto()
TRITON_KERNEL = auto()

def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD

def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL


@runtime_checkable
class TopKOutput(Protocol):
"""Protocol for top-k outputs in different formats."""

@property
def format(self) -> TopKOutputFormat:
"""The format of the output."""
...


class StandardTopKOutput(NamedTuple):
"""Standard top-k output format."""

topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor

@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.STANDARD

class TopK(CustomOp):

# TODO(ch-wan): support triton_kernels
class TritonKernelTopKOutput(NamedTuple):
"""Triton kernel top-k output format."""

routing_data: RoutingData
gather_indx: GatherIndx
scatter_indx: ScatterIndx

@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.TRITON_KERNEL


# -------------------------------- TopK ---------------------------------------


class TopK(CustomOp):

def __init__(
self,
Expand Down Expand Up @@ -97,6 +148,8 @@ def __init__(
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor

self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]

def forward_native(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -131,23 +184,29 @@ def forward_cuda(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
if self.use_triton_kernels:
routing_data, gather_idx, scatter_idx = routing(
router_logits, self.top_k, self.renormalize
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
else:
torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)

def forward_cpu(
self,
Expand Down Expand Up @@ -217,6 +276,9 @@ def forward_npu(
)


# ------------------------------- TopK implementation -------------------------------------


def fused_topk_torch_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down Expand Up @@ -680,4 +742,4 @@ def select_experts(

get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)

return TopKOutput(topk_weights, topk_ids, router_logits)
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
24 changes: 14 additions & 10 deletions python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels

self.triton_kernel_moe_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)

self.triton_kernel_moe_forward = _tk_forward

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -229,16 +237,12 @@ def forward_cuda(
) -> torch.Tensor:

if self.use_triton_kernels:
# TODO(ch-wan): re-enable the Triton kernel
raise NotImplementedError("The Triton kernel is temporarily disabled.")
# return triton_kernel_moe_forward(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# gating_output=router_logits,
# topk=top_k,
# renormalize=renormalize,
# )
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
else:
if _use_aiter:
assert not no_combine, "unsupported"
Expand Down
Loading