@@ -42,9 +42,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
4242 TensorView gemm1_weights, TensorView output1_scales_scalar,
4343 TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
4444 TensorView output2_scales_scalar, TensorView output, int64_t const num_experts,
45- int64_t const top_k, int64_t const n_group, int64_t const topk_group,
45+ int64_t const top_k, Optional< int64_t > const n_group, Optional< int64_t > const topk_group,
4646 int64_t const intermediate_size, int64_t const local_expert_offset,
47- int64_t const local_num_experts, double const routed_scaling_factor,
47+ int64_t const local_num_experts, Optional< double > const routed_scaling_factor,
4848 bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
4949 int64_t const routing_method_type, bool enable_pdl) {
5050 static const std::tuple<int , int > device_props = [hidden_states] {
@@ -62,8 +62,11 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
6262
6363 if (use_routing_scales_on_input) {
6464 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
65- } else {
65+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
66+ RoutingMethodType::DeepSeekV3) {
6667 TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
68+ } else {
69+ TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
6770 }
6871 TVM_FFI_ICHECK_EQ (routing_logits->ndim , 2 ) << " routing_logits must be 2D." ;
6972 TVM_FFI_ICHECK_EQ (routing_logits->shape [1 ], num_experts) << " routing_logits has incorrect shape." ;
@@ -74,17 +77,32 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
7477 << " routing_bias has incorrect shape." ;
7578 }
7679
77- if (n_group <= 0 || topk_group <= 0 ) {
78- TVM_FFI_ICHECK_EQ (top_k, 1 ) << " Current routing kernel (no groups) only supports top_k=1." ;
79- } else {
80- TVM_FFI_ICHECK_LE (top_k, 8 ) << " Current routing kernel (with groups) only supports top_k<=8." ;
81- TVM_FFI_ICHECK_LE (topk_group, 4 )
82- << " Current routing kernel (with groups) only supports topk_group<=4." ;
83- TVM_FFI_ICHECK_LE (topk_group, n_group) << " n_group must not be smaller than topk_group." ;
84- TVM_FFI_ICHECK_EQ (num_experts % n_group, 0 ) << " num_experts must be divisible by n_group" ;
80+ if (n_group.has_value () && n_group.value () != 0 ) {
81+ TVM_FFI_ICHECK (static_cast <RoutingMethodType>(routing_method_type) ==
82+ RoutingMethodType::DeepSeekV3)
83+ << " Routing kernel with groups implies DeepSeekV3 routing method." ;
84+ TVM_FFI_ICHECK (topk_group.has_value ()) << " if n_group is given, topk_group must be given" ;
85+ TVM_FFI_ICHECK_EQ (num_experts % n_group.value (), 0 )
86+ << " num_experts must be divisible by n_group" ;
87+ TVM_FFI_ICHECK (top_k <= 8 && top_k > 0 )
88+ << " Current routing kernel (with groups) only supports top_k<=8 && top_k>0." ;
89+ TVM_FFI_ICHECK (topk_group.value () <= 4 && topk_group.value () > 0 )
90+ << " Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0." ;
91+ TVM_FFI_ICHECK_LE (topk_group.value (), n_group.value ())
92+ << " n_group must not be smaller than topk_group." ;
8593 // This check ensures we have enough experts in the selected groups to handle the top_k routing
86- TVM_FFI_ICHECK_LT (top_k, (topk_group * num_experts / n_group))
94+ TVM_FFI_ICHECK_LT (top_k, (topk_group. value () * num_experts / n_group. value () ))
8795 << " top_k must be less than total number of experts in selected groups" ;
96+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
97+ RoutingMethodType::Renormalize ||
98+ static_cast <RoutingMethodType>(routing_method_type) ==
99+ RoutingMethodType::RenormalizeNaive) {
100+ TVM_FFI_LOG_AND_THROW (NotImplementedError)
101+ << " Don't support routing method type Renormalize(Naive)." ;
102+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
103+ RoutingMethodType::Llama4) {
104+ TVM_FFI_ICHECK_EQ (top_k, 1 )
105+ << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
88106 }
89107 TVM_FFI_ICHECK_EQ (num_experts % 4 , 0 )
90108 << " Routing kernel expects that num_experts must be divisible by 4" ;
@@ -122,11 +140,11 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
122140 args.hidden_size = hidden_states->shape [1 ];
123141 args.hidden_size_output = args.hidden_size ;
124142 args.top_k = top_k;
125- args.n_group = n_group;
126- args.topk_group = topk_group;
143+ args.n_group = n_group. has_value () ? n_group. value () : 0 ;
144+ args.topk_group = topk_group. has_value () ? topk_group. value () : 0 ;
127145 args.local_expert_offset = local_expert_offset;
128146 args.local_num_experts = local_num_experts;
129- args.routed_scaling_factor = routed_scaling_factor;
147+ args.routed_scaling_factor = routed_scaling_factor. has_value () ? routed_scaling_factor. value () : 1.0 ;
130148 args.intermediate_size = intermediate_size;
131149 args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
132150
@@ -282,8 +300,8 @@ void trtllm_fp8_per_tensor_scale_moe(
282300 TensorView gemm1_weights, TensorView output1_scales_scalar,
283301 TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
284302 TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
285- int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset,
286- int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input,
303+ Optional< int64_t > n_group, Optional< int64_t > topk_group, int64_t intermediate_size, int64_t local_expert_offset,
304+ int64_t local_num_experts, Optional< double > routed_scaling_factor, bool use_routing_scales_on_input,
287305 int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
288306 auto dtype = hidden_states->dtype ;
289307 if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
@@ -302,9 +320,9 @@ void trtllm_fp8_block_scale_moe_launcher(
302320 TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
303321 TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
304322 TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
305- int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group,
323+ int64_t const num_experts, int64_t const top_k, Optional< int64_t > const n_group, Optional< int64_t > const topk_group,
306324 int64_t const intermediate_size, int64_t const local_expert_offset,
307- int64_t const local_num_experts, double const routed_scaling_factor,
325+ int64_t const local_num_experts, Optional< double > const routed_scaling_factor,
308326 int64_t const tile_tokens_dim, int64_t const routing_method_type,
309327 tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
310328 bool enable_pdl) {
@@ -321,7 +339,11 @@ void trtllm_fp8_block_scale_moe_launcher(
321339 << " This kernel requires 10.x architecture. Current device has SM "
322340 << std::get<0 >(device_props) << std::get<1 >(device_props);
323341
324- TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
342+ if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
343+ TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_float32) << " routing_logits must be float." ;
344+ } else {
345+ TVM_FFI_ICHECK_EQ (routing_logits->dtype , dl_bfloat16) << " routing_logits must be bfloat16." ;
346+ }
325347 TVM_FFI_ICHECK_EQ (routing_logits->ndim , 2 ) << " routing_logits must be 2D." ;
326348 TVM_FFI_ICHECK_EQ (routing_logits->shape [0 ], hidden_states->shape [0 ])
327349 << " routing_logits and hidden_states must have the same number of tokens." ;
@@ -336,17 +358,32 @@ void trtllm_fp8_block_scale_moe_launcher(
336358 << " routing_bias has incorrect shape." ;
337359 }
338360
339- if (n_group <= 0 || topk_group <= 0 ) {
340- TVM_FFI_ICHECK_EQ (top_k, 1 ) << " Current routing kernel (no groups) only supports top_k=1." ;
341- } else {
342- TVM_FFI_ICHECK_LE (top_k, 8 ) << " Current routing kernel (with groups) only supports top_k<=8." ;
343- TVM_FFI_ICHECK_LE (topk_group, 4 )
344- << " Current routing kernel (with groups) only supports topk_group<=4." ;
345- TVM_FFI_ICHECK_LE (topk_group, n_group) << " n_group must not be smaller than topk_group." ;
346- TVM_FFI_ICHECK_EQ (num_experts % n_group, 0 ) << " num_experts must be divisible by n_group" ;
361+ if (n_group.has_value () && n_group.value () != 0 ) {
362+ TVM_FFI_ICHECK (static_cast <RoutingMethodType>(routing_method_type) ==
363+ RoutingMethodType::DeepSeekV3)
364+ << " Routing kernel with groups implies DeepSeekV3 routing method." ;
365+ TVM_FFI_ICHECK (topk_group.has_value ()) << " if n_group is given, topk_group must be given" ;
366+ TVM_FFI_ICHECK_EQ (num_experts % n_group.value (), 0 )
367+ << " num_experts must be divisible by n_group" ;
368+ TVM_FFI_ICHECK (top_k <= 8 && top_k > 0 )
369+ << " Current routing kernel (with groups) only supports top_k<=8 && top_k>0." ;
370+ TVM_FFI_ICHECK (topk_group.value () <= 4 && topk_group.value () > 0 )
371+ << " Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0." ;
372+ TVM_FFI_ICHECK_LE (topk_group.value (), n_group.value ())
373+ << " n_group must not be smaller than topk_group." ;
347374 // This check ensures we have enough experts in the selected groups to handle the top_k routing
348- TVM_FFI_ICHECK_LT (top_k, (topk_group * num_experts / n_group))
375+ TVM_FFI_ICHECK_LT (top_k, (topk_group. value () * num_experts / n_group. value () ))
349376 << " top_k must be less than total number of experts in selected groups" ;
377+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
378+ RoutingMethodType::Renormalize ||
379+ static_cast <RoutingMethodType>(routing_method_type) ==
380+ RoutingMethodType::RenormalizeNaive) {
381+ TVM_FFI_ICHECK (top_k <= 10 && top_k > 0 )
382+ << " Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0." ;
383+ } else if (static_cast <RoutingMethodType>(routing_method_type) ==
384+ RoutingMethodType::Llama4) {
385+ TVM_FFI_ICHECK_EQ (top_k, 1 )
386+ << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
350387 }
351388 TVM_FFI_ICHECK_EQ (num_experts % 4 , 0 )
352389 << " Routing kernel expects that num_experts must be divisible by 4" ;
@@ -383,11 +420,11 @@ void trtllm_fp8_block_scale_moe_launcher(
383420 args.hidden_size = hidden_states->shape [1 ];
384421 args.hidden_size_output = args.hidden_size ;
385422 args.top_k = top_k;
386- args.n_group = n_group;
387- args.topk_group = topk_group;
423+ args.n_group = n_group. has_value () ? n_group. value () : 0 ;
424+ args.topk_group = topk_group. has_value () ? topk_group. value () : 0 ;
388425 args.local_expert_offset = local_expert_offset;
389426 args.local_num_experts = local_num_experts;
390- args.routed_scaling_factor = routed_scaling_factor;
427+ args.routed_scaling_factor = routed_scaling_factor. has_value () ? routed_scaling_factor. value () : 1.0 ;
391428 args.intermediate_size = intermediate_size;
392429 args.mUseDeepSeekFp8 = true ;
393430
@@ -573,9 +610,9 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView>
573610 TensorView gemm1_weights, TensorView gemm1_weights_scale,
574611 TensorView gemm2_weights, TensorView gemm2_weights_scale,
575612 TensorView output, int64_t num_experts, int64_t top_k,
576- int64_t n_group, int64_t topk_group, int64_t intermediate_size,
613+ Optional< int64_t > n_group, Optional< int64_t > topk_group, int64_t intermediate_size,
577614 int64_t local_expert_offset, int64_t local_num_experts,
578- double routed_scaling_factor, int64_t tile_tokens_dim,
615+ Optional< double > routed_scaling_factor, int64_t tile_tokens_dim,
579616 int64_t routing_method_type, bool use_shuffled_weight,
580617 int64_t weight_layout, bool enable_pdl) {
581618 auto dtype = hidden_states->dtype ;
@@ -696,8 +733,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
696733 TVM_FFI_ICHECK (topk_group.has_value ()) << " if n_group is given, topk_group must be given" ;
697734 TVM_FFI_ICHECK_EQ (num_experts % n_group.value (), 0 )
698735 << " num_experts must be divisible by n_group" ;
699- TVM_FFI_ICHECK (top_k <= 8 && top_k > 0 )
700- << " Current routing kernel (with groups) only supports top_k<=8 && top_k>0." ;
736+ // TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
737+ // << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
701738 TVM_FFI_ICHECK (topk_group.value () <= 4 && topk_group.value () > 0 )
702739 << " Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0." ;
703740 TVM_FFI_ICHECK_LE (topk_group.value (), n_group.value ())
@@ -710,9 +747,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
710747 static_cast <RoutingMethodType>(routing_method_type) ==
711748 RoutingMethodType::RenormalizeNaive ||
712749 static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::TopK) {
713- TVM_FFI_ICHECK (top_k <= 8 && top_k > 0 )
714- << " Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
715- " top_k>0." ;
750+ // TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
751+ // << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
752+ // "top_k>0.";
716753 } else if (static_cast <RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
717754 TVM_FFI_ICHECK_EQ (top_k, 1 )
718755 << " Current routing kernel (no groups, Llama4) only supports top_k=1." ;
0 commit comments