chore/feat: Add do_finalize to trtllm-gen fp8/f16 MoE APIs#2548
chore/feat: Add do_finalize to trtllm-gen fp8/f16 MoE APIs#2548yzh119 merged 19 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughEntry points for fused MoE now accept a do_finalize flag and return arrays/lists of tensors instead of single tensors; do_finalize is threaded through Python wrappers, FFI entry points, and the CUDA launcher, changing whether final outputs or intermediate tensors are returned across BF16/FP8/FP4/MXInt4 variants. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python wrapper
participant FFI as C++ FFI entry
participant Launcher as CUDA launcher
participant GPU as GPU kernels
Py->>FFI: call trtllm_*_moe(..., do_finalize)
FFI->>Launcher: build args + per-tile launchers, pass do_finalize
Launcher->>GPU: run kernels (produce final or intermediate tensors)
GPU-->>Launcher: return Array\<Tensor\>
Launcher-->>FFI: forward Array\<Tensor\>
FFI-->>Py: return Array\<Tensor\> (caller may select [0] when finalized)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a do_finalize parameter to the trtllm-gen fp8 and bf16 MoE APIs, aligning them with the existing fp4 API and enabling more flexible post-processing. The changes span both the C++ CUDA kernels and the Python interface, including updates to function signatures, return types, and docstrings.
While the changes for the do_finalize=True path seem correct and are reflected in the updated tests, I've identified several critical issues in the implementation for the do_finalize=False path that will prevent the new feature from working as intended for FP8 MoE variants. Specifically, the do_finalize parameter is ignored in the C++ implementation for FP8, and there's a return value mismatch between the C++ and Python layers for FP8 block scale MoE that will cause a runtime error. I've also suggested adding test cases for the do_finalize=False scenario to help catch such issues.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
655-659:⚠️ Potential issue | 🔴 Critical
do_finalizeis hardcoded totrue, overriding the caller's value.
Fp8PerTensorLauncher::prepare_moe()unconditionally setsargs->do_finalize = trueat line 659, which overwrites thedo_finalizevalue that was set by the caller (e.g., at line 1618). This makes the newly addeddo_finalize=Falsepath unreachable for FP8 per-tensor scale MoE, silently ignoring the user's request.Proposed fix
- args->do_finalize = true; // FP8 per-tensor scale always finalizes + // args->do_finalize is set by the caller
926-930:⚠️ Potential issue | 🔴 CriticalSame
do_finalizeoverride issue as inFp8PerTensorLauncher.
Fp8BlockScaleLauncher::prepare_moe()hardcodesargs->do_finalize = trueat line 930, overriding the caller's value (set at line 1716). The newdo_finalizeparameter for FP8 block-scale MoE is silently ignored.Proposed fix
- args->do_finalize = true; + // args->do_finalize is set by the caller
984-988:⚠️ Potential issue | 🔴 CriticalReturn element count mismatch with Python caller when
do_finalize=False.
Fp8BlockScaleLauncher::run()returns 2 elements{gemm2_output, expanded_idx_to_permuted_idx}when!do_finalize, but the Python caller atcore.pylines 1721–1726 unpacks 3 elements:gemm2_output, _, expanded_idx_to_permuted_idx = intermediate_outputThis will crash with a
ValueErroronce the hardcodeddo_finalize=trueinprepare_moe()is removed. The return should includeexpert_weightsto match the base class pattern and the Python side.Proposed fix to align with base class and Python expectation
if (args->do_finalize) { return {output}; } - return {gemm2_output, expanded_idx_to_permuted_idx}; + return {gemm2_output, FusedMoeLauncher::expert_weights, expanded_idx_to_permuted_idx};flashinfer/fused_moe/core.py (2)
1390-1414:⚠️ Potential issue | 🟠 MajorFake op always returns a single-element list, ignoring
do_finalize.When
do_finalize=False, the real op returns a 3-element list. The fake op should mirror this behavior fortorch.compiletracing to produce correct graph shapes. The same issue applies to_fake_trtllm_fp8_per_tensor_scale_moe(line 1569) and_fake_trtllm_fp8_block_scale_moe(line 1759).Proposed fix for BF16 fake op (apply similar pattern to other fake ops)
def _fake_trtllm_bf16_moe( ... do_finalize: bool = True, ... ) -> List[torch.Tensor]: seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] - return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + if do_finalize: + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + else: + return [ + hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16), + hidden_states.new_empty([seq_len, top_k], dtype=torch.bfloat16), + hidden_states.new_empty([seq_len * top_k], dtype=torch.int32), + ]
2109-2113:⚠️ Potential issue | 🟠 Major
trtllm_mxint4_block_scale_moe_opreturns a bareTensorinstead ofList[Tensor].The function signature declares
-> List[torch.Tensor](line 2017) and the fake op returns a list (line 2143), but the real implementation returnsoutput(a bare tensor) at line 2113. This is inconsistent and could cause type errors downstream.Proposed fix
- return output + return [output]tests/moe/test_trtllm_gen_routed_fused_moe.py (1)
339-341:⚠️ Potential issue | 🔴 CriticalMissing
[0]indexing ontrtllm_fp8_block_scale_moereference output — likely runtime crash.Line 341 calls
trtllm_fp8_block_scale_moe(...)and directly chains.to(torch.float), but the same function is now indexed with[0]intest_trtllm_gen_fused_moe.py(line 1039). If the return type was changed toList[Tensor], calling.to(torch.float)on a list will raise anAttributeError.🐛 Proposed fix
- ).to(torch.float) + )[0].to(torch.float)
🤖 Fix all issues with AI agents
In `@flashinfer/fused_moe/core.py`:
- Around line 1718-1726: The FP8 block-scale path currently returns or unpacks
intermediate_output without initializing expert_weights and assumes
Fp8BlockScaleLauncher::run() yields three values; fix this by matching the C++
change so Fp8BlockScaleLauncher::run() returns (gemm2_output, expert_weights,
expanded_idx_to_permuted_idx) when !do_finalize, then in the Python path where
do_finalize is False unpack intermediate_output into gemm2_output,
expert_weights, expanded_idx_to_permuted_idx (instead of the current two/three
mismatch) and ensure expert_weights is the value coming from that unpack (not an
uninitialized variable) before returning the torch.from_dlpack(gemm2_output),
expert_weights, torch.from_dlpack(expanded_idx_to_permuted_idx).
- Around line 1532-1540: The FP8 per-tensor return path discards the
C++-produced expert_weights by unpacking intermediate_output as (gemm2_output,
_, expanded_idx_to_permuted_idx), leaving the Python-side expert_weights
uninitialized; update the unpack to capture the C++ expert_weights (e.g.,
extract intermediate_output[1]) and convert it to a PyTorch tensor (like the
other dlpack conversions) before returning so the returned expert_weights is the
initialized tensor from the C++ kernel; adjust the non-do_finalize branch
handling of intermediate_output, gemm2_output, expanded_idx_to_permuted_idx and
expert_weights accordingly in fused_moe.core (symbols: intermediate_output,
gemm2_output, expanded_idx_to_permuted_idx, expert_weights, do_finalize).
- Around line 1379-1388: The branch that handles do_finalize=False currently
discards the expert_weights returned by the native routing kernel and instead
returns the locally allocated empty expert_weights, causing callers to receive
uninitialized data; modify the unpacking of intermediate_output so it captures
the C++-returned expert_weights (i.e., change "gemm2_output, _,
expanded_idx_to_permuted_idx = intermediate_output" to unpack the second element
into expert_weights) and return that expert_weights (the one from
intermediate_output) along with torch.from_dlpack(gemm2_output) and
torch.from_dlpack(expanded_idx_to_permuted_idx), rather than the local
placeholder expert_weights.
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Line 1039: The implementation of trtllm_mxint4_block_scale_moe_op returns a
single torch.Tensor but is annotated as List[torch.Tensor], causing inconsistent
handling with other _op functions; fix by making the function's return shape
consistent with the others—either change the implementation to return [output]
(wrap the tensor in a list) so it matches trtllm_fp8_block_scale_moe_op,
trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp4_block_scale_moe_op, and
trtllm_bf16_moe_op, or update the function's return type annotation to
torch.Tensor if you prefer single-tensor returns, and then update any
calling/tests to use the consistent access pattern.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1444-1449:⚠️ Potential issue | 🟡 MinorInconsistent return value when
do_finalize=trueacross launchers.
FP4BlockScaleLauncher::runreturns an empty array{}whendo_finalize=true(line 1446), while the base classFusedMoeLauncher::run(line 370) andFp8BlockScaleLauncher::run(line 983) return{output}. Since this PR's goal is to align BF16/FP8 APIs with FP4, the return contract should be consistent. Either all launchers should return the finalized output tensor, or all should return empty (with the caller holding a reference to the output).
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 982-986: The derived class Fp8BlockScaleLauncher has a private
TensorView expert_weights that shadows the base
FusedMoeLauncher::expert_weights, so when prepare_routing() allocates/populates
the base-class expert_weights (via workspace.expert_weights) the current return
returns the stale derived member; fix by returning the base-class member
explicitly (e.g., qualify the return as FusedMoeLauncher::expert_weights) or
rename the derived member (e.g., expert_weights_input) and update all uses in
Fp8BlockScaleLauncher (constructor and any references) so the return and any
consumers receive the populated base-class Tensor; ensure prepare_routing() and
any routing/kernel code still write to the base-class expert_weights.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)
1186-1209:⚠️ Potential issue | 🟡 MinorPass
do_finalizeinto the MXINT4 tuning run.Other variants forward
do_finalize, but the MXINT4 branch omits it, so tuning runs the finalized path even when callers request intermediate outputs.🛠️ Suggested fix
kwargs["routed_scaling_factor"], kwargs["routing_method_type"], + kwargs["do_finalize"], kwargs["enable_pdl"], output, [-1, -1] if tactic == -1 else tactic,
1391-1414:⚠️ Potential issue | 🟠 MajorFake ops should mirror
do_finalizereturn arity.All three fake ops ignore
do_finalizeand always return a single output tensor, which can break tracing/shape inference for non-finalized paths that now return 3 tensors.🧪 Example fix (apply to all three fake ops)
- return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + if do_finalize: + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + else: + return [ + hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16), + hidden_states.new_empty([seq_len, top_k], dtype=routing_logits.dtype), + hidden_states.new_empty([seq_len * top_k], dtype=torch.int32), + ]Also applies to: 1543-1569, 1729-1759
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1974-1979: The free function currently captures result =
selected_launcher->run(config, enable_pdl) then tries to return variables
(gemm2_output, expanded_idx_to_permuted_idx, output) that are protected members
of FusedMoeLauncher and out of scope; change the function to simply return
result directly (i.e. return the value from selected_launcher->run) for both
do_finalize branches so behavior matches
trtllm_bf16_moe/trtllm_fp8_*/trtllm_fp4_block_scale_moe; if the mxint4 path
truly needs to return {output} when finalized, move that special-case logic into
MxInt4BlockScaleLauncher::run() instead of here.
- Around line 1480-1489: Add a missing output TensorView parameter to
trtllm_bf16_moe (and likewise to trtllm_fp8_per_tensor_scale_moe) so the
function can return the final MoE result when do_finalize=true; update the
function signatures to accept TensorView output, pass that output into the
launcher by setting args->output = output (instead of only allocating inside
Bf16MoeLauncher::prepare_moe), and ensure the caller-visible return uses this
provided output (matching the trtllm_fp4_block_scale_moe /
trtllm_mxint4_block_scale_moe pattern).
In `@flashinfer/fused_moe/core.py`:
- Around line 2116-2124: The non-finalized branch returns a locally allocated,
uninitialized expert_weights; instead unpack and use the kernel-produced
expert_weights from intermediate_output. Replace the unpacking "gemm2_output,
expanded_idx_to_permuted_idx = intermediate_output" with "gemm2_output,
expert_weights, expanded_idx_to_permuted_idx = intermediate_output" and return
that expert_weights (not the local variable) alongside
torch.from_dlpack(gemm2_output) and
torch.from_dlpack(expanded_idx_to_permuted_idx) in the else branch.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
| enable_pdl: Optional[bool] = None, | ||
| tune_max_num_tokens: int = 8192, | ||
| ) -> torch.Tensor: | ||
| ) -> List[torch.Tensor]: |
There was a problem hiding this comment.
so i'm a little worried about changing the output type
i wonder if there is a way we could conditionally preserve the old output for compatibility?
There was a problem hiding this comment.
@aleozlx We could make it return the single tensor when do_finalize is False, but that would make fp8/bf16 inconsistent with the fp4.
aleozlx
left a comment
There was a problem hiding this comment.
looks great overall
posted one comment
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[FAILED] Pipeline #44336875: 9/20 passed |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[FAILED] Pipeline #44402035: 9/20 passed |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
tests are clean, g/b300 timed out |
|
/bot run |
|
[FAILED] Pipeline #44572775: 14/20 passed |
) <!-- .github/pull_request_template.md --> ## 📌 Description Fix unit test failures caused by change in #2548 to the API that now returns a list of tensors. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Updated MoE test implementations to correctly extract return values from method calls. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR aims to add the
do_finalizeto the trtllm-gen fp8/f16 MoE APIs to align with fp4 APIs. This also allows flexible post processing of MoE.Additionally, fix the bug that the output tensors are allocated twice for bf16/fp8 MoE.
The API changes include:
do_finalizetorch.Tensor->List[torch.Tensor]🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Tests
Documentation