feat: Expose unpacked topk weights for routed moe (fp4)#2425
feat: Expose unpacked topk weights for routed moe (fp4)#2425aleozlx wants to merge 21 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a RoutingInputMode enum and propagate multi-mode routing support end-to-end: introduces packed/unpacked precomputed pathways, replaces expert_* buffers with topk_ids/topk_weights across CUDA launchers, Runner APIs, and Python bindings, and updates tests to validate both routing formats. Changes
Sequence Diagram(s)sequenceDiagram
participant Python as Python API
participant Launcher as FP4BlockScaleLauncher
participant Runner as Routing::Runner
participant CUDA as CUDA Kernels
Python->>Launcher: call(routing_input_mode, routing logits/topk_ids/topk_weights)
Launcher->>Runner: prepare routingData (select buffers by mode)
alt FromLogits
Runner->>CUDA: compute topk_ids/topk_weights from logits
else PackedPrecomputed
Runner->>CUDA: unpack packed topk (id+score) or use packed ids -> produce topk_weights
else UnpackedPrecomputed
Runner->>CUDA: consume provided topk_ids and topk_weights
end
CUDA-->>Runner: routing/topk results
Runner-->>Launcher: routing metadata
Launcher-->>Python: return outputs (including topk_weights)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility of the fused Mixture of Experts (MoE) kernel by enabling it to process pre-computed routing decisions. This allows for scenarios where expert selection and weighting are determined externally, rather than solely relying on the kernel's internal routing logits computation. The change facilitates integration with advanced routing strategies and optimizes the MoE execution pipeline by providing direct control over expert assignments. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces support for pre-computed routing weights and expert IDs in the FP4 block scale MoE kernels. The changes are consistently applied across the C++ and Python code, updating function signatures, call sites, and internal logic to accommodate the new parameters. The accompanying documentation updates in the Python file are clear and helpful. The implementation appears to be functionally sound and well-integrated.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1783-1799:⚠️ Potential issue | 🟡 MinorValidate
routing_input_modeparameter.The
routing_input_modeparameter is cast directly toRoutingInputModeenum at Line 1895 without validation. Invalid values (not 0, 1, or 2) would cause undefined behavior in the launcher's switch statement.🛡️ Proposed fix
Array<Tensor> trtllm_fp4_block_scale_moe( int64_t routing_input_mode, Optional<TensorView> routing_logits, TensorView topk_ids, TensorView topk_weights, Optional<TensorView> routing_bias, TensorView hidden_states, // ... remaining parameters ... TensorView output, Array<int64_t> config_index) { + // Validate routing_input_mode + TVM_FFI_ICHECK(routing_input_mode >= 0 && routing_input_mode <= 2) + << "routing_input_mode must be 0 (FromLogits), 1 (PackedPrecomputed), or 2 (UnpackedPrecomputed)."; + // Determine data types based on input format
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1445-1463: The switch over routing_input_mode_ in the
trtllm_fp4_block_scale_moe kernel launcher can leave expert_ids_param and
expert_weights_param uninitialized for invalid enum values; add a default case
to the switch that sets expert_ids_param = nullptr and expert_weights_param =
nullptr, emits an error (or uses an assertion) indicating an invalid
RoutingInputMode, and returns or aborts early from trtllm_fp4_block_scale_moe to
avoid undefined behavior; reference the switch on routing_input_mode_, and the
variables expert_ids_param and expert_weights_param when making the change.
| switch (routing_input_mode_) { | ||
| case RoutingInputMode::FromLogits: | ||
| // Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT) | ||
| expert_ids_param = nullptr; | ||
| expert_weights_param = topk_weights.data_ptr(); | ||
| break; | ||
|
|
||
| case RoutingInputMode::PackedPrecomputed: | ||
| // Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT) | ||
| expert_ids_param = nullptr; | ||
| expert_weights_param = topk_weights.data_ptr(); | ||
| break; | ||
|
|
||
| case RoutingInputMode::UnpackedPrecomputed: | ||
| // Mode 3: Both are INPUTS, kernel uses them directly | ||
| expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr()); | ||
| expert_weights_param = topk_weights.data_ptr(); | ||
| break; | ||
| } |
There was a problem hiding this comment.
Add a default case to handle invalid routing modes.
The switch statement lacks a default case. If routing_input_mode_ contains an invalid value (e.g., due to an unchecked cast from the int64_t parameter in trtllm_fp4_block_scale_moe), this results in undefined behavior with uninitialized expert_ids_param and expert_weights_param.
🛡️ Proposed fix
case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
+
+ default:
+ TVM_FFI_LOG_AND_THROW(ValueError)
+ << "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| switch (routing_input_mode_) { | |
| case RoutingInputMode::FromLogits: | |
| // Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT) | |
| expert_ids_param = nullptr; | |
| expert_weights_param = topk_weights.data_ptr(); | |
| break; | |
| case RoutingInputMode::PackedPrecomputed: | |
| // Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT) | |
| expert_ids_param = nullptr; | |
| expert_weights_param = topk_weights.data_ptr(); | |
| break; | |
| case RoutingInputMode::UnpackedPrecomputed: | |
| // Mode 3: Both are INPUTS, kernel uses them directly | |
| expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr()); | |
| expert_weights_param = topk_weights.data_ptr(); | |
| break; | |
| } | |
| switch (routing_input_mode_) { | |
| case RoutingInputMode::FromLogits: | |
| // Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT) | |
| expert_ids_param = nullptr; | |
| expert_weights_param = topk_weights.data_ptr(); | |
| break; | |
| case RoutingInputMode::PackedPrecomputed: | |
| // Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT) | |
| expert_ids_param = nullptr; | |
| expert_weights_param = topk_weights.data_ptr(); | |
| break; | |
| case RoutingInputMode::UnpackedPrecomputed: | |
| // Mode 3: Both are INPUTS, kernel uses them directly | |
| expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr()); | |
| expert_weights_param = topk_weights.data_ptr(); | |
| break; | |
| default: | |
| TVM_FFI_LOG_AND_THROW(ValueError) | |
| << "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_); | |
| } |
🤖 Prompt for AI Agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1445 - 1463, The
switch over routing_input_mode_ in the trtllm_fp4_block_scale_moe kernel
launcher can leave expert_ids_param and expert_weights_param uninitialized for
invalid enum values; add a default case to the switch that sets expert_ids_param
= nullptr and expert_weights_param = nullptr, emits an error (or uses an
assertion) indicating an invalid RoutingInputMode, and returns or aborts early
from trtllm_fp4_block_scale_moe to avoid undefined behavior; reference the
switch on routing_input_mode_, and the variables expert_ids_param and
expert_weights_param when making the change.
|
@claude hello |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
|
@claude can you make a similar PR but for fp8 routed_moe? |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
|
That last fix (dafc67d) passed all tests |
|
@flashinfer-bot run |
|
/bot run |
|
[FAILED] Pipeline #43153719: 6/20 passed |
|
/bot run |
|
posting latest test result again for good measure |
|
[FAILED] Pipeline #43845118: 14/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)
1993-2026:⚠️ Potential issue | 🟡 MinorSuppress ARG001 in the fake op signature.
Ruff flags unused args in
_fake_trtllm_fp4_block_scale_moe. A single# noqa: ARG001on the def line keeps the stub clean without renaming parameters.🧹 Suggested fix
-def _fake_trtllm_fp4_block_scale_moe( +def _fake_trtllm_fp4_block_scale_moe( # noqa: ARG001 routing_input_mode: int, routing_logits: torch.Tensor, topk_ids: Optional[torch.Tensor], topk_weights: Optional[torch.Tensor],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1993 - 2026, Ruff flags unused arguments in the stub for _fake_trtllm_fp4_block_scale_moe; add a "# noqa: ARG001" comment to the function definition line of _fake_trtllm_fp4_block_scale_moe to suppress the unused-argument warning instead of renaming parameters so the signature stays intact and the linter is satisfied.
1055-1104:⚠️ Potential issue | 🟡 MinorSilence unused
topk_weightsunpacking to keep Ruff clean.
topk_weightsis unpacked but unused in bothget_valid_tacticsandforward, triggering RUF059. Rename it to_topk_weights(or_) in both spots.🧹 Suggested fix
@@ - ( - output, - routing_logits, - topk_ids, - topk_weights, - hidden_states, - *extra_inputs, - ) = inputs + ( + output, + routing_logits, + topk_ids, + _topk_weights, + hidden_states, + *extra_inputs, + ) = inputs @@ - ( - output, - routing_logits, - topk_ids, - topk_weights, - hidden_states, - *extra_inputs, - ) = inputs + ( + output, + routing_logits, + topk_ids, + _topk_weights, + hidden_states, + *extra_inputs, + ) = inputs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1055 - 1104, In both MoERunner.get_valid_tactics and MoERunner.forward, rename the unpacked variable topk_weights to a throwaway name (e.g., _topk_weights or _) where inputs are destructured—specifically in the tuple unpacking lines inside get_valid_tactics and forward that currently read "(output, routing_logits, topk_ids, topk_weights, hidden_states, *extra_inputs,)"—so the unused value is ignored and the RUF059 lint warning is resolved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1842-1861: When routing_input_mode ==
RoutingInputMode.UnpackedPrecomputed ensure user-provided topk_weights (and
topk_ids) are validated before passing to the kernel: check topk_weights.ndim ==
2 and topk_weights.shape == (num_tokens, top_k) and topk_weights.dtype ==
routing_dtype (and check topk_ids.ndim == 2, topk_ids.shape == (num_tokens,
top_k), topk_ids.dtype == torch.int32), raising a clear AssertionError if
mismatched; apply the identical shape/dtype checks in
trtllm_fp4_block_scale_routed_moe where the packed tuple is unpacked so both
call sites enforce the same validation.
---
Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1993-2026: Ruff flags unused arguments in the stub for
_fake_trtllm_fp4_block_scale_moe; add a "# noqa: ARG001" comment to the function
definition line of _fake_trtllm_fp4_block_scale_moe to suppress the
unused-argument warning instead of renaming parameters so the signature stays
intact and the linter is satisfied.
- Around line 1055-1104: In both MoERunner.get_valid_tactics and
MoERunner.forward, rename the unpacked variable topk_weights to a throwaway name
(e.g., _topk_weights or _) where inputs are destructured—specifically in the
tuple unpacking lines inside get_valid_tactics and forward that currently read
"(output, routing_logits, topk_ids, topk_weights, hidden_states,
*extra_inputs,)"—so the unused value is ignored and the RUF059 lint warning is
resolved.
---
Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1536-1554: The switch over routing_input_mode_ must include a
default branch to avoid leaving expert_ids_param and expert_weights_param
uninitialized; add a default case in the switch that sets expert_ids_param =
nullptr and expert_weights_param = nullptr and then throws a descriptive
exception (e.g., std::invalid_argument or runtime_error) that includes the
invalid routing_input_mode_ value so callers can diagnose the bad enum. Ensure
the symbols routing_input_mode_, expert_ids_param, and expert_weights_param are
referenced exactly as in the diff.
| # workspace buffers required by trtllm-gen | ||
| if topk_ids is None: | ||
| topk_ids = torch.empty( | ||
| num_tokens, top_k, dtype=torch.int32, device=hidden_states.device | ||
| # For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS | ||
| if routing_input_mode == RoutingInputMode.UnpackedPrecomputed: | ||
| assert topk_ids is not None, ( | ||
| "topk_ids must be provided for UnpackedPrecomputed mode" | ||
| ) | ||
| if expert_weights is None: | ||
| expert_weights = torch.empty( | ||
| num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device | ||
| assert topk_weights is not None, ( | ||
| "topk_weights must be provided for UnpackedPrecomputed mode" | ||
| ) | ||
| else: | ||
| # For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers | ||
| if topk_ids is None: | ||
| topk_ids = torch.empty( | ||
| num_tokens, top_k, dtype=torch.int32, device=hidden_states.device | ||
| ) | ||
| if topk_weights is None: | ||
| topk_weights = torch.empty( | ||
| num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device | ||
| ) | ||
| if enable_pdl is None: |
There was a problem hiding this comment.
Validate UnpackedPrecomputed topk_weights shape/dtype.
When routing_input_mode is UnpackedPrecomputed, user-provided topk_weights is passed straight to the kernel. Add explicit shape/dtype checks to prevent silent misinterpretation (and apply the same validation at Line 2788 in trtllm_fp4_block_scale_routed_moe, where the tuple is unpacked).
✅ Suggested checks
if routing_input_mode == RoutingInputMode.UnpackedPrecomputed:
assert topk_ids is not None, (
"topk_ids must be provided for UnpackedPrecomputed mode"
)
assert topk_weights is not None, (
"topk_weights must be provided for UnpackedPrecomputed mode"
)
+ assert topk_ids.shape == (num_tokens, top_k), (
+ f"topk_ids must be shape ({num_tokens}, {top_k})"
+ )
+ assert topk_weights.shape == (num_tokens, top_k), (
+ f"topk_weights must be shape ({num_tokens}, {top_k})"
+ )
+ assert topk_weights.dtype == routing_dtype, (
+ f"topk_weights must be {routing_dtype}"
+ )
else:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # workspace buffers required by trtllm-gen | |
| if topk_ids is None: | |
| topk_ids = torch.empty( | |
| num_tokens, top_k, dtype=torch.int32, device=hidden_states.device | |
| # For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS | |
| if routing_input_mode == RoutingInputMode.UnpackedPrecomputed: | |
| assert topk_ids is not None, ( | |
| "topk_ids must be provided for UnpackedPrecomputed mode" | |
| ) | |
| if expert_weights is None: | |
| expert_weights = torch.empty( | |
| num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device | |
| assert topk_weights is not None, ( | |
| "topk_weights must be provided for UnpackedPrecomputed mode" | |
| ) | |
| else: | |
| # For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers | |
| if topk_ids is None: | |
| topk_ids = torch.empty( | |
| num_tokens, top_k, dtype=torch.int32, device=hidden_states.device | |
| ) | |
| if topk_weights is None: | |
| topk_weights = torch.empty( | |
| num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device | |
| ) | |
| if enable_pdl is None: | |
| # workspace buffers required by trtllm-gen | |
| # For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS | |
| if routing_input_mode == RoutingInputMode.UnpackedPrecomputed: | |
| assert topk_ids is not None, ( | |
| "topk_ids must be provided for UnpackedPrecomputed mode" | |
| ) | |
| assert topk_weights is not None, ( | |
| "topk_weights must be provided for UnpackedPrecomputed mode" | |
| ) | |
| assert topk_ids.shape == (num_tokens, top_k), ( | |
| f"topk_ids must be shape ({num_tokens}, {top_k})" | |
| ) | |
| assert topk_weights.shape == (num_tokens, top_k), ( | |
| f"topk_weights must be shape ({num_tokens}, {top_k})" | |
| ) | |
| assert topk_weights.dtype == routing_dtype, ( | |
| f"topk_weights must be {routing_dtype}" | |
| ) | |
| else: | |
| # For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers | |
| if topk_ids is None: | |
| topk_ids = torch.empty( | |
| num_tokens, top_k, dtype=torch.int32, device=hidden_states.device | |
| ) | |
| if topk_weights is None: | |
| topk_weights = torch.empty( | |
| num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device | |
| ) | |
| if enable_pdl is None: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/core.py` around lines 1842 - 1861, When
routing_input_mode == RoutingInputMode.UnpackedPrecomputed ensure user-provided
topk_weights (and topk_ids) are validated before passing to the kernel: check
topk_weights.ndim == 2 and topk_weights.shape == (num_tokens, top_k) and
topk_weights.dtype == routing_dtype (and check topk_ids.ndim == 2,
topk_ids.shape == (num_tokens, top_k), topk_ids.dtype == torch.int32), raising a
clear AssertionError if mismatched; apply the identical shape/dtype checks in
trtllm_fp4_block_scale_routed_moe where the packed tuple is unpacked so both
call sites enforce the same validation.
|
/bot run |
|
resolved conflicts again, removing |
|
@IwakuraRein could you give an approval for code review? (as i can't give approval to my own PR) then i'll tag others who have permission to merge after new testing decoupling and trying to unstick this PR that's been here for a while... |
|
[FAILED] Pipeline #44679293: 9/20 passed |
📌 Description
Summary
Add support for pre-computed routing with unpacked format in trtllm_fp4_block_scale_routed_moe, and
improve code clarity with explicit routing mode enum.
Changes
New Feature: Unpacked Pre-computed Routing (Mode 3)
trtllm_fp4_block_scale_routed_moe now accepts routing input in two formats:
Packed format (existing behavior)
trtllm_fp4_block_scale_routed_moe(packed_tensor, ...)
Unpacked format (new - pass tuple)
trtllm_fp4_block_scale_routed_moe((topk_ids, topk_weights), ...)
This is a backwards-compatible enhancement - existing code continues to work unchanged.
Code Clarity Improvements
fc1_expert_weights unchanged)
Files Modified
Test plan
pytest tests/moe/test_trtllm_gen_routed_fused_moe.py -v
🤖 Generated with https://claude.ai/code (fully reviewed)
🔍 Related Issues
#2373
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Public API
Tests
Documentation