Skip to content

feat: Expose unpacked topk weights for routed moe (fp4)#2425

Open
aleozlx wants to merge 21 commits intoflashinfer-ai:mainfrom
aleozlx:topk_weights
Open

feat: Expose unpacked topk weights for routed moe (fp4)#2425
aleozlx wants to merge 21 commits intoflashinfer-ai:mainfrom
aleozlx:topk_weights

Conversation

@aleozlx
Copy link
Collaborator

@aleozlx aleozlx commented Jan 27, 2026

📌 Description

Summary

Add support for pre-computed routing with unpacked format in trtllm_fp4_block_scale_routed_moe, and
improve code clarity with explicit routing mode enum.

Changes

New Feature: Unpacked Pre-computed Routing (Mode 3)

trtllm_fp4_block_scale_routed_moe now accepts routing input in two formats:

Packed format (existing behavior)

trtllm_fp4_block_scale_routed_moe(packed_tensor, ...)

Unpacked format (new - pass tuple)

trtllm_fp4_block_scale_routed_moe((topk_ids, topk_weights), ...)

This is a backwards-compatible enhancement - existing code continues to work unchanged.

Code Clarity Improvements

  • Added RoutingInputMode enum in both C++ and Python to explicitly define three routing modes:
    • FromLogits: Compute routing from logits (Mode 1)
    • PackedPrecomputed: Pre-computed with packed (score << 16 | id) format (Mode 2)
    • UnpackedPrecomputed: Pre-computed with separate tensors (Mode 3)
  • Renamed expert_weights → topk_weights in public APIs for consistency (MLP weights like
    fc1_expert_weights unchanged)
  • Added routing_input_mode as first parameter to internal C++ function for explicit mode selection

Files Modified

  • flashinfer/fused_moe/core.py - Python API with RoutingInputMode enum and tuple support
  • csrc/trtllm_fused_moe_kernel_launcher.cu - C++ launcher with RoutingInputMode enum
  • tests/moe/test_trtllm_gen_routed_fused_moe.py - Extended tests for both formats
  • INTEGRATION_STATUS.md - Documentation (to be removed after merge)

Test plan

  • test_trtllm_gen_fused_moe.py - All passed (Mode 1: FromLogits)
  • test_trtllm_gen_routed_fused_moe.py with routing_format=packed - All passed (Mode 2)
  • test_trtllm_gen_routed_fused_moe.py with routing_format=unpacked - All passed (Mode 3)

pytest tests/moe/test_trtllm_gen_routed_fused_moe.py -v

🤖 Generated with https://claude.ai/code (fully reviewed)

🔍 Related Issues

#2373

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).
$ pytest tests/moe/test_trtllm_gen_routed_fused_moe.py
3456 passed in 336.66s (0:05:36)

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a RoutingInputMode with three modes and end-to-end support for packed or unpacked precomputed routing inputs; intermediate routed outputs now use topk_weights.
  • Public API

    • Core APIs now accept precomputed routing IDs/weights and an explicit routing mode to enable mode-aware routing paths.
  • Tests

    • Parameterized tests added for "packed" and "unpacked" routing formats.
  • Documentation

    • Docstrings and return descriptions updated to reflect new routing formats and topk_weights semantics.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a RoutingInputMode enum and propagate multi-mode routing support end-to-end: introduces packed/unpacked precomputed pathways, replaces expert_* buffers with topk_ids/topk_weights across CUDA launchers, Runner APIs, and Python bindings, and updates tests to validate both routing formats.

Changes

Cohort / File(s) Summary
CUDA Launchers & Kernels
csrc/trtllm_fused_moe_kernel_launcher.cu
Add RoutingInputMode; FP4/FP8/BF16/MXInt4 launchers accept and store mode, topk_ids/topk_weights; routing buffer selection and init updated for FromLogits/PackedPrecomputed/UnpackedPrecomputed.
Runner Implementation
csrc/trtllm_fused_moe_runner.cu, include/flashinfer/trtllm/fused_moe/runner.h
Runner::run signature gains int32_t* expertIds; routingData initialization conditionally uses expertIds (topk ids) vs routing logits and sets topk buffers accordingly.
Python Core / Bindings
flashinfer/fused_moe/core.py
Add RoutingInputMode enum; propagate routing_input_mode through ops; replace expert_weights with topk_weights; accept packed (Tensor) or unpacked ((ids, weights)) routing inputs and pass appropriate buffers to launchers.
API Entry Points & Launcher Hooks
csrc/...trtllm_fp4_block_scale_moe*, .../FP4BlockScaleLauncher, ...getValidConfigs
Signatures updated to include routing_input_mode, topk_ids, topk_weights; config selection and launcher creation become mode-aware across data types and precision paths.
Tests
tests/moe/test_trtllm_gen_routed_fused_moe.py
Parametrize tests for routing_format ("packed","unpacked"); construct routing_input as packed or unpacked; rename expert_weights→topk_weights and validate both precomputed formats for FP4/FP8 routed paths.

Sequence Diagram(s)

sequenceDiagram
    participant Python as Python API
    participant Launcher as FP4BlockScaleLauncher
    participant Runner as Routing::Runner
    participant CUDA as CUDA Kernels

    Python->>Launcher: call(routing_input_mode, routing logits/topk_ids/topk_weights)
    Launcher->>Runner: prepare routingData (select buffers by mode)

    alt FromLogits
        Runner->>CUDA: compute topk_ids/topk_weights from logits
    else PackedPrecomputed
        Runner->>CUDA: unpack packed topk (id+score) or use packed ids -> produce topk_weights
    else UnpackedPrecomputed
        Runner->>CUDA: consume provided topk_ids and topk_weights
    end

    CUDA-->>Runner: routing/topk results
    Runner-->>Launcher: routing metadata
    Launcher-->>Python: return outputs (including topk_weights)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • djmmoss
  • yzh119
  • jiahanc
  • bkryu
  • nvmbreughe

Poem

🐇 I hopped through code with whiskers twitching bright,
Three routing modes now guide the flight,
IDs and weights in packed or free,
Top-k paths hum — a rabbit's glee,
Small hops, big routes — the kernels ignite.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature being introduced: adding support for unpacked topk weights in routed MOE for fp4 operations.
Description check ✅ Passed The PR description comprehensively covers the changes, includes related issues, test results, and completion of the pre-commit checklist and testing requirements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility of the fused Mixture of Experts (MoE) kernel by enabling it to process pre-computed routing decisions. This allows for scenarios where expert selection and weighting are determined externally, rather than solely relying on the kernel's internal routing logits computation. The change facilitates integration with advanced routing strategies and optimizes the MoE execution pipeline by providing direct control over expert assignments.

Highlights

  • Pre-computed Routing Support: Introduced a new mode for the fused Mixture of Experts (MoE) kernel to accept pre-computed routing information (expert IDs and weights) directly, bypassing the need to compute routing from raw logits.
  • Python API Update: The trtllm_fp4_block_scale_routed_moe Python function now explicitly accepts topk_weights as a separate input, alongside topk_ids, clarifying the input structure for pre-computed routing.
  • C++ Kernel Interface Enhancement: The underlying C++ Runner::run function has been updated to include an expertIds parameter, enabling the kernel to utilize pre-determined expert assignments.
  • Conditional Routing Logic: Modified the C++ routing execution logic to dynamically choose between computing routing from provided logits or using the newly introduced expertIds for pre-computed routing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces support for pre-computed routing weights and expert IDs in the FP4 block scale MoE kernels. The changes are consistently applied across the C++ and Python code, updating function signatures, call sites, and internal logic to accommodate the new parameters. The accompanying documentation updates in the Python file are clear and helpful. The implementation appears to be functionally sound and well-integrated.

@aleozlx aleozlx changed the title [wip] Topk weights feat: Expose unpacked topk weights for routed moe (fp4) Feb 3, 2026
@aleozlx aleozlx marked this pull request as ready for review February 3, 2026 02:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1783-1799: ⚠️ Potential issue | 🟡 Minor

Validate routing_input_mode parameter.

The routing_input_mode parameter is cast directly to RoutingInputMode enum at Line 1895 without validation. Invalid values (not 0, 1, or 2) would cause undefined behavior in the launcher's switch statement.

🛡️ Proposed fix
 Array<Tensor> trtllm_fp4_block_scale_moe(
     int64_t routing_input_mode, Optional<TensorView> routing_logits, TensorView topk_ids,
     TensorView topk_weights, Optional<TensorView> routing_bias, TensorView hidden_states,
     // ... remaining parameters ...
     TensorView output, Array<int64_t> config_index) {
+  // Validate routing_input_mode
+  TVM_FFI_ICHECK(routing_input_mode >= 0 && routing_input_mode <= 2)
+      << "routing_input_mode must be 0 (FromLogits), 1 (PackedPrecomputed), or 2 (UnpackedPrecomputed).";
+
   // Determine data types based on input format
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1445-1463: The switch over routing_input_mode_ in the
trtllm_fp4_block_scale_moe kernel launcher can leave expert_ids_param and
expert_weights_param uninitialized for invalid enum values; add a default case
to the switch that sets expert_ids_param = nullptr and expert_weights_param =
nullptr, emits an error (or uses an assertion) indicating an invalid
RoutingInputMode, and returns or aborts early from trtllm_fp4_block_scale_moe to
avoid undefined behavior; reference the switch on routing_input_mode_, and the
variables expert_ids_param and expert_weights_param when making the change.

Comment on lines +1445 to +1463
switch (routing_input_mode_) {
case RoutingInputMode::FromLogits:
// Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;

case RoutingInputMode::PackedPrecomputed:
// Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;

case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a default case to handle invalid routing modes.

The switch statement lacks a default case. If routing_input_mode_ contains an invalid value (e.g., due to an unchecked cast from the int64_t parameter in trtllm_fp4_block_scale_moe), this results in undefined behavior with uninitialized expert_ids_param and expert_weights_param.

🛡️ Proposed fix
       case RoutingInputMode::UnpackedPrecomputed:
         // Mode 3: Both are INPUTS, kernel uses them directly
         expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
         expert_weights_param = topk_weights.data_ptr();
         break;
+
+      default:
+        TVM_FFI_LOG_AND_THROW(ValueError)
+            << "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_);
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
switch (routing_input_mode_) {
case RoutingInputMode::FromLogits:
// Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::PackedPrecomputed:
// Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
}
switch (routing_input_mode_) {
case RoutingInputMode::FromLogits:
// Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::PackedPrecomputed:
// Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
default:
TVM_FFI_LOG_AND_THROW(ValueError)
<< "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_);
}
🤖 Prompt for AI Agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1445 - 1463, The
switch over routing_input_mode_ in the trtllm_fp4_block_scale_moe kernel
launcher can leave expert_ids_param and expert_weights_param uninitialized for
invalid enum values; add a default case to the switch that sets expert_ids_param
= nullptr and expert_weights_param = nullptr, emits an error (or uses an
assertion) indicating an invalid RoutingInputMode, and returns or aborts early
from trtllm_fp4_block_scale_moe to avoid undefined behavior; reference the
switch on routing_input_mode_, and the variables expert_ids_param and
expert_weights_param when making the change.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 3, 2026

@claude hello

@claude
Copy link

claude bot commented Feb 3, 2026

Claude encountered an error —— View job

Command failed: git fetch origin --depth=20 topk_weights

I'll analyze this and get back to you.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 3, 2026

@claude can you make a similar PR but for fp8 routed_moe?

@claude
Copy link

claude bot commented Feb 3, 2026

Claude encountered an error —— View job

Command failed: git fetch origin --depth=20 topk_weights

I'll analyze this and get back to you.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 3, 2026

That last fix (dafc67d) passed all tests

 pytest tests/moe/test_trtllm_gen_routed_fused_moe.py -x
 3488 passed in 363.64s (0:06:03)

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 3, 2026

@flashinfer-bot run

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 3, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !292 has been created, and the CI pipeline #43153719 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43153719: 6/20 passed

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 12, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !292 has been updated with latest changes, and the CI pipeline #43845118 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 12, 2026

posting latest test result again for good measure

$ git log -1
commit a06a0fd23a4b469fc838f23fd1ac084f3f107402 (HEAD -> topk_weights, origin/topk_weights)
Merge: dafc67d9 1d350ae5
Author: Alex Yang <[email protected]>
Date:   Wed Feb 11 18:25:20 2026 -0800

    Merge branch 'main' into topk_weights

$ pytest tests/moe/test_trtllm_gen_routed_fused_moe.py
3488 passed in 954.71s (0:15:54)

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43845118: 14/20 passed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)

1993-2026: ⚠️ Potential issue | 🟡 Minor

Suppress ARG001 in the fake op signature.

Ruff flags unused args in _fake_trtllm_fp4_block_scale_moe. A single # noqa: ARG001 on the def line keeps the stub clean without renaming parameters.

🧹 Suggested fix
-def _fake_trtllm_fp4_block_scale_moe(
+def _fake_trtllm_fp4_block_scale_moe(  # noqa: ARG001
     routing_input_mode: int,
     routing_logits: torch.Tensor,
     topk_ids: Optional[torch.Tensor],
     topk_weights: Optional[torch.Tensor],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1993 - 2026, Ruff flags unused
arguments in the stub for _fake_trtllm_fp4_block_scale_moe; add a "# noqa:
ARG001" comment to the function definition line of
_fake_trtllm_fp4_block_scale_moe to suppress the unused-argument warning instead
of renaming parameters so the signature stays intact and the linter is
satisfied.

1055-1104: ⚠️ Potential issue | 🟡 Minor

Silence unused topk_weights unpacking to keep Ruff clean.

topk_weights is unpacked but unused in both get_valid_tactics and forward, triggering RUF059. Rename it to _topk_weights (or _) in both spots.

🧹 Suggested fix
@@
-            (
-                output,
-                routing_logits,
-                topk_ids,
-                topk_weights,
-                hidden_states,
-                *extra_inputs,
-            ) = inputs
+            (
+                output,
+                routing_logits,
+                topk_ids,
+                _topk_weights,
+                hidden_states,
+                *extra_inputs,
+            ) = inputs
@@
-            (
-                output,
-                routing_logits,
-                topk_ids,
-                topk_weights,
-                hidden_states,
-                *extra_inputs,
-            ) = inputs
+            (
+                output,
+                routing_logits,
+                topk_ids,
+                _topk_weights,
+                hidden_states,
+                *extra_inputs,
+            ) = inputs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1055 - 1104, In both
MoERunner.get_valid_tactics and MoERunner.forward, rename the unpacked variable
topk_weights to a throwaway name (e.g., _topk_weights or _) where inputs are
destructured—specifically in the tuple unpacking lines inside get_valid_tactics
and forward that currently read "(output, routing_logits, topk_ids,
topk_weights, hidden_states, *extra_inputs,)"—so the unused value is ignored and
the RUF059 lint warning is resolved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1842-1861: When routing_input_mode ==
RoutingInputMode.UnpackedPrecomputed ensure user-provided topk_weights (and
topk_ids) are validated before passing to the kernel: check topk_weights.ndim ==
2 and topk_weights.shape == (num_tokens, top_k) and topk_weights.dtype ==
routing_dtype (and check topk_ids.ndim == 2, topk_ids.shape == (num_tokens,
top_k), topk_ids.dtype == torch.int32), raising a clear AssertionError if
mismatched; apply the identical shape/dtype checks in
trtllm_fp4_block_scale_routed_moe where the packed tuple is unpacked so both
call sites enforce the same validation.

---

Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1993-2026: Ruff flags unused arguments in the stub for
_fake_trtllm_fp4_block_scale_moe; add a "# noqa: ARG001" comment to the function
definition line of _fake_trtllm_fp4_block_scale_moe to suppress the
unused-argument warning instead of renaming parameters so the signature stays
intact and the linter is satisfied.
- Around line 1055-1104: In both MoERunner.get_valid_tactics and
MoERunner.forward, rename the unpacked variable topk_weights to a throwaway name
(e.g., _topk_weights or _) where inputs are destructured—specifically in the
tuple unpacking lines inside get_valid_tactics and forward that currently read
"(output, routing_logits, topk_ids, topk_weights, hidden_states,
*extra_inputs,)"—so the unused value is ignored and the RUF059 lint warning is
resolved.

---

Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1536-1554: The switch over routing_input_mode_ must include a
default branch to avoid leaving expert_ids_param and expert_weights_param
uninitialized; add a default case in the switch that sets expert_ids_param =
nullptr and expert_weights_param = nullptr and then throws a descriptive
exception (e.g., std::invalid_argument or runtime_error) that includes the
invalid routing_input_mode_ value so callers can diagnose the bad enum. Ensure
the symbols routing_input_mode_, expert_ids_param, and expert_weights_param are
referenced exactly as in the diff.

Comment on lines 1842 to 1861
# workspace buffers required by trtllm-gen
if topk_ids is None:
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
# For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS
if routing_input_mode == RoutingInputMode.UnpackedPrecomputed:
assert topk_ids is not None, (
"topk_ids must be provided for UnpackedPrecomputed mode"
)
if expert_weights is None:
expert_weights = torch.empty(
num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device
assert topk_weights is not None, (
"topk_weights must be provided for UnpackedPrecomputed mode"
)
else:
# For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers
if topk_ids is None:
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
)
if topk_weights is None:
topk_weights = torch.empty(
num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device
)
if enable_pdl is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate UnpackedPrecomputed topk_weights shape/dtype.

When routing_input_mode is UnpackedPrecomputed, user-provided topk_weights is passed straight to the kernel. Add explicit shape/dtype checks to prevent silent misinterpretation (and apply the same validation at Line 2788 in trtllm_fp4_block_scale_routed_moe, where the tuple is unpacked).

✅ Suggested checks
 if routing_input_mode == RoutingInputMode.UnpackedPrecomputed:
     assert topk_ids is not None, (
         "topk_ids must be provided for UnpackedPrecomputed mode"
     )
     assert topk_weights is not None, (
         "topk_weights must be provided for UnpackedPrecomputed mode"
     )
+    assert topk_ids.shape == (num_tokens, top_k), (
+        f"topk_ids must be shape ({num_tokens}, {top_k})"
+    )
+    assert topk_weights.shape == (num_tokens, top_k), (
+        f"topk_weights must be shape ({num_tokens}, {top_k})"
+    )
+    assert topk_weights.dtype == routing_dtype, (
+        f"topk_weights must be {routing_dtype}"
+    )
 else:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# workspace buffers required by trtllm-gen
if topk_ids is None:
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
# For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS
if routing_input_mode == RoutingInputMode.UnpackedPrecomputed:
assert topk_ids is not None, (
"topk_ids must be provided for UnpackedPrecomputed mode"
)
if expert_weights is None:
expert_weights = torch.empty(
num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device
assert topk_weights is not None, (
"topk_weights must be provided for UnpackedPrecomputed mode"
)
else:
# For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers
if topk_ids is None:
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
)
if topk_weights is None:
topk_weights = torch.empty(
num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device
)
if enable_pdl is None:
# workspace buffers required by trtllm-gen
# For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS
if routing_input_mode == RoutingInputMode.UnpackedPrecomputed:
assert topk_ids is not None, (
"topk_ids must be provided for UnpackedPrecomputed mode"
)
assert topk_weights is not None, (
"topk_weights must be provided for UnpackedPrecomputed mode"
)
assert topk_ids.shape == (num_tokens, top_k), (
f"topk_ids must be shape ({num_tokens}, {top_k})"
)
assert topk_weights.shape == (num_tokens, top_k), (
f"topk_weights must be shape ({num_tokens}, {top_k})"
)
assert topk_weights.dtype == routing_dtype, (
f"topk_weights must be {routing_dtype}"
)
else:
# For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers
if topk_ids is None:
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
)
if topk_weights is None:
topk_weights = torch.empty(
num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device
)
if enable_pdl is None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 1842 - 1861, When
routing_input_mode == RoutingInputMode.UnpackedPrecomputed ensure user-provided
topk_weights (and topk_ids) are validated before passing to the kernel: check
topk_weights.ndim == 2 and topk_weights.shape == (num_tokens, top_k) and
topk_weights.dtype == routing_dtype (and check topk_ids.ndim == 2,
topk_ids.shape == (num_tokens, top_k), topk_ids.dtype == torch.int32), raising a
clear AssertionError if mismatched; apply the identical shape/dtype checks in
trtllm_fp4_block_scale_routed_moe where the packed tuple is unpacked so both
call sites enforce the same validation.

@aleozlx aleozlx added the ready label Feb 20, 2026
@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 24, 2026

/bot run

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 24, 2026

resolved conflicts again, removing ready label pending new test results

@aleozlx aleozlx removed the ready label Feb 24, 2026
@flashinfer-bot
Copy link
Collaborator

GitLab MR !292 has been updated with latest changes, and the CI pipeline #44679293 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Collaborator Author

aleozlx commented Feb 24, 2026

@IwakuraRein could you give an approval for code review? (as i can't give approval to my own PR)

then i'll tag others who have permission to merge after new testing

decoupling and trying to unstick this PR that's been here for a while...

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44679293: 9/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants