Skip to content

fix: add DeepSeek routing for Bf16xBf16 and MxIntxBf16 TRT-LLM Gen MoE#2234

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
nekorobov:nkorobov/bf16-mxint-bf16-moe-ds-routing
Dec 18, 2025
Merged

fix: add DeepSeek routing for Bf16xBf16 and MxIntxBf16 TRT-LLM Gen MoE#2234
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
nekorobov:nkorobov/bf16-mxint-bf16-moe-ds-routing

Conversation

@nekorobov
Copy link
Collaborator

@nekorobov nekorobov commented Dec 17, 2025

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added optional routing scaling factor support to Mixture of Experts operations, enabling more granular control over routing behavior.
    • Made routing bias optional for MXInt4-based MoE operations; previously required, can now be omitted.
  • Tests

    • Expanded test coverage to include additional MoE implementations and routing configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 17, 2025

Walkthrough

Added optional routed_scaling_factor parameter to BF16 MoE operations and relaxed routing_bias validation in MXInt4 MoE path. Changes thread these parameters through the C++ kernel launcher, Python API, and test suite to enhance routing configurability.

Changes

Cohort / File(s) Summary
C++ Kernel Launcher
csrc/trtllm_fused_moe_kernel_launcher.cu
Made routed_scaling_factor optional with default 1.0 in BF16 MoE and updated launcher to populate it in args structure. Relaxed MXInt4 MoE routing_bias validation to accept optional bias with dtype/shape checks instead of unconditional rejection.
Python API
flashinfer/fused_moe/core.py
Added routed_scaling_factor: Optional[float] to BF16 MoE operations (trtllm_bf16_moe_op, trtllm_bf16_moe). Added routing_bias: Optional[torch.Tensor] to MXInt4 block-scale MoE operations (trtllm_mxint4_block_scale_moe_op, trtllm_mxint4_block_scale_moe). Threaded parameters through internal calls and updated docstrings.
Tests
tests/moe/test_trtllm_gen_fused_moe.py
Implemented runtime handling of routing_bias from kwargs and propagated routed_scaling into kernel invocations. Extended parametrizations to include MxInt4BlockScaleMoe and BF16Moe in routing scenarios and updated compatible implementation lists.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Validation logic changes in C++ for optional routing_bias parameter handling
  • Parameter threading complexity across multiple language boundaries (C++ ↔ Python)
  • Test parametrization updates requiring verification of coverage across multiple MoE implementations

Possibly related PRs

Suggested reviewers

  • aleozlx
  • djmmoss
  • cyx-6
  • yongwww
  • wenscarl
  • yzh119

Poem

🐰 Scaling factors hop through the routing maze,
Bias now optional, in flexing ways,
From kernel to Python, parameters align,
MoE grows configurable, oh so fine!
Expert selection gets a gentle hand,
With routing scaled just as planned!

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning PR description is entirely composed of the template with no actual content filled in—all sections remain as placeholders with unchecked checkboxes and no implementation details provided. Fill in the Description section explaining the routing changes, list any related issues, and verify pre-commit and test checklist items have been completed.
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed Title specifically describes the main change (adding DeepSeek routing support for BF16 and MxInt variants), directly matching the code modifications in both CUDA kernel and Python API layers.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db2aacb and ecd2be2.

📒 Files selected for processing (3)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (3 hunks)
  • flashinfer/fused_moe/core.py (10 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (2)
  • local_num_experts (277-277)
  • num_experts (263-263)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
flashinfer/fused_moe/core.py (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
  • routing_bias (158-164)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

1406-1454: BF16 MoE: routed_scaling_factor threading looks correct and backward‑compatible

The added Optional<double> routed_scaling_factor in trtllm_bf16_moe and initialization via args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); cleanly align BF16 with the existing FP8/FP4 paths. Defaulting to 1.0 keeps prior behavior when callers pass None, and the parameter is correctly propagated into MoERunnerArgs before launcher initialization.


1807-1816: MxInt4 MoE: optional routing_bias validation is consistent with Python API

Making routing_bias optional and only validating dtype/shape when has_value() is true is the right relaxation. The checks enforce [num_experts] and bfloat16, matching the updated Python docstring and tests, and should not affect cases where no bias is used.

tests/moe/test_trtllm_gen_fused_moe.py (3)

750-788: MxInt4 test harness correctly wires routing_bias and routed_scaling into kernel call

MxInt4BlockScaleMoe.call_moe now forwards routing_bias and a routed_scaling (with a safe default of 1.0 when absent) into trtllm_mxint4_block_scale_moe, matching the core API and C++ launcher signature. This gives the tests full coverage of optional bias and scaling behavior for MXInt4 DeepSeek‑style routing.


1304-1339: BF16 test harness now exercises routed_scaling_factor end‑to‑end

BF16Moe.call_moe plumbs kwargs["routed_scaling"] through to trtllm_bf16_moe, so Renormalize/TopK configs can keep None (defaulting to 1.0 in the launcher) while DeepSeekV3 configs use a non‑trivial value (e.g., 2.5). This matches the BF16 core API and ensures the new routing scaling support is actually tested.


2628-2724: DeepSeekV3 parametrization updates appropriately include BF16 and MxInt4 implementations

Adding MxInt4BlockScaleMoe() and BF16Moe() to the moe_impl list and to compatible_moe_impls for the DSv3 routing and Shuffled_BlockMajorK weight layout ensures these implementations are exercised only in configurations they can support. The changes are consistent with the underlying kernels (BF16/MxInt4, BlockMajorK) and with the new routing_bias/routed_scaling plumbing.

flashinfer/fused_moe/core.py (4)

1100-1116: BF16 MoE: routed_scaling_factor is now correctly wired through MoERunner to the C++ op

The BF16 branch in MoERunner.forward now forwards kwargs["routed_scaling_factor"] into moe_op.trtllm_bf16_moe, matching the updated custom op and C++ signatures. Together with the C++ defaulting logic (value_or(1.0)), this cleanly enables BF16 DeepSeek‑style scaling while preserving previous behavior when the factor is omitted.


1921-2043: MxInt4 trtllm path: routing_bias support is threaded correctly from Python to C++

trtllm_mxint4_block_scale_moe_op and its wrapper now accept routing_bias: Optional[torch.Tensor] and pass it through to moe_op.trtllm_mxint4_block_scale_moe, matching the updated C++ launcher. The autotuner kwargs and fake op were updated accordingly, so tests can now exercise optional bias for MXInt4 without API mismatch.


2084-2165: Public BF16 API doc correctly documents routed_scaling_factor

The BF16 wrapper’s docstring and signature now expose routed_scaling_factor: Optional[float] = None and explain its semantics in the routing section. Combined with the underlying defaulting to 1.0 when None, this provides a clear and stable public API for DeepSeek‑style scaling in BF16.


2582-2674: MxInt4 Python wrapper aligns with new routing_bias and routed_scaling_factor behavior

The updated trtllm_mxint4_block_scale_moe wrapper adds routing_bias to the signature and forwards both routing_bias and routed_scaling_factor into the SM100 module’s op in the correct order. The docstring now specifies that routing_bias must be bf16, matching the C++ checks and the updated tests.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and performance of TRT-LLM generated Mixture-of-Experts (MoE) operations by integrating DeepSeekV3 routing capabilities for BF16xBF16 and MxIntxBf16 data types. It introduces new parameters to control routing behavior and bias, ensuring broader compatibility and more granular control over expert selection within the framework.

Highlights

  • DeepSeek Routing Support: Enabled DeepSeekV3 routing for both BF16xBF16 and MxIntxBf16 TRT-LLM Mixture-of-Experts (MoE) implementations, expanding the range of supported routing methods.
  • New routed_scaling_factor for BF16 MoE: Introduced an optional routed_scaling_factor parameter for BF16 MoE operations, allowing for dynamic scaling of routing logits to fine-tune expert selection.
  • Support for routing_bias in MxInt4 MoE: Added support for an optional routing_bias parameter in MxInt4 block-scale MoE, which must be of bfloat16 data type, providing more control over expert routing.
  • Expanded Test Coverage: Updated unit tests to validate the new routed_scaling_factor and routing_bias parameters and to confirm DeepSeekV3 routing compatibility for the newly supported BF16 and MxInt4 MoE types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@jiahanc
Copy link
Collaborator

jiahanc commented Dec 17, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !199 has been created, and the CI pipeline #40395384 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)

1275-1375: BF16 custom op and wrapper expose routed_scaling_factor consistently, but fake op is now out of sync

The new routed_scaling_factor: Optional[float] parameter in trtllm_bf16_moe_op, its propagation into the autotuner kwargs, and forwarding to moe_op.trtllm_bf16_moe plus the top‑level trtllm_bf16_moe wrapper are all consistent and match the C++ launcher.

However, _fake_trtllm_bf16_moe still uses the old signature (missing routed_scaling_factor), so any path that invokes the fake op (e.g., fake tensor / Inductor / AOT flows) will raise a TypeError due to an unexpected positional argument.

Apply this diff to align the fake op signature with the real one:

 @register_fake_op("flashinfer::trtllm_bf16_moe")
 def _fake_trtllm_bf16_moe(
     routing_logits: torch.Tensor,
     routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     gemm1_weights: torch.Tensor,
     gemm2_weights: torch.Tensor,
     num_experts: int,
     top_k: int,
     n_group: Optional[int],
     topk_group: Optional[int],
     intermediate_size: int,
     local_expert_offset: int,
     local_num_experts: int,
-    routing_method_type: int,
+    routed_scaling_factor: Optional[float],
+    routing_method_type: int,
     use_shuffled_weight: bool,
     weight_layout: int,
     enable_pdl: Optional[bool] = None,
     tune_max_num_tokens: int = 8192,
 ):

This keeps argument order matching the custom op and preserves existing defaults.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db2aacb and ecd2be2.

📒 Files selected for processing (3)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (3 hunks)
  • flashinfer/fused_moe/core.py (10 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (2)
  • local_num_experts (277-277)
  • num_experts (263-263)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
flashinfer/fused_moe/core.py (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
  • routing_bias (158-164)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

1406-1454: BF16 MoE: routed_scaling_factor threading looks correct and backward‑compatible

The added Optional<double> routed_scaling_factor in trtllm_bf16_moe and initialization via args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); cleanly align BF16 with the existing FP8/FP4 paths. Defaulting to 1.0 keeps prior behavior when callers pass None, and the parameter is correctly propagated into MoERunnerArgs before launcher initialization.


1807-1816: MxInt4 MoE: optional routing_bias validation is consistent with Python API

Making routing_bias optional and only validating dtype/shape when has_value() is true is the right relaxation. The checks enforce [num_experts] and bfloat16, matching the updated Python docstring and tests, and should not affect cases where no bias is used.

tests/moe/test_trtllm_gen_fused_moe.py (3)

750-788: MxInt4 test harness correctly wires routing_bias and routed_scaling into kernel call

MxInt4BlockScaleMoe.call_moe now forwards routing_bias and a routed_scaling (with a safe default of 1.0 when absent) into trtllm_mxint4_block_scale_moe, matching the core API and C++ launcher signature. This gives the tests full coverage of optional bias and scaling behavior for MXInt4 DeepSeek‑style routing.


1304-1339: BF16 test harness now exercises routed_scaling_factor end‑to‑end

BF16Moe.call_moe plumbs kwargs["routed_scaling"] through to trtllm_bf16_moe, so Renormalize/TopK configs can keep None (defaulting to 1.0 in the launcher) while DeepSeekV3 configs use a non‑trivial value (e.g., 2.5). This matches the BF16 core API and ensures the new routing scaling support is actually tested.


2628-2724: DeepSeekV3 parametrization updates appropriately include BF16 and MxInt4 implementations

Adding MxInt4BlockScaleMoe() and BF16Moe() to the moe_impl list and to compatible_moe_impls for the DSv3 routing and Shuffled_BlockMajorK weight layout ensures these implementations are exercised only in configurations they can support. The changes are consistent with the underlying kernels (BF16/MxInt4, BlockMajorK) and with the new routing_bias/routed_scaling plumbing.

flashinfer/fused_moe/core.py (4)

1100-1116: BF16 MoE: routed_scaling_factor is now correctly wired through MoERunner to the C++ op

The BF16 branch in MoERunner.forward now forwards kwargs["routed_scaling_factor"] into moe_op.trtllm_bf16_moe, matching the updated custom op and C++ signatures. Together with the C++ defaulting logic (value_or(1.0)), this cleanly enables BF16 DeepSeek‑style scaling while preserving previous behavior when the factor is omitted.


1921-2043: MxInt4 trtllm path: routing_bias support is threaded correctly from Python to C++

trtllm_mxint4_block_scale_moe_op and its wrapper now accept routing_bias: Optional[torch.Tensor] and pass it through to moe_op.trtllm_mxint4_block_scale_moe, matching the updated C++ launcher. The autotuner kwargs and fake op were updated accordingly, so tests can now exercise optional bias for MXInt4 without API mismatch.


2084-2165: Public BF16 API doc correctly documents routed_scaling_factor

The BF16 wrapper’s docstring and signature now expose routed_scaling_factor: Optional[float] = None and explain its semantics in the routing section. Combined with the underlying defaulting to 1.0 when None, this provides a clear and stable public API for DeepSeek‑style scaling in BF16.


2582-2674: MxInt4 Python wrapper aligns with new routing_bias and routed_scaling_factor behavior

The updated trtllm_mxint4_block_scale_moe wrapper adds routing_bias to the signature and forwards both routing_bias and routed_scaling_factor into the SM100 module’s op in the correct order. The docstring now specifies that routing_bias must be bf16, matching the C++ checks and the updated tests.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for DeepSeek routing for Bf16xBf16 and MxIntxBf16 MoE layers in TRT-LLM. The changes are well-structured, introducing routed_scaling_factor and routing_bias parameters and plumbing them through from the Python API to the C++ kernels. The test suite has also been updated to cover this new functionality. The implementation looks solid. I have a few minor suggestions to improve the robustness and consistency of the test code.

):
"""Call MoE with runtime input quantization + kernel execution (done at runtime)."""
expert_logits = kwargs["expert_logits"]
routing_bias = kwargs["routing_bias"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness and consistency with other parts of the code (e.g., enable_autotune), it's better to use kwargs.get("routing_bias") instead of direct access. This will prevent a KeyError if the key is missing.

Suggested change
routing_bias = kwargs["routing_bias"]
routing_bias = kwargs.get("routing_bias")

intermediate_size = kwargs["intermediate_size"]
routing_method_type = kwargs["routing_method_type"]
enable_autotune = kwargs.get("enable_autotune", True)
routed_scaling = kwargs.get("routed_scaling", 1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value of 1.0 is already handled in the C++ layer. To maintain a single source of truth for default values and for consistency, it's better to remove the default value here. kwargs.get("routed_scaling") will return None if the key is missing, and the C++ layer will correctly use its default of 1.0.

Suggested change
routed_scaling = kwargs.get("routed_scaling", 1.0)
routed_scaling = kwargs.get("routed_scaling")

n_groups = kwargs["n_groups"]
top_k_groups = kwargs["top_k_groups"]
intermediate_size = kwargs["intermediate_size"]
routed_scaling = kwargs["routed_scaling"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness and consistency, it's better to use kwargs.get("routed_scaling") instead of direct access. This will prevent a KeyError if the key is missing and will return None, which is a valid value for this optional parameter and is handled correctly by the downstream C++ function.

Suggested change
routed_scaling = kwargs["routed_scaling"]
routed_scaling = kwargs.get("routed_scaling")

@yzh119
Copy link
Collaborator

yzh119 commented Dec 18, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !199 has been created, and the CI pipeline #40411215 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants