Skip to content

Commit 979b759

Browse files
ChristinaZjiahanc
andauthored
Update the routing for TRTLLMGEN to support kimi k2 and qwen (#1831)
<!-- .github/pull_request_template.md --> ## 📌 Description Update the routing code to align with the implementation in TRTLLM and add support for KIMI K2 and Qwen Also revised the unit test based on the config of kimi k2 (https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/config.json) ## 🔍 Related Issues ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * MoE operations now support optional routing parameters with automatic defaults for streamlined model configuration. * **Refactor** * Optimized expert kernel routing and buffer management for improved flexibility across multiple routing strategies. * Enhanced top-K result handling with unified buffer interfaces. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jiahanc <[email protected]> Co-authored-by: jiahanc <[email protected]>
1 parent ffcc5f4 commit 979b759

File tree

11 files changed

+1036
-711
lines changed

11 files changed

+1036
-711
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 87 additions & 51 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_deepseek.cu

Lines changed: 253 additions & 165 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_llama4.cu

Lines changed: 143 additions & 84 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 239 additions & 83 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_runner.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
7070
routingData.mUsePdl = true;
7171

7272
// output:
73-
routingData.mPtrExpertIdx = routingExpertIndexes;
73+
routingData.mPtrTopKPacked = routingExpertIndexes;
7474
routingData.mPtrExpertCounts = expertCountHistogram;
7575
routingData.mPtrPermutedIdxSize = permutedIdxSize;
7676
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
7777
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
78-
routingData.mPtrExpertWeights = expertWeights;
78+
routingData.mPtrTopKWeights = expertWeights;
7979

8080
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
8181
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
@@ -107,12 +107,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
107107
routingData.mUsePdl = true;
108108

109109
// output:
110-
routingData.mPtrExpertIdx = routingExpertIndexes;
110+
routingData.mPtrTopKPacked = routingExpertIndexes;
111111
routingData.mPtrExpertCounts = expertCountHistogram;
112112
routingData.mPtrPermutedIdxSize = permutedIdxSize;
113113
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
114114
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
115-
routingData.mPtrExpertWeights = expertWeights;
115+
routingData.mPtrTopKWeights = expertWeights;
116116

117117
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
118118
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
@@ -149,12 +149,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
149149
//
150150
// Outputs
151151
//
152-
routingData.mPtrExpertIdx = routingExpertIndexes;
152+
routingData.mPtrTopKPacked = routingExpertIndexes;
153153
routingData.mPtrExpertCounts = expertCountHistogram;
154154
routingData.mPtrPermutedIdxSize = permutedIdxSize;
155155
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
156156
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
157-
routingData.mPtrExpertWeights = expertWeights;
157+
routingData.mPtrTopKWeights = expertWeights;
158158

159159
//
160160
// Grouped Gemm Launch Config Buffers

flashinfer/fused_moe/core.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def _maybe_get_cached_w3_w1_permute_indices(
184184
epilogue_tile_m: int,
185185
num_elts_per_sf: Union[None, int] = None,
186186
) -> torch.Tensor:
187-
if dst_w3_w1_weight.shape not in _cache_permute_indices:
187+
# Create a unique cache key (weight_type, weight_shape)
188+
cache_key = ("w3_w1", dst_w3_w1_weight.shape)
189+
if cache_key not in _cache_permute_indices:
188190
# Get permute indices and chain them together
189191
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight)
190192
if num_elts_per_sf is None:
@@ -198,10 +200,10 @@ def _maybe_get_cached_w3_w1_permute_indices(
198200
num_elts_per_sf=num_elts_per_sf,
199201
)
200202
# Memoize permute indices as recompute is **very** costly
201-
_cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to(
203+
_cache_permute_indices[cache_key] = permute0[permute1].to(
202204
dst_w3_w1_weight.device
203205
)
204-
permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape]
206+
permute_indices = _cache_permute_indices[cache_key]
205207
return permute_indices
206208

207209

@@ -211,7 +213,9 @@ def get_w2_permute_indices_with_cache(
211213
epilogue_tile_m: int,
212214
num_elts_per_sf: Union[None, int] = None,
213215
) -> torch.Tensor:
214-
if dst_w2_weight.shape not in _cache_permute_indices:
216+
# Create a unique cache key (weight_type, weight_shape)
217+
cache_key = ("w2", dst_w2_weight.shape)
218+
if cache_key not in _cache_permute_indices:
215219
if num_elts_per_sf is None:
216220
permute_indices = get_shuffle_matrix_a_row_indices(
217221
dst_w2_weight, epilogue_tile_m
@@ -223,8 +227,8 @@ def get_w2_permute_indices_with_cache(
223227
num_elts_per_sf=num_elts_per_sf,
224228
).to(dst_w2_weight.device)
225229
# Memoize permute indices as recompute is **very** costly
226-
_cache_permute_indices[dst_w2_weight.shape] = permute_indices
227-
permute_indices = _cache_permute_indices[dst_w2_weight.shape]
230+
_cache_permute_indices[cache_key] = permute_indices
231+
permute_indices = _cache_permute_indices[cache_key]
228232
return permute_indices
229233

230234

@@ -1097,12 +1101,12 @@ def trtllm_fp8_per_tensor_scale_moe_op(
10971101
output2_scales_scalar: torch.Tensor,
10981102
num_experts: int,
10991103
top_k: int,
1100-
n_group: int,
1101-
topk_group: int,
1104+
n_group: Optional[int],
1105+
topk_group: Optional[int],
11021106
intermediate_size: int,
11031107
local_expert_offset: int,
11041108
local_num_experts: int,
1105-
routed_scaling_factor: float,
1109+
routed_scaling_factor: Optional[float],
11061110
use_routing_scales_on_input: bool,
11071111
tile_tokens_dim: int = 8,
11081112
routing_method_type: int = 0,
@@ -1151,12 +1155,12 @@ def _fake_trtllm_fp8_per_tensor_scale_moe(
11511155
output2_scales_scalar: torch.Tensor,
11521156
num_experts: int,
11531157
top_k: int,
1154-
n_group: int,
1155-
topk_group: int,
1158+
n_group: Optional[int],
1159+
topk_group: Optional[int],
11561160
intermediate_size: int,
11571161
local_expert_offset: int,
11581162
local_num_experts: int,
1159-
routed_scaling_factor: float,
1163+
routed_scaling_factor: Optional[float],
11601164
use_routing_scales_on_input: bool,
11611165
tile_tokens_dim: int = 8,
11621166
routing_method_type: int = 0,
@@ -1183,12 +1187,12 @@ def trtllm_fp8_block_scale_moe_op(
11831187
output: torch.Tensor,
11841188
num_experts: int,
11851189
top_k: int,
1186-
n_group: int,
1187-
topk_group: int,
1190+
n_group: Optional[int],
1191+
topk_group: Optional[int],
11881192
intermediate_size: int,
11891193
local_expert_offset: int,
11901194
local_num_experts: int,
1191-
routed_scaling_factor: float,
1195+
routed_scaling_factor: Optional[float],
11921196
tile_tokens_dim: int,
11931197
routing_method_type: int,
11941198
use_shuffled_weight: bool = False,
@@ -1197,6 +1201,7 @@ def trtllm_fp8_block_scale_moe_op(
11971201
) -> torch.Tensor:
11981202
if enable_pdl is None:
11991203
enable_pdl = device_support_pdl(hidden_states.device)
1204+
12001205
# Call the C++ function for block scale MoE
12011206
moe_op.trtllm_fp8_block_scale_moe(
12021207
routing_logits,
@@ -1238,12 +1243,12 @@ def _fake_trtllm_fp8_block_scale_moe(
12381243
output: torch.Tensor,
12391244
num_experts: int,
12401245
top_k: int,
1241-
n_group: int,
1242-
topk_group: int,
1246+
n_group: Optional[int],
1247+
topk_group: Optional[int],
12431248
intermediate_size: int,
12441249
local_expert_offset: int,
12451250
local_num_experts: int,
1246-
routed_scaling_factor: float,
1251+
routed_scaling_factor: Optional[float],
12471252
tile_tokens_dim: int = 8,
12481253
routing_method_type: int = 0,
12491254
use_shuffled_weight: bool = False,
@@ -1503,12 +1508,12 @@ def trtllm_fp8_per_tensor_scale_moe(
15031508
output2_scales_scalar: torch.Tensor,
15041509
num_experts: int,
15051510
top_k: int,
1506-
n_group: int,
1507-
topk_group: int,
1511+
n_group: Optional[int],
1512+
topk_group: Optional[int],
15081513
intermediate_size: int,
15091514
local_expert_offset: int,
15101515
local_num_experts: int,
1511-
routed_scaling_factor: float,
1516+
routed_scaling_factor: Optional[float],
15121517
use_routing_scales_on_input: bool,
15131518
tile_tokens_dim: int = 8,
15141519
routing_method_type: int = 0,
@@ -1576,12 +1581,12 @@ def trtllm_fp8_block_scale_moe(
15761581
gemm2_weights_scale: torch.Tensor,
15771582
num_experts: int,
15781583
top_k: int,
1579-
n_group: int,
1580-
topk_group: int,
1584+
n_group: Optional[int],
1585+
topk_group: Optional[int],
15811586
intermediate_size: int,
15821587
local_expert_offset: int,
15831588
local_num_experts: int,
1584-
routed_scaling_factor: float,
1589+
routed_scaling_factor: Optional[float],
15851590
tile_tokens_dim: int = 8,
15861591
routing_method_type: int = 0,
15871592
use_shuffled_weight: bool = False,

0 commit comments

Comments
 (0)