Skip to content

Commit 1a94ecb

Browse files
committed
Update the routing for TRTLLMGEN to support kimi k2 and qwen
Remove the cudaDeviceSync;Add support numexpert16; Clean up unit test Signed-off-by: Christina Zhang <[email protected]> Signed-off-by: jiahanc <[email protected]>
1 parent d728bcd commit 1a94ecb

12 files changed

+1028
-535
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ void TrtllmGenBatchedGemmRunner::run(
169169
auto const configs = bmm.getBatchedGemmConfigs();
170170

171171
auto const& config = configs[configIndex];
172-
172+
std::cout << "config.mFunctionName: " << config.mFunctionName << std::endl;
173173
FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
174174
if (!mOptions.staticBatch) {
175175
FLASHINFER_CHECK(totalNumPaddedTokens,

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
4242
TensorView gemm1_weights, TensorView output1_scales_scalar,
4343
TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
4444
TensorView output2_scales_scalar, TensorView output, int64_t const num_experts,
45-
int64_t const top_k, int64_t const n_group, int64_t const topk_group,
45+
int64_t const top_k, Optional<int64_t> const n_group, Optional<int64_t> const topk_group,
4646
int64_t const intermediate_size, int64_t const local_expert_offset,
47-
int64_t const local_num_experts, double const routed_scaling_factor,
47+
int64_t const local_num_experts, Optional<double> const routed_scaling_factor,
4848
bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
4949
int64_t const routing_method_type, bool enable_pdl) {
5050
static const std::tuple<int, int> device_props = [hidden_states] {
@@ -62,8 +62,11 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
6262

6363
if (use_routing_scales_on_input) {
6464
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
65-
} else {
65+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
66+
RoutingMethodType::DeepSeekV3) {
6667
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
68+
} else {
69+
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
6770
}
6871
TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D.";
6972
TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) << "routing_logits has incorrect shape.";
@@ -74,17 +77,32 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
7477
<< "routing_bias has incorrect shape.";
7578
}
7679

77-
if (n_group <= 0 || topk_group <= 0) {
78-
TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
79-
} else {
80-
TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8.";
81-
TVM_FFI_ICHECK_LE(topk_group, 4)
82-
<< "Current routing kernel (with groups) only supports topk_group<=4.";
83-
TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
84-
TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
80+
if (n_group.has_value() && n_group.value() != 0) {
81+
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
82+
RoutingMethodType::DeepSeekV3)
83+
<< "Routing kernel with groups implies DeepSeekV3 routing method.";
84+
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
85+
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
86+
<< "num_experts must be divisible by n_group";
87+
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
88+
<< "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
89+
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
90+
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
91+
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
92+
<< "n_group must not be smaller than topk_group.";
8593
// This check ensures we have enough experts in the selected groups to handle the top_k routing
86-
TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
94+
TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value()))
8795
<< "top_k must be less than total number of experts in selected groups";
96+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
97+
RoutingMethodType::Renormalize ||
98+
static_cast<RoutingMethodType>(routing_method_type) ==
99+
RoutingMethodType::RenormalizeNaive) {
100+
TVM_FFI_LOG_AND_THROW(NotImplementedError)
101+
<< "Don't support routing method type Renormalize(Naive).";
102+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
103+
RoutingMethodType::Llama4) {
104+
TVM_FFI_ICHECK_EQ(top_k, 1)
105+
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
88106
}
89107
TVM_FFI_ICHECK_EQ(num_experts % 4, 0)
90108
<< "Routing kernel expects that num_experts must be divisible by 4";
@@ -122,11 +140,11 @@ void trtllm_fp8_per_tensor_scale_moe_launcher(
122140
args.hidden_size = hidden_states->shape[1];
123141
args.hidden_size_output = args.hidden_size;
124142
args.top_k = top_k;
125-
args.n_group = n_group;
126-
args.topk_group = topk_group;
143+
args.n_group = n_group.has_value() ? n_group.value() : 0;
144+
args.topk_group = topk_group.has_value() ? topk_group.value() : 0;
127145
args.local_expert_offset = local_expert_offset;
128146
args.local_num_experts = local_num_experts;
129-
args.routed_scaling_factor = routed_scaling_factor;
147+
args.routed_scaling_factor = routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
130148
args.intermediate_size = intermediate_size;
131149
args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
132150

@@ -282,8 +300,8 @@ void trtllm_fp8_per_tensor_scale_moe(
282300
TensorView gemm1_weights, TensorView output1_scales_scalar,
283301
TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
284302
TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
285-
int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset,
286-
int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input,
303+
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size, int64_t local_expert_offset,
304+
int64_t local_num_experts, Optional<double> routed_scaling_factor, bool use_routing_scales_on_input,
287305
int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
288306
auto dtype = hidden_states->dtype;
289307
if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) {
@@ -302,9 +320,9 @@ void trtllm_fp8_block_scale_moe_launcher(
302320
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
303321
TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
304322
TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
305-
int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group,
323+
int64_t const num_experts, int64_t const top_k, Optional<int64_t> const n_group, Optional<int64_t> const topk_group,
306324
int64_t const intermediate_size, int64_t const local_expert_offset,
307-
int64_t const local_num_experts, double const routed_scaling_factor,
325+
int64_t const local_num_experts, Optional<double> const routed_scaling_factor,
308326
int64_t const tile_tokens_dim, int64_t const routing_method_type,
309327
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
310328
bool enable_pdl) {
@@ -321,7 +339,11 @@ void trtllm_fp8_block_scale_moe_launcher(
321339
<< "This kernel requires 10.x architecture. Current device has SM "
322340
<< std::get<0>(device_props) << std::get<1>(device_props);
323341

324-
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
342+
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
343+
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float.";
344+
} else {
345+
TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16.";
346+
}
325347
TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D.";
326348
TVM_FFI_ICHECK_EQ(routing_logits->shape[0], hidden_states->shape[0])
327349
<< "routing_logits and hidden_states must have the same number of tokens.";
@@ -336,17 +358,32 @@ void trtllm_fp8_block_scale_moe_launcher(
336358
<< "routing_bias has incorrect shape.";
337359
}
338360

339-
if (n_group <= 0 || topk_group <= 0) {
340-
TVM_FFI_ICHECK_EQ(top_k, 1) << "Current routing kernel (no groups) only supports top_k=1.";
341-
} else {
342-
TVM_FFI_ICHECK_LE(top_k, 8) << "Current routing kernel (with groups) only supports top_k<=8.";
343-
TVM_FFI_ICHECK_LE(topk_group, 4)
344-
<< "Current routing kernel (with groups) only supports topk_group<=4.";
345-
TVM_FFI_ICHECK_LE(topk_group, n_group) << "n_group must not be smaller than topk_group.";
346-
TVM_FFI_ICHECK_EQ(num_experts % n_group, 0) << "num_experts must be divisible by n_group";
361+
if (n_group.has_value() && n_group.value() != 0) {
362+
TVM_FFI_ICHECK(static_cast<RoutingMethodType>(routing_method_type) ==
363+
RoutingMethodType::DeepSeekV3)
364+
<< "Routing kernel with groups implies DeepSeekV3 routing method.";
365+
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
366+
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
367+
<< "num_experts must be divisible by n_group";
368+
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
369+
<< "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
370+
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
371+
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
372+
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
373+
<< "n_group must not be smaller than topk_group.";
347374
// This check ensures we have enough experts in the selected groups to handle the top_k routing
348-
TVM_FFI_ICHECK_LT(top_k, (topk_group * num_experts / n_group))
375+
TVM_FFI_ICHECK_LT(top_k, (topk_group.value() * num_experts / n_group.value()))
349376
<< "top_k must be less than total number of experts in selected groups";
377+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
378+
RoutingMethodType::Renormalize ||
379+
static_cast<RoutingMethodType>(routing_method_type) ==
380+
RoutingMethodType::RenormalizeNaive) {
381+
TVM_FFI_ICHECK(top_k <= 10 && top_k > 0)
382+
<< "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0.";
383+
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
384+
RoutingMethodType::Llama4) {
385+
TVM_FFI_ICHECK_EQ(top_k, 1)
386+
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";
350387
}
351388
TVM_FFI_ICHECK_EQ(num_experts % 4, 0)
352389
<< "Routing kernel expects that num_experts must be divisible by 4";
@@ -383,11 +420,11 @@ void trtllm_fp8_block_scale_moe_launcher(
383420
args.hidden_size = hidden_states->shape[1];
384421
args.hidden_size_output = args.hidden_size;
385422
args.top_k = top_k;
386-
args.n_group = n_group;
387-
args.topk_group = topk_group;
423+
args.n_group = n_group.has_value() ? n_group.value() : 0;
424+
args.topk_group = topk_group.has_value() ? topk_group.value() : 0;
388425
args.local_expert_offset = local_expert_offset;
389426
args.local_num_experts = local_num_experts;
390-
args.routed_scaling_factor = routed_scaling_factor;
427+
args.routed_scaling_factor = routed_scaling_factor.has_value() ? routed_scaling_factor.value() : 1.0;
391428
args.intermediate_size = intermediate_size;
392429
args.mUseDeepSeekFp8 = true;
393430

@@ -573,9 +610,9 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional<TensorView>
573610
TensorView gemm1_weights, TensorView gemm1_weights_scale,
574611
TensorView gemm2_weights, TensorView gemm2_weights_scale,
575612
TensorView output, int64_t num_experts, int64_t top_k,
576-
int64_t n_group, int64_t topk_group, int64_t intermediate_size,
613+
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
577614
int64_t local_expert_offset, int64_t local_num_experts,
578-
double routed_scaling_factor, int64_t tile_tokens_dim,
615+
Optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
579616
int64_t routing_method_type, bool use_shuffled_weight,
580617
int64_t weight_layout, bool enable_pdl) {
581618
auto dtype = hidden_states->dtype;
@@ -696,8 +733,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
696733
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
697734
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
698735
<< "num_experts must be divisible by n_group";
699-
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
700-
<< "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
736+
// TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
737+
// << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
701738
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
702739
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
703740
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
@@ -710,9 +747,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
710747
static_cast<RoutingMethodType>(routing_method_type) ==
711748
RoutingMethodType::RenormalizeNaive ||
712749
static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::TopK) {
713-
TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
714-
<< "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
715-
"top_k>0.";
750+
// TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
751+
// << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
752+
// "top_k>0.";
716753
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
717754
TVM_FFI_ICHECK_EQ(top_k, 1)
718755
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";

0 commit comments

Comments
 (0)