Skip to content

fix: support fp32 logits for fp8_per_tensor and fp8_block#2534

Open
yweng0828 wants to merge 2 commits intoflashinfer-ai:mainfrom
yweng0828:yweng/add_fp32_logits_for_fp8_routing
Open

fix: support fp32 logits for fp8_per_tensor and fp8_block#2534
yweng0828 wants to merge 2 commits intoflashinfer-ai:mainfrom
yweng0828:yweng/add_fp32_logits_for_fp8_routing

Conversation

@yweng0828
Copy link

@yweng0828 yweng0828 commented Feb 10, 2026

📌 Description

This PR adds more template instantiation for supporting FP32 logits for routing when using fp8_per_tensor and fp8_block quantization.

  • Differentiates between mDtypeScore and mDtypeExpW and adds more template instantiation.
  • Adds testing for different logits data types.

🔍 Related Issues

#2469

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added support for configurable data types for routing scores (float32 and bfloat16).
    • Enhanced routing logits validation with runtime dtype checks.
  • Bug Fixes

    • Improved dtype consistency validation across MOE computation paths.
  • Tests

    • Extended test coverage to validate multiple data type combinations for routing logits.
    • Added parameterized testing for FP32 and BF16 logits across routing methods.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yweng0828, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the fused MoE kernels to support FP32 logits, which is necessary for compatibility with certain models like DeepSeekV3. The changes involve modifications to the kernel launcher, runner, and test suite to accommodate the new data type. This ensures that the MoE kernels can handle a wider range of models and configurations.

Highlights

  • FP32 Logits Support: This PR introduces support for FP32 logits in FP8 per-tensor and FP8 block-scale fused MoE kernels, enabling compatibility with models like DeepSeekV3.
  • Code Modifications: The changes involve modifying the FusedMoeLauncher class to handle different data types for routing scores and updating kernel launch configurations to support FP32.
  • Test Updates: The PR includes updates to the test suite to incorporate FP32 logits testing, ensuring the new functionality works as expected.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Added mDtypeScore member to FusedMoeLauncher class.
    • Modified kernel launch parameters to pass mDtypeScore.
    • Added logic to determine mDtypeScore based on routing_logits dtype.
  • csrc/trtllm_fused_moe_runner.cu
    • Modified run function to accept dtypeScore.
    • Assigned dtypeScore to routingData.mDtypeScore.
  • include/flashinfer/trtllm/fused_moe/DevKernel.h
    • Updated LAUNCH_ROUTING_WITH_NUM_EXPERTS macro to handle different combinations of mDtypeScore and mDtypeExpW.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mDtypeScore member to the Data struct.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Modified Runner::run to accept dtypeScore as a parameter.
  • tests/moe/test_dpsk_fused_moe_fp8.py
    • Added logits_dtype parameter to DPSKFusedMoEFp8 class.
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Removed unnecessary dtype conversion for expert_logits.
    • Added logits_dtype parameter to run_moe_test function.
    • Modified run_moe_test to create expert_logits with the specified dtype.
    • Added logits_dtype parameterization to test functions.
  • tests/moe/utils.py
    • Added logits_dtype parameter to skip_checks function.
    • Added skip logic for incompatible logits_dtype and routing method/quant mode combinations.
Activity
  • The pull request introduces support for FP32 logits in FP8 fused MoE kernels.
  • The changes involve modifying the kernel launcher, runner, and test suite.
  • The test suite is updated to include FP32 logits testing.
  • Skip logic is added to the tests to handle incompatible configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces support for fp32 logits in the fused MoE kernels, specifically for fp8_per_tensor and fp8_block quantization modes. This is achieved by adding a mDtypeScore member to FusedMoeLauncher and routingRenormalize::Data structs, and updating the routing_runner.run calls and kernel dispatch macros to utilize this new dtype. The routing_logits dtype validation logic in trtllm_fp8_per_tensor_scale_moe and trtllm_fp8_block_scale_moe functions is relaxed to allow float32 where appropriate, while still enforcing float32 for DeepSeekV3 routing. Corresponding test cases are updated to parameterize logits_dtype and include new skip conditions to ensure compatibility. The changes are consistent across the codebase and align with the stated goal of the pull request.

workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel
}
if (routing_logits.has_value()) {
mDtypeScore =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this piece of code be part of the FusedMoeLauncher class so that all child classes can share it? It seems that this logic is currently in the Fp8PerTensorLauncher class. Also, we might want to add an assertion to check the data type of routing_logits.

  TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
      << "BF16 MoE: routing_logits must be bfloat16 or float.";

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. I have refactored this part of the logic and moved it to the base class.

kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
FLASHINFER_WARN("Unsupported mDtypeScore/mDtypeExpW combination"); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use this infor: Unsupported combination of mDtypeScore and mDtypeExpW

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it has been updated.

@yweng0828 yweng0828 force-pushed the yweng/add_fp32_logits_for_fp8_routing branch from a62decc to 0c876d4 Compare February 12, 2026 07:17
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 12, 2026

📝 Walkthrough

Walkthrough

The PR introduces explicit score dtype tracking (mDtypeScore) to the fused MoE routing system. A new parameter is added to track routing logits precision, defaulting to BF16 but forcing FP32 for DeepSeekV3. This parameter is threaded through launcher, runner, and kernel interfaces with accompanying validation logic and test parameterization.

Changes

Cohort / File(s) Summary
Launcher & Runner Core
csrc/trtllm_fused_moe_kernel_launcher.cu, csrc/trtllm_fused_moe_runner.cu
Added mDtypeScore field to launcher and updated routing invocations to pass score dtype before element dtype. Renamed check_routing_logits_shape() to check_routing_logits(). Introduced conditional logic forcing FP32 for DeepSeekV3 routing, defaulting to BF16 otherwise.
Kernel Interfaces
include/flashinfer/trtllm/fused_moe/runner.h, include/flashinfer/trtllm/fused_moe/RoutingKernel.h, include/flashinfer/trtllm/fused_moe/DevKernel.h
Updated Runner::run() signature to include dtypeScore parameter before dtypeElt. Added mDtypeScore member to routing Data struct. Modified kernel launch macros to condition on data.mDtypeScore alongside existing dtype checks.
Test Infrastructure
tests/moe/test_trtllm_gen_fused_moe.py, tests/moe/test_dpsk_fused_moe_fp8.py, tests/moe/utils.py
Introduced logits_dtype parameter across test harness to parameterize expert logits precision (FP32/BF16). Added test decorators for multi-precision execution. Implemented validation checks in utilities to ensure DeepSeekV3 uses FP32 logits and FP32 logits pair with supported quantization modes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • yzh119
  • djmmoss
  • wenscarl

Poem

🐰 A score of types now flows so clear,
With BF16 as default cheer,
But DeepSeekV3 calls for FP32's might,
The routing pipeline shines so bright,
Parameter threading left and right! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix: support fp32 logits for fp8_per_tensor and fp8_block' directly describes the main change: enabling FP32 logits support for specific quantization modes.
Description check ✅ Passed The description covers the main change (supporting FP32 logits), key technical details (differentiating mDtypeScore and mDtypeExpW), mentions related issue, and completes the required checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

219-221: ⚠️ Potential issue | 🟡 Minor

Stale error message in LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG.

The else-branch error message still says "Unsupported dtypeExpW", but the macro now gates on mDtypeScore, mDtypeBias, and mDtypeExpW. Update it similar to line 269 for consistency and easier debugging.

Proposed fix
   } else {                                                                                        \
-    FLASHINFER_WARN("Unsupported dtypeExpW");                                                     \
+    FLASHINFER_WARN("Unsupported combination of mDtypeScore, mDtypeBias, and mDtypeExpW");        \
   }
tests/moe/test_dpsk_fused_moe_fp8.py (1)

615-624: ⚠️ Potential issue | 🔴 Critical

Missing routing_method_type key in routing_config will cause KeyError in the updated skip_checks.

The routing_config dicts defined at lines 510–548 don't contain a "routing_method_type" key, but the new check in skip_checks (line 148 of utils.py) accesses routing_config["routing_method_type"] unconditionally. This will crash every test case in this file.

Either add "routing_method_type" to each routing config dict, or use .get() with a default in skip_checks:

Option 1: Fix in utils.py (safer — handles callers that don't set the key)
     if (
-        routing_config["routing_method_type"] == RoutingMethodType.DeepSeekV3
+        routing_config.get("routing_method_type") == RoutingMethodType.DeepSeekV3
         and logits_dtype != torch.float32
     ):
Option 2: Fix in this test file (add routing_method_type to each config)

For the DSv3 config:

             {
                 "num_experts": 256,
                 "top_k": 8,
+                "routing_method_type": RoutingMethodType.DeepSeekV3,
                 ...
             },

And similarly for other configs with the appropriate RoutingMethodType.

tests/moe/test_trtllm_gen_fused_moe.py (1)

2883-2893: ⚠️ Potential issue | 🔴 Critical

Bug: logits_dtype and cache_permute_indices arguments are swapped.

The run_moe_test signature (line 2337) expects cache_permute_indices as the 8th positional arg and logits_dtype as the 9th. Here, they are passed in the opposite order. This will cause moe_impl._cache_permute_indices to be set to a torch.dtype and expert_logits.to(logits_dtype) to receive a dict, resulting in a runtime crash.

Compare with test_renormalize_routing (line 2695–2696), test_topk_routing (line 2975–2976), and test_llama4_routing (line 3056–3057), which all pass the arguments in the correct order.

🐛 Proposed fix
     run_moe_test(
         num_tokens,
         hidden_size,
         intermediate_size,
         moe_impl,
         routing_config,
         weight_processing,
         activation_type,
-        logits_dtype,
         cache_permute_indices,
+        logits_dtype,
     )
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 288-298: The code currently sets mDtypeScore based solely on
routing_method_type which forces BFloat16 for non-DeepSeekV3 even when
routing_logits are float32; update the block that runs when
routing_logits.has_value() so mDtypeScore is derived from
routing_logits.value().dtype(): if routing_method_type ==
RoutingMethodType::DeepSeekV3 keep the TVM_FFI_ICHECK_EQ asserting dtype is
dl_float32 and set mDtypeScore = btg::Dtype::Fp32; otherwise inspect
routing_logits.value().dtype() and set mDtypeScore = btg::Dtype::Fp32 for
dl_float32, btg::Dtype::Bfloat16 for dl_bfloat16 (and error/ICHECK for
unsupported dtypes). Reference symbols: mDtypeScore, routing_logits,
RoutingMethodType::DeepSeekV3.

In `@tests/moe/utils.py`:
- Around line 155-162: The condition incorrectly compares type(moe_impl) to
QuantMode enum values causing all FP32-logits tests to skip; change the check to
inspect moe_impl.quant_mode instead. Update the if-statement that currently
reads "if logits_dtype == torch.float32 and type(moe_impl) not in
[QuantMode...]" to use "moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE, QuantMode.BF16]" so the pytest.skip call only
triggers for incompatible quant modes; keep the existing pytest.skip message and
variables (logits_dtype, moe_impl, QuantMode) unchanged.
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

2865-2870: DeepSeekV3 routing is only parametrized with FP32 logits — intentional?

Unlike test_renormalize_routing and test_topk_routing which test both FP32 and BF16, this test only exercises FP32 logits. If BF16 logits are also a valid input for DeepSeekV3 routing in production, consider adding BF16 coverage here too.

Comment on lines +288 to +298

// Set dtype of score
if (routing_logits.has_value()) {
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
<< "routing_logits must be float.";
mDtypeScore = btg::Dtype::Fp32;
} else {
mDtypeScore = btg::Dtype::Bfloat16;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

mDtypeScore is derived from routing method, not from actual logits dtype — will misinterpret FP32 logits for non-DeepSeekV3 routing.

For non-DeepSeekV3 routing (e.g., Renormalize), mDtypeScore is unconditionally set to Bfloat16 regardless of the actual routing_logits dtype. If a caller passes float32 routing logits with Renormalize routing, the routing kernel will read 32-bit floats as 16-bit bfloat16 values, producing garbage results.

The score dtype should be derived from the actual tensor dtype, with DeepSeekV3 adding its own assertion on top:

Proposed fix
     // Set dtype of score
     if (routing_logits.has_value()) {
+      // Derive mDtypeScore from the actual routing_logits tensor dtype
+      if (routing_logits.value().dtype() == dl_float32) {
+        mDtypeScore = btg::Dtype::Fp32;
+      } else {
+        mDtypeScore = btg::Dtype::Bfloat16;
+      }
+
+      // DeepSeekV3 requires float32 routing logits
       if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
         TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
             << "routing_logits must be float.";
-        mDtypeScore = btg::Dtype::Fp32;
-      } else {
-        mDtypeScore = btg::Dtype::Bfloat16;
       }
     }
🤖 Prompt for AI Agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 288 - 298, The code
currently sets mDtypeScore based solely on routing_method_type which forces
BFloat16 for non-DeepSeekV3 even when routing_logits are float32; update the
block that runs when routing_logits.has_value() so mDtypeScore is derived from
routing_logits.value().dtype(): if routing_method_type ==
RoutingMethodType::DeepSeekV3 keep the TVM_FFI_ICHECK_EQ asserting dtype is
dl_float32 and set mDtypeScore = btg::Dtype::Fp32; otherwise inspect
routing_logits.value().dtype() and set mDtypeScore = btg::Dtype::Fp32 for
dl_float32, btg::Dtype::Bfloat16 for dl_bfloat16 (and error/ICHECK for
unsupported dtypes). Reference symbols: mDtypeScore, routing_logits,
RoutingMethodType::DeepSeekV3.

Comment on lines +155 to +162
if logits_dtype == torch.float32 and type(moe_impl) not in [
QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE,
QuantMode.BF16,
]:
pytest.skip(
f"Incompatible: logits_dtype={logits_dtype} with {type(moe_impl).__name__} + {moe_impl.quant_mode}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: type(moe_impl) is compared against QuantMode enum values — condition always skips.

type(moe_impl) returns the class (e.g., FP8BlockScaleMoe), not a QuantMode enum value. This comparison will never match, so all FP32 logits tests will be unconditionally skipped regardless of quant mode, silently defeating the purpose of this PR's test coverage.

The check should compare moe_impl.quant_mode instead, which is what the error message already references:

Proposed fix
-    if logits_dtype == torch.float32 and type(moe_impl) not in [
+    if logits_dtype == torch.float32 and moe_impl.quant_mode not in [
         QuantMode.FP8_PER_TENSOR,
         QuantMode.FP8_BLOCK_SCALE,
         QuantMode.BF16,
     ]:
🤖 Prompt for AI Agents
In `@tests/moe/utils.py` around lines 155 - 162, The condition incorrectly
compares type(moe_impl) to QuantMode enum values causing all FP32-logits tests
to skip; change the check to inspect moe_impl.quant_mode instead. Update the
if-statement that currently reads "if logits_dtype == torch.float32 and
type(moe_impl) not in [QuantMode...]" to use "moe_impl.quant_mode not in
[QuantMode.FP8_PER_TENSOR, QuantMode.FP8_BLOCK_SCALE, QuantMode.BF16]" so the
pytest.skip call only triggers for incompatible quant modes; keep the existing
pytest.skip message and variables (logits_dtype, moe_impl, QuantMode) unchanged.

@yweng0828
Copy link
Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@yweng0828 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 18, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !321 has been created, and the CI pipeline #44270124 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 18, 2026

hi @yweng0828 thx for the contrib
wanna sync with you

  • ready for review?
  • tests passing locally?

@aleozlx aleozlx self-assigned this Feb 18, 2026
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #44270124: 16/20 passed

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

tests clean, ready to merge

@aleozlx aleozlx added the run-ci label Feb 20, 2026
@aleozlx aleozlx enabled auto-merge (squash) February 20, 2026 19:20
@aleozlx aleozlx added the ready label Feb 20, 2026
@yweng0828
Copy link
Author

Hi @aleozlx , thank you for your review. The PR is ready. Local testing has also passed.
Do I need to rebase to main? Or can we just merge it directly?

@wenscarl
Copy link
Collaborator

wenscarl commented Mar 4, 2026

@yweng0828 Does the change also apply to trtllm_fp4_block_scale_moe?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants