1616 ep_scatter ,
1717 gelu_and_mul_triton_kernel ,
1818 grouped_gemm_triton ,
19+ moe_ep_deepgemm_preprocess ,
1920 post_reorder_triton_kernel ,
2021 pre_reorder_triton_kernel ,
2122 run_moe_ep_preproess ,
@@ -178,6 +179,7 @@ def __init__(
178179 assert (
179180 num_fused_shared_experts == 0
180181 ), "num_fused_shared_experts is not supported in EP"
182+ self .num_fused_shared_experts = num_fused_shared_experts
181183 self .num_experts_per_partition = self .num_experts // self .tp_size
182184 self .start_expert_id = self .tp_rank * self .num_experts_per_partition
183185 self .end_expert_id = self .start_expert_id + self .num_experts_per_partition - 1
@@ -227,13 +229,182 @@ def __init__(
227229
228230 self .grouped_gemm_runner = None
229231
232+ self .w13_weight_fp8 = (
233+ self .w13_weight ,
234+ (
235+ self .w13_weight_scale_inv
236+ if self .use_block_quant
237+ else self .w13_weight_scale
238+ ),
239+ )
240+ self .w2_weight_fp8 = (
241+ self .w2_weight ,
242+ self .w2_weight_scale_inv if self .use_block_quant else self .w2_weight_scale ,
243+ )
244+
230245 def forward (self , hidden_states : torch .Tensor , router_logits : torch .Tensor ):
246+ if deep_gemm_wrapper .ENABLE_JIT_DEEPGEMM and self .use_fp8_w8a8 :
247+ return self .forward_deepgemm (hidden_states , router_logits )
248+ else :
249+ return self .forward_normal (hidden_states , router_logits )
250+
251+ def forward_deepgemm (
252+ self , hidden_states : torch .Tensor , router_logits : torch .Tensor
253+ ):
254+ assert self .quant_method is not None
255+ assert self .activation == "silu"
231256 hidden_states_shape = hidden_states .shape
232257 hidden_states_dtype = hidden_states .dtype
233258 hidden_states_device = hidden_states .device
259+ topk_weights , topk_ids = select_experts (
260+ hidden_states = hidden_states ,
261+ router_logits = router_logits ,
262+ top_k = self .top_k ,
263+ use_grouped_topk = self .use_grouped_topk ,
264+ renormalize = self .renormalize ,
265+ topk_group = self .topk_group ,
266+ num_expert_group = self .num_expert_group ,
267+ num_fused_shared_experts = self .num_fused_shared_experts ,
268+ correction_bias = self .correction_bias ,
269+ custom_routing_function = self .custom_routing_function ,
270+ routed_scaling_factor = self .routed_scaling_factor ,
271+ )
234272
235- assert self .quant_method is not None
273+ if not self .use_block_quant :
274+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
275+ scale_block_size = 128
276+ w13_weight_scale_n = 2 * (
277+ (self .intermediate_size + scale_block_size - 1 ) // scale_block_size
278+ )
279+ w13_weight_scale_k = (
280+ hidden_states_shape [- 1 ] + scale_block_size - 1
281+ ) // scale_block_size
282+ w13_weight_scale = (
283+ self .w13_weight_scale .unsqueeze (1 )
284+ .repeat_interleave (w13_weight_scale_n , dim = 1 )
285+ .unsqueeze (2 )
286+ .repeat_interleave (w13_weight_scale_k , dim = 2 )
287+ )
288+ self .w13_weight_fp8 = (
289+ self .w13_weight ,
290+ w13_weight_scale ,
291+ )
292+ w2_weight_scale_n = (
293+ hidden_states_shape [- 1 ] + scale_block_size - 1
294+ ) // scale_block_size
295+ w2_weight_scale_k = (
296+ self .intermediate_size + scale_block_size - 1
297+ ) // scale_block_size
298+ w2_weight_scale = (
299+ self .w2_weight_scale .unsqueeze (1 )
300+ .repeat_interleave (w2_weight_scale_n , dim = 1 )
301+ .unsqueeze (2 )
302+ .repeat_interleave (w2_weight_scale_k , dim = 2 )
303+ )
304+ self .w2_weight_fp8 = (
305+ self .w2_weight ,
306+ w2_weight_scale ,
307+ )
236308
309+ # PreReorder
310+ m_max , masked_m , expected_m , src2dst , gateup_input , gateup_input_scale = (
311+ moe_ep_deepgemm_preprocess (
312+ topk_ids ,
313+ self .num_experts ,
314+ hidden_states ,
315+ self .top_k ,
316+ self .start_expert_id ,
317+ self .end_expert_id ,
318+ self .block_shape ,
319+ )
320+ )
321+
322+ dispose_tensor (hidden_states )
323+
324+ # GroupGemm-0
325+ gateup_input_fp8 = (
326+ gateup_input ,
327+ deep_gemm_wrapper .get_col_major_tma_aligned_tensor (gateup_input_scale ),
328+ )
329+ num_groups , m , k = gateup_input_fp8 [0 ].size ()
330+ n = self .w13_weight .size (1 )
331+ gateup_output = torch .empty (
332+ (num_groups , m , n ), device = hidden_states_device , dtype = torch .bfloat16
333+ )
334+ deep_gemm_wrapper .grouped_gemm_nt_f8f8bf16_masked (
335+ gateup_input_fp8 , self .w13_weight_fp8 , gateup_output , masked_m , expected_m
336+ )
337+ del gateup_input
338+ del gateup_input_fp8
339+
340+ # Act
341+ down_input = torch .empty (
342+ (
343+ gateup_output .shape [0 ],
344+ gateup_output .shape [1 ],
345+ gateup_output .shape [2 ] // 2 ,
346+ ),
347+ device = hidden_states_device ,
348+ dtype = self .fp8_dtype ,
349+ )
350+ scale_block_size = 128
351+ down_input_scale = torch .empty (
352+ (
353+ gateup_output .shape [0 ],
354+ gateup_output .shape [1 ],
355+ gateup_output .shape [2 ] // 2 // scale_block_size ,
356+ ),
357+ device = hidden_states_device ,
358+ dtype = torch .float32 ,
359+ )
360+ silu_and_mul_masked_post_quant_fwd (
361+ gateup_output ,
362+ down_input ,
363+ down_input_scale ,
364+ scale_block_size ,
365+ masked_m ,
366+ )
367+ del gateup_output
368+
369+ # GroupGemm-1
370+ n = self .w2_weight .size (1 )
371+ down_input_fp8 = (
372+ down_input ,
373+ deep_gemm_wrapper .get_col_major_tma_aligned_tensor (down_input_scale ),
374+ )
375+ down_output = torch .empty (
376+ (num_groups , m , n ), device = hidden_states_device , dtype = torch .bfloat16
377+ )
378+ deep_gemm_wrapper .grouped_gemm_nt_f8f8bf16_masked (
379+ down_input_fp8 , self .w2_weight_fp8 , down_output , masked_m , expected_m
380+ )
381+ del down_input
382+ del down_input_fp8
383+
384+ # PostReorder
385+ output = torch .empty (
386+ hidden_states_shape , dtype = hidden_states_dtype , device = hidden_states_device
387+ )
388+ post_reorder_triton_kernel [(hidden_states_shape [0 ],)](
389+ down_output ,
390+ output ,
391+ src2dst ,
392+ topk_ids ,
393+ topk_weights ,
394+ self .start_expert_id ,
395+ self .end_expert_id ,
396+ self .top_k ,
397+ hidden_states_shape [1 ],
398+ m_max * self .start_expert_id ,
399+ BLOCK_SIZE = 512 ,
400+ )
401+ return output
402+
403+ def forward_normal (self , hidden_states : torch .Tensor , router_logits : torch .Tensor ):
404+ assert self .quant_method is not None
405+ hidden_states_shape = hidden_states .shape
406+ hidden_states_dtype = hidden_states .dtype
407+ hidden_states_device = hidden_states .device
237408 if self .grouped_gemm_runner is None :
238409 self .grouped_gemm_runner = GroupedGemmRunner (
239410 hidden_states .device ,
@@ -249,6 +420,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
249420 renormalize = self .renormalize ,
250421 topk_group = self .topk_group ,
251422 num_expert_group = self .num_expert_group ,
423+ num_fused_shared_experts = self .num_fused_shared_experts ,
252424 correction_bias = self .correction_bias ,
253425 custom_routing_function = self .custom_routing_function ,
254426 routed_scaling_factor = self .routed_scaling_factor ,
@@ -440,6 +612,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
440612 self .end_expert_id ,
441613 self .top_k ,
442614 hidden_states_shape [1 ],
615+ 0 ,
443616 BLOCK_SIZE = 512 ,
444617 )
445618 return output
0 commit comments