fix: add DeepSeek routing for Bf16xBf16 and MxIntxBf16 TRT-LLM Gen MoE#2234
Conversation
Signed-off-by: Nikita Korobov <[email protected]>
WalkthroughAdded optional Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (3)
🧰 Additional context used🧬 Code graph analysis (2)csrc/trtllm_fused_moe_kernel_launcher.cu (2)
flashinfer/fused_moe/core.py (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (9)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and performance of TRT-LLM generated Mixture-of-Experts (MoE) operations by integrating DeepSeekV3 routing capabilities for BF16xBF16 and MxIntxBf16 data types. It introduces new parameters to control routing behavior and bias, ensuring broader compatibility and more granular control over expert selection within the framework. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)
1275-1375: BF16 custom op and wrapper expose routed_scaling_factor consistently, but fake op is now out of syncThe new
routed_scaling_factor: Optional[float]parameter intrtllm_bf16_moe_op, its propagation into the autotuner kwargs, and forwarding tomoe_op.trtllm_bf16_moeplus the top‑leveltrtllm_bf16_moewrapper are all consistent and match the C++ launcher.However,
_fake_trtllm_bf16_moestill uses the old signature (missingrouted_scaling_factor), so any path that invokes the fake op (e.g., fake tensor / Inductor / AOT flows) will raise aTypeErrordue to an unexpected positional argument.Apply this diff to align the fake op signature with the real one:
@register_fake_op("flashinfer::trtllm_bf16_moe") def _fake_trtllm_bf16_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor, num_experts: int, top_k: int, n_group: Optional[int], topk_group: Optional[int], intermediate_size: int, local_expert_offset: int, local_num_experts: int, - routing_method_type: int, + routed_scaling_factor: Optional[float], + routing_method_type: int, use_shuffled_weight: bool, weight_layout: int, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, ):This keeps argument order matching the custom op and preserves existing defaults.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/trtllm_fused_moe_kernel_launcher.cu(3 hunks)flashinfer/fused_moe/core.py(10 hunks)tests/moe/test_trtllm_gen_fused_moe.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (2)
local_num_experts(277-277)num_experts(263-263)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
flashinfer/fused_moe/core.py (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
routing_bias(158-164)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1406-1454: BF16 MoE: routed_scaling_factor threading looks correct and backward‑compatibleThe added
Optional<double> routed_scaling_factorintrtllm_bf16_moeand initialization viaargs->routed_scaling_factor = routed_scaling_factor.value_or(1.0);cleanly align BF16 with the existing FP8/FP4 paths. Defaulting to1.0keeps prior behavior when callers passNone, and the parameter is correctly propagated intoMoERunnerArgsbefore launcher initialization.
1807-1816: MxInt4 MoE: optional routing_bias validation is consistent with Python APIMaking
routing_biasoptional and only validating dtype/shape whenhas_value()is true is the right relaxation. The checks enforce[num_experts]andbfloat16, matching the updated Python docstring and tests, and should not affect cases where no bias is used.tests/moe/test_trtllm_gen_fused_moe.py (3)
750-788: MxInt4 test harness correctly wires routing_bias and routed_scaling into kernel call
MxInt4BlockScaleMoe.call_moenow forwardsrouting_biasand arouted_scaling(with a safe default of1.0when absent) intotrtllm_mxint4_block_scale_moe, matching the core API and C++ launcher signature. This gives the tests full coverage of optional bias and scaling behavior for MXInt4 DeepSeek‑style routing.
1304-1339: BF16 test harness now exercises routed_scaling_factor end‑to‑end
BF16Moe.call_moeplumbskwargs["routed_scaling"]through totrtllm_bf16_moe, so Renormalize/TopK configs can keepNone(defaulting to 1.0 in the launcher) while DeepSeekV3 configs use a non‑trivial value (e.g., 2.5). This matches the BF16 core API and ensures the new routing scaling support is actually tested.
2628-2724: DeepSeekV3 parametrization updates appropriately include BF16 and MxInt4 implementationsAdding
MxInt4BlockScaleMoe()andBF16Moe()to themoe_impllist and tocompatible_moe_implsfor theDSv3routing andShuffled_BlockMajorKweight layout ensures these implementations are exercised only in configurations they can support. The changes are consistent with the underlying kernels (BF16/MxInt4, BlockMajorK) and with the new routing_bias/routed_scaling plumbing.flashinfer/fused_moe/core.py (4)
1100-1116: BF16 MoE: routed_scaling_factor is now correctly wired through MoERunner to the C++ opThe BF16 branch in
MoERunner.forwardnow forwardskwargs["routed_scaling_factor"]intomoe_op.trtllm_bf16_moe, matching the updated custom op and C++ signatures. Together with the C++ defaulting logic (value_or(1.0)), this cleanly enables BF16 DeepSeek‑style scaling while preserving previous behavior when the factor is omitted.
1921-2043: MxInt4 trtllm path: routing_bias support is threaded correctly from Python to C++
trtllm_mxint4_block_scale_moe_opand its wrapper now acceptrouting_bias: Optional[torch.Tensor]and pass it through tomoe_op.trtllm_mxint4_block_scale_moe, matching the updated C++ launcher. The autotuner kwargs and fake op were updated accordingly, so tests can now exercise optional bias for MXInt4 without API mismatch.
2084-2165: Public BF16 API doc correctly documents routed_scaling_factorThe BF16 wrapper’s docstring and signature now expose
routed_scaling_factor: Optional[float] = Noneand explain its semantics in the routing section. Combined with the underlying defaulting to 1.0 whenNone, this provides a clear and stable public API for DeepSeek‑style scaling in BF16.
2582-2674: MxInt4 Python wrapper aligns with new routing_bias and routed_scaling_factor behaviorThe updated
trtllm_mxint4_block_scale_moewrapper addsrouting_biasto the signature and forwards bothrouting_biasandrouted_scaling_factorinto the SM100 module’s op in the correct order. The docstring now specifies thatrouting_biasmust be bf16, matching the C++ checks and the updated tests.
There was a problem hiding this comment.
Code Review
This pull request adds support for DeepSeek routing for Bf16xBf16 and MxIntxBf16 MoE layers in TRT-LLM. The changes are well-structured, introducing routed_scaling_factor and routing_bias parameters and plumbing them through from the Python API to the C++ kernels. The test suite has also been updated to cover this new functionality. The implementation looks solid. I have a few minor suggestions to improve the robustness and consistency of the test code.
| ): | ||
| """Call MoE with runtime input quantization + kernel execution (done at runtime).""" | ||
| expert_logits = kwargs["expert_logits"] | ||
| routing_bias = kwargs["routing_bias"] |
There was a problem hiding this comment.
For robustness and consistency with other parts of the code (e.g., enable_autotune), it's better to use kwargs.get("routing_bias") instead of direct access. This will prevent a KeyError if the key is missing.
| routing_bias = kwargs["routing_bias"] | |
| routing_bias = kwargs.get("routing_bias") |
| intermediate_size = kwargs["intermediate_size"] | ||
| routing_method_type = kwargs["routing_method_type"] | ||
| enable_autotune = kwargs.get("enable_autotune", True) | ||
| routed_scaling = kwargs.get("routed_scaling", 1.0) |
There was a problem hiding this comment.
The default value of 1.0 is already handled in the C++ layer. To maintain a single source of truth for default values and for consistency, it's better to remove the default value here. kwargs.get("routed_scaling") will return None if the key is missing, and the C++ layer will correctly use its default of 1.0.
| routed_scaling = kwargs.get("routed_scaling", 1.0) | |
| routed_scaling = kwargs.get("routed_scaling") |
| n_groups = kwargs["n_groups"] | ||
| top_k_groups = kwargs["top_k_groups"] | ||
| intermediate_size = kwargs["intermediate_size"] | ||
| routed_scaling = kwargs["routed_scaling"] |
There was a problem hiding this comment.
For robustness and consistency, it's better to use kwargs.get("routed_scaling") instead of direct access. This will prevent a KeyError if the key is missing and will return None, which is a valid value for this optional parameter and is handled correctly by the downstream C++ function.
| routed_scaling = kwargs["routed_scaling"] | |
| routed_scaling = kwargs.get("routed_scaling") |
|
/bot run |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.