Skip to content

Commit 506c492

Browse files
xutizhouTianQiLin666666ch-wan
authored
feat: integrate deepgemm into EPMoE (#6821)
Co-authored-by: tianqilin.99 <[email protected]> Co-authored-by: TianQiLin666666 <[email protected]> Co-authored-by: Cheng Wan <[email protected]>
1 parent 30ceccc commit 506c492

5 files changed

Lines changed: 336 additions & 3 deletions

File tree

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
478478
end_expert_id,
479479
topk,
480480
hidden_size,
481+
dst_start,
481482
BLOCK_SIZE: tl.constexpr,
482483
):
483484
InDtype = down_output_ptr.dtype.element_ty
484485

485-
src_idx = tl.program_id(0)
486+
src_idx_int32 = tl.program_id(0)
487+
src_idx = src_idx_int32.to(tl.int64)
486488
src2dst_ptr = src2dst_ptr + src_idx * topk
487489
topk_ids_ptr = topk_ids_ptr + src_idx * topk
488490
topk_weights_ptr = topk_weights_ptr + src_idx * topk
@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
501503
expert_id = tl.load(topk_ids_ptr + idx)
502504
if expert_id >= start_expert_id and expert_id <= end_expert_id:
503505
computed = True
504-
dst_idx = tl.load(src2dst_ptr + idx)
506+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
507+
dst_idx = dst_idx_int32.to(tl.int64)
508+
dst_idx = dst_idx - dst_start
505509
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
506510
load_ptr = down_output_ptr + dst_idx * hidden_size
507511
in_data = tl.load(load_ptr + offset, mask=mask)
@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
10861090
BLOCK_SIZE_K=BLOCK_SIZE_K,
10871091
)
10881092
return output.t()[:m]
1093+
1094+
1095+
@triton.jit
1096+
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
1097+
expert_id = tl.program_id(0)
1098+
start = tl.load(seg_indptr + expert_id)
1099+
end = tl.load(seg_indptr + expert_id + 1)
1100+
tl.store(masked_m + expert_id, (end - start))
1101+
1102+
1103+
@triton.jit
1104+
def deepgemm_compute_src2dst_triton_kernel(
1105+
topk_ids,
1106+
reorder_ids,
1107+
seg_indptr,
1108+
src2dst,
1109+
m_max,
1110+
num_toks,
1111+
BLOCK_SIZE: tl.constexpr,
1112+
):
1113+
pid = tl.program_id(axis=0)
1114+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1115+
mask = dst_id < num_toks
1116+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
1117+
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1118+
expert_dst_start = tl.load(seg_indptr + expert_id)
1119+
expert_dst_offset = dst_id - expert_dst_start
1120+
dst_id = expert_id * m_max + expert_dst_offset
1121+
tl.store(src2dst + src_id, dst_id, mask=mask)
1122+
1123+
1124+
@triton.jit
1125+
def fill_gateup_input_triton_kernel(
1126+
input_ptr,
1127+
scale_ptr,
1128+
gateup_input_ptr,
1129+
gateup_input_scale_ptr,
1130+
src2dst_ptr,
1131+
topk_ids_ptr,
1132+
start_expert_id,
1133+
end_expert_id,
1134+
topk,
1135+
m_max,
1136+
hidden_size,
1137+
scale_size,
1138+
BLOCK_SIZE: tl.constexpr,
1139+
):
1140+
1141+
src_idx_int32 = tl.program_id(0)
1142+
src_idx = src_idx_int32.to(tl.int64)
1143+
src2dst_ptr = src2dst_ptr + src_idx * topk
1144+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
1145+
src_ptr = input_ptr + src_idx * hidden_size
1146+
scale_src_ptr = scale_ptr + src_idx * scale_size
1147+
1148+
vec = tl.arange(0, BLOCK_SIZE)
1149+
for idx in range(topk):
1150+
expert_id = tl.load(topk_ids_ptr + idx)
1151+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
1152+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
1153+
dst_idx = dst_idx_int32.to(tl.int64)
1154+
dst_idx = dst_idx - start_expert_id * m_max
1155+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1156+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1157+
offset = start_offset + vec
1158+
mask = offset < hidden_size
1159+
in_data = tl.load(src_ptr + offset, mask=mask)
1160+
tl.store(dst_ptr + offset, in_data, mask=mask)
1161+
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
1162+
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
1163+
offset = start_offset + vec
1164+
mask = offset < scale_size
1165+
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
1166+
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
1167+
1168+
1169+
def moe_ep_deepgemm_preprocess(
1170+
topk_ids: torch.Tensor,
1171+
num_experts: int,
1172+
hidden_states: torch.Tensor,
1173+
top_k: int,
1174+
start_expert_id,
1175+
end_expert_id,
1176+
block_shape,
1177+
output_dtype: torch.dtype = torch.float8_e4m3fn,
1178+
):
1179+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1180+
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
1181+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1182+
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
1183+
1184+
compute_seg_indptr_triton_kernel[(num_experts,)](
1185+
reorder_topk_ids, seg_indptr, topk_ids.numel()
1186+
)
1187+
1188+
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1189+
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
1190+
1191+
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1192+
m_max = (hidden_states.size(0) + 255) // 256 * 256
1193+
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
1194+
gateup_input = torch.empty(
1195+
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
1196+
device=hidden_states.device,
1197+
dtype=output_dtype,
1198+
)
1199+
1200+
deepgemm_compute_src2dst_triton_kernel[grid](
1201+
topk_ids,
1202+
reorder_ids,
1203+
seg_indptr,
1204+
src2dst,
1205+
m_max,
1206+
topk_ids.numel(),
1207+
BLOCK_SIZE=256,
1208+
)
1209+
1210+
if block_shape is None:
1211+
block_shape = [128, 128]
1212+
assert len(block_shape) == 2
1213+
block_n, block_k = block_shape[0], block_shape[1]
1214+
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1215+
1216+
gateup_input_scale = torch.empty(
1217+
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
1218+
device=hidden_states.device,
1219+
dtype=scale.dtype,
1220+
)
1221+
1222+
fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
1223+
hidden_states,
1224+
scale,
1225+
gateup_input,
1226+
gateup_input_scale,
1227+
src2dst,
1228+
topk_ids,
1229+
start_expert_id,
1230+
end_expert_id,
1231+
top_k,
1232+
m_max,
1233+
hidden_states.size(1),
1234+
scale.size(1),
1235+
BLOCK_SIZE=1024,
1236+
)
1237+
1238+
return (
1239+
m_max,
1240+
masked_m[start_expert_id : (end_expert_id + 1)],
1241+
expected_m,
1242+
src2dst,
1243+
gateup_input,
1244+
gateup_input_scale,
1245+
)

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ep_scatter,
1717
gelu_and_mul_triton_kernel,
1818
grouped_gemm_triton,
19+
moe_ep_deepgemm_preprocess,
1920
post_reorder_triton_kernel,
2021
pre_reorder_triton_kernel,
2122
run_moe_ep_preproess,
@@ -178,6 +179,7 @@ def __init__(
178179
assert (
179180
num_fused_shared_experts == 0
180181
), "num_fused_shared_experts is not supported in EP"
182+
self.num_fused_shared_experts = num_fused_shared_experts
181183
self.num_experts_per_partition = self.num_experts // self.tp_size
182184
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
183185
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
@@ -227,13 +229,182 @@ def __init__(
227229

228230
self.grouped_gemm_runner = None
229231

232+
self.w13_weight_fp8 = (
233+
self.w13_weight,
234+
(
235+
self.w13_weight_scale_inv
236+
if self.use_block_quant
237+
else self.w13_weight_scale
238+
),
239+
)
240+
self.w2_weight_fp8 = (
241+
self.w2_weight,
242+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
243+
)
244+
230245
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
246+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
247+
return self.forward_deepgemm(hidden_states, router_logits)
248+
else:
249+
return self.forward_normal(hidden_states, router_logits)
250+
251+
def forward_deepgemm(
252+
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
253+
):
254+
assert self.quant_method is not None
255+
assert self.activation == "silu"
231256
hidden_states_shape = hidden_states.shape
232257
hidden_states_dtype = hidden_states.dtype
233258
hidden_states_device = hidden_states.device
259+
topk_weights, topk_ids = select_experts(
260+
hidden_states=hidden_states,
261+
router_logits=router_logits,
262+
top_k=self.top_k,
263+
use_grouped_topk=self.use_grouped_topk,
264+
renormalize=self.renormalize,
265+
topk_group=self.topk_group,
266+
num_expert_group=self.num_expert_group,
267+
num_fused_shared_experts=self.num_fused_shared_experts,
268+
correction_bias=self.correction_bias,
269+
custom_routing_function=self.custom_routing_function,
270+
routed_scaling_factor=self.routed_scaling_factor,
271+
)
234272

235-
assert self.quant_method is not None
273+
if not self.use_block_quant:
274+
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
275+
scale_block_size = 128
276+
w13_weight_scale_n = 2 * (
277+
(self.intermediate_size + scale_block_size - 1) // scale_block_size
278+
)
279+
w13_weight_scale_k = (
280+
hidden_states_shape[-1] + scale_block_size - 1
281+
) // scale_block_size
282+
w13_weight_scale = (
283+
self.w13_weight_scale.unsqueeze(1)
284+
.repeat_interleave(w13_weight_scale_n, dim=1)
285+
.unsqueeze(2)
286+
.repeat_interleave(w13_weight_scale_k, dim=2)
287+
)
288+
self.w13_weight_fp8 = (
289+
self.w13_weight,
290+
w13_weight_scale,
291+
)
292+
w2_weight_scale_n = (
293+
hidden_states_shape[-1] + scale_block_size - 1
294+
) // scale_block_size
295+
w2_weight_scale_k = (
296+
self.intermediate_size + scale_block_size - 1
297+
) // scale_block_size
298+
w2_weight_scale = (
299+
self.w2_weight_scale.unsqueeze(1)
300+
.repeat_interleave(w2_weight_scale_n, dim=1)
301+
.unsqueeze(2)
302+
.repeat_interleave(w2_weight_scale_k, dim=2)
303+
)
304+
self.w2_weight_fp8 = (
305+
self.w2_weight,
306+
w2_weight_scale,
307+
)
236308

309+
# PreReorder
310+
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
311+
moe_ep_deepgemm_preprocess(
312+
topk_ids,
313+
self.num_experts,
314+
hidden_states,
315+
self.top_k,
316+
self.start_expert_id,
317+
self.end_expert_id,
318+
self.block_shape,
319+
)
320+
)
321+
322+
dispose_tensor(hidden_states)
323+
324+
# GroupGemm-0
325+
gateup_input_fp8 = (
326+
gateup_input,
327+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
328+
)
329+
num_groups, m, k = gateup_input_fp8[0].size()
330+
n = self.w13_weight.size(1)
331+
gateup_output = torch.empty(
332+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
333+
)
334+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
335+
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
336+
)
337+
del gateup_input
338+
del gateup_input_fp8
339+
340+
# Act
341+
down_input = torch.empty(
342+
(
343+
gateup_output.shape[0],
344+
gateup_output.shape[1],
345+
gateup_output.shape[2] // 2,
346+
),
347+
device=hidden_states_device,
348+
dtype=self.fp8_dtype,
349+
)
350+
scale_block_size = 128
351+
down_input_scale = torch.empty(
352+
(
353+
gateup_output.shape[0],
354+
gateup_output.shape[1],
355+
gateup_output.shape[2] // 2 // scale_block_size,
356+
),
357+
device=hidden_states_device,
358+
dtype=torch.float32,
359+
)
360+
silu_and_mul_masked_post_quant_fwd(
361+
gateup_output,
362+
down_input,
363+
down_input_scale,
364+
scale_block_size,
365+
masked_m,
366+
)
367+
del gateup_output
368+
369+
# GroupGemm-1
370+
n = self.w2_weight.size(1)
371+
down_input_fp8 = (
372+
down_input,
373+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
374+
)
375+
down_output = torch.empty(
376+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
377+
)
378+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
379+
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
380+
)
381+
del down_input
382+
del down_input_fp8
383+
384+
# PostReorder
385+
output = torch.empty(
386+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
387+
)
388+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
389+
down_output,
390+
output,
391+
src2dst,
392+
topk_ids,
393+
topk_weights,
394+
self.start_expert_id,
395+
self.end_expert_id,
396+
self.top_k,
397+
hidden_states_shape[1],
398+
m_max * self.start_expert_id,
399+
BLOCK_SIZE=512,
400+
)
401+
return output
402+
403+
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
404+
assert self.quant_method is not None
405+
hidden_states_shape = hidden_states.shape
406+
hidden_states_dtype = hidden_states.dtype
407+
hidden_states_device = hidden_states.device
237408
if self.grouped_gemm_runner is None:
238409
self.grouped_gemm_runner = GroupedGemmRunner(
239410
hidden_states.device,
@@ -249,6 +420,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
249420
renormalize=self.renormalize,
250421
topk_group=self.topk_group,
251422
num_expert_group=self.num_expert_group,
423+
num_fused_shared_experts=self.num_fused_shared_experts,
252424
correction_bias=self.correction_bias,
253425
custom_routing_function=self.custom_routing_function,
254426
routed_scaling_factor=self.routed_scaling_factor,
@@ -440,6 +612,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
440612
self.end_expert_id,
441613
self.top_k,
442614
hidden_states_shape[1],
615+
0,
443616
BLOCK_SIZE=512,
444617
)
445618
return output

0 commit comments

Comments
 (0)