Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,3 +707,215 @@ def grouped_gemm_triton(
**config,
)
return c


@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m, num_experts, N):
expert_id = tl.program_id(0)
start = tl.load(seg_indptr + expert_id)
end = tl.load(seg_indptr + expert_id + 1)
tl.store(masked_m + expert_id, (end - start))


@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
max_m,
num_toks,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * max_m + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def fill_gateup_input_triton_kernel(
input_ptr,
scale_ptr,
gateup_input_ptr,
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
):

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
scale_src_ptr = scale_ptr + src_idx * scale_size

for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask)
tl.store(dst_ptr + offset, in_data, mask=mask)
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < scale_size
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)


def exp2_upper(num: int) -> int:
for i in range(2, 31):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the variable num start from 2**2=4

value = pow(2, i)
if num <= value:
return value
return num


def moe_ep_deepgemm_preproess(
topk_ids: torch.Tensor,
num_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)

compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)

grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](
seg_indptr, masked_m, num_experts, reorder_topk_ids.numel()
)

m_max = exp2_upper(hidden_states.size(0))
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)

deepgemm_compute_src2dst_triton_kernel[grid](
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
topk_ids.numel(),
BLOCK_SIZE=256,
)

if block_shape is not None:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
if _is_cuda:
hidden_states, scale = sglang_per_token_group_quant_fp8(
hidden_states, block_k
)
else:
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)

gateup_input_scale = torch.empty(
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
device=hidden_states.device,
dtype=scale.dtype,
)

fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
scale,
gateup_input,
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)

return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
expected_m,
src2dst,
gateup_input,
gateup_input_scale,
)


@triton.jit
def deepgemm_post_reorder_triton_kernel(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
max_m,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk

computed = False
store_ptr = output_ptr + src_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size

sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx - start_expert_id * max_m
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)

if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
135 changes: 134 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
deepgemm_post_reorder_triton_kernel,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preproess,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
Expand All @@ -38,7 +40,13 @@
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
from sglang.srt.utils import (
DeepEPMode,
get_bool_env_var,
is_cuda,
is_hip,
set_weight_attrs,
)

_is_hip = is_hip()

Expand All @@ -47,6 +55,8 @@

logger = logging.getLogger(__name__)

epmoe_use_deepgemm = get_bool_env_var("EPMOE_USE_DEEPGEMM")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


We might import it directly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, do you mean we just replace EPMOE_USE_DEEPGEMM with _ENABLE_JIT_DEEPGEMM

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, enabling _ENABLE_JIT_DEEPGEMM will set deepgemm at epmoe as the default configuration.



class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
Expand Down Expand Up @@ -198,7 +208,130 @@ def __init__(

self.grouped_gemm_runner = None

self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if use_deep_gemm and epmoe_use_deepgemm:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why disable EPMOE DeepGEMM when use_deep_gemm is enabled?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe forward_deepgemm is called when use_deep_gemm is enabled.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any cases where Triton GEMM in forward_normal outperforms DeepGEMM?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for now, I didn't find any case where Triton GEMM in forward_normal outperforms DeepGEMM, but DeepGEMM may occupy more GPU memory.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could remove epmoe_use_deepgemm and corresponding Environment variable EPMOE_USE_DEEPGEMM for the sake of clarity.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done

return self.forward_deepgemm(hidden_states, router_logits)
else:
return self.forward_normal(hidden_states, router_logits)

def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
assert self.quant_method is not None

topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)

# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preproess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
gateup_input_fp8 = (
gateup_input,
get_col_major_tma_aligned_tensor(gateup_input_scale),
)

# GroupGemm-0
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
)

m_grouped_gemm_fp8_fp8_bf16_nt_masked(
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)

# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)

# PostReorder
output = torch.empty_like(hidden_states)
deepgemm_post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.size(1),
m_max,
BLOCK_SIZE=512,
)
return output

def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None

if self.grouped_gemm_runner is None:
Expand Down