fix: support fp32 logits for fp8_per_tensor and fp8_block#2534
fix: support fp32 logits for fp8_per_tensor and fp8_block#2534yweng0828 wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @yweng0828, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the fused MoE kernels to support FP32 logits, which is necessary for compatibility with certain models like DeepSeekV3. The changes involve modifications to the kernel launcher, runner, and test suite to accommodate the new data type. This ensures that the MoE kernels can handle a wider range of models and configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces support for fp32 logits in the fused MoE kernels, specifically for fp8_per_tensor and fp8_block quantization modes. This is achieved by adding a mDtypeScore member to FusedMoeLauncher and routingRenormalize::Data structs, and updating the routing_runner.run calls and kernel dispatch macros to utilize this new dtype. The routing_logits dtype validation logic in trtllm_fp8_per_tensor_scale_moe and trtllm_fp8_block_scale_moe functions is relaxed to allow float32 where appropriate, while still enforcing float32 for DeepSeekV3 routing. Corresponding test cases are updated to parameterize logits_dtype and include new skip conditions to ensure compatibility. The changes are consistent across the codebase and align with the stated goal of the pull request.
| workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel | ||
| } | ||
| if (routing_logits.has_value()) { | ||
| mDtypeScore = |
There was a problem hiding this comment.
Should this piece of code be part of the FusedMoeLauncher class so that all child classes can share it? It seems that this logic is currently in the Fp8PerTensorLauncher class. Also, we might want to add an assertion to check the data type of routing_logits.
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
There was a problem hiding this comment.
Thanks for pointing it out. I have refactored this part of the logic and moved it to the base class.
| kernel, numBlocks, numThreads, smemSize, stream); \ | ||
| } else { \ | ||
| FLASHINFER_WARN("Unsupported dtypeExpW"); \ | ||
| FLASHINFER_WARN("Unsupported mDtypeScore/mDtypeExpW combination"); \ |
There was a problem hiding this comment.
How about use this infor: Unsupported combination of mDtypeScore and mDtypeExpW
a62decc to
0c876d4
Compare
📝 WalkthroughWalkthroughThe PR introduces explicit score dtype tracking ( Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)
219-221:⚠️ Potential issue | 🟡 MinorStale error message in
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG.The else-branch error message still says
"Unsupported dtypeExpW", but the macro now gates onmDtypeScore,mDtypeBias, andmDtypeExpW. Update it similar to line 269 for consistency and easier debugging.Proposed fix
} else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ + FLASHINFER_WARN("Unsupported combination of mDtypeScore, mDtypeBias, and mDtypeExpW"); \ }tests/moe/test_dpsk_fused_moe_fp8.py (1)
615-624:⚠️ Potential issue | 🔴 CriticalMissing
routing_method_typekey inrouting_configwill causeKeyErrorin the updatedskip_checks.The
routing_configdicts defined at lines 510–548 don't contain a"routing_method_type"key, but the new check inskip_checks(line 148 ofutils.py) accessesrouting_config["routing_method_type"]unconditionally. This will crash every test case in this file.Either add
"routing_method_type"to each routing config dict, or use.get()with a default inskip_checks:Option 1: Fix in utils.py (safer — handles callers that don't set the key)
if ( - routing_config["routing_method_type"] == RoutingMethodType.DeepSeekV3 + routing_config.get("routing_method_type") == RoutingMethodType.DeepSeekV3 and logits_dtype != torch.float32 ):Option 2: Fix in this test file (add routing_method_type to each config)
For the DSv3 config:
{ "num_experts": 256, "top_k": 8, + "routing_method_type": RoutingMethodType.DeepSeekV3, ... },And similarly for other configs with the appropriate
RoutingMethodType.tests/moe/test_trtllm_gen_fused_moe.py (1)
2883-2893:⚠️ Potential issue | 🔴 CriticalBug:
logits_dtypeandcache_permute_indicesarguments are swapped.The
run_moe_testsignature (line 2337) expectscache_permute_indicesas the 8th positional arg andlogits_dtypeas the 9th. Here, they are passed in the opposite order. This will causemoe_impl._cache_permute_indicesto be set to atorch.dtypeandexpert_logits.to(logits_dtype)to receive a dict, resulting in a runtime crash.Compare with
test_renormalize_routing(line 2695–2696),test_topk_routing(line 2975–2976), andtest_llama4_routing(line 3056–3057), which all pass the arguments in the correct order.🐛 Proposed fix
run_moe_test( num_tokens, hidden_size, intermediate_size, moe_impl, routing_config, weight_processing, activation_type, - logits_dtype, cache_permute_indices, + logits_dtype, )
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 288-298: The code currently sets mDtypeScore based solely on
routing_method_type which forces BFloat16 for non-DeepSeekV3 even when
routing_logits are float32; update the block that runs when
routing_logits.has_value() so mDtypeScore is derived from
routing_logits.value().dtype(): if routing_method_type ==
RoutingMethodType::DeepSeekV3 keep the TVM_FFI_ICHECK_EQ asserting dtype is
dl_float32 and set mDtypeScore = btg::Dtype::Fp32; otherwise inspect
routing_logits.value().dtype() and set mDtypeScore = btg::Dtype::Fp32 for
dl_float32, btg::Dtype::Bfloat16 for dl_bfloat16 (and error/ICHECK for
unsupported dtypes). Reference symbols: mDtypeScore, routing_logits,
RoutingMethodType::DeepSeekV3.
In `@tests/moe/utils.py`:
- Around line 155-162: The condition incorrectly compares type(moe_impl) to
QuantMode enum values causing all FP32-logits tests to skip; change the check to
inspect moe_impl.quant_mode instead. Update the if-statement that currently
reads "if logits_dtype == torch.float32 and type(moe_impl) not in
[QuantMode...]" to use "moe_impl.quant_mode not in [QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE, QuantMode.BF16]" so the pytest.skip call only
triggers for incompatible quant modes; keep the existing pytest.skip message and
variables (logits_dtype, moe_impl, QuantMode) unchanged.
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
2865-2870: DeepSeekV3 routing is only parametrized with FP32 logits — intentional?Unlike
test_renormalize_routingandtest_topk_routingwhich test both FP32 and BF16, this test only exercises FP32 logits. If BF16 logits are also a valid input for DeepSeekV3 routing in production, consider adding BF16 coverage here too.
|
|
||
| // Set dtype of score | ||
| if (routing_logits.has_value()) { | ||
| if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { | ||
| TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) | ||
| << "routing_logits must be float."; | ||
| mDtypeScore = btg::Dtype::Fp32; | ||
| } else { | ||
| mDtypeScore = btg::Dtype::Bfloat16; | ||
| } | ||
| } |
There was a problem hiding this comment.
mDtypeScore is derived from routing method, not from actual logits dtype — will misinterpret FP32 logits for non-DeepSeekV3 routing.
For non-DeepSeekV3 routing (e.g., Renormalize), mDtypeScore is unconditionally set to Bfloat16 regardless of the actual routing_logits dtype. If a caller passes float32 routing logits with Renormalize routing, the routing kernel will read 32-bit floats as 16-bit bfloat16 values, producing garbage results.
The score dtype should be derived from the actual tensor dtype, with DeepSeekV3 adding its own assertion on top:
Proposed fix
// Set dtype of score
if (routing_logits.has_value()) {
+ // Derive mDtypeScore from the actual routing_logits tensor dtype
+ if (routing_logits.value().dtype() == dl_float32) {
+ mDtypeScore = btg::Dtype::Fp32;
+ } else {
+ mDtypeScore = btg::Dtype::Bfloat16;
+ }
+
+ // DeepSeekV3 requires float32 routing logits
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
<< "routing_logits must be float.";
- mDtypeScore = btg::Dtype::Fp32;
- } else {
- mDtypeScore = btg::Dtype::Bfloat16;
}
}🤖 Prompt for AI Agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 288 - 298, The code
currently sets mDtypeScore based solely on routing_method_type which forces
BFloat16 for non-DeepSeekV3 even when routing_logits are float32; update the
block that runs when routing_logits.has_value() so mDtypeScore is derived from
routing_logits.value().dtype(): if routing_method_type ==
RoutingMethodType::DeepSeekV3 keep the TVM_FFI_ICHECK_EQ asserting dtype is
dl_float32 and set mDtypeScore = btg::Dtype::Fp32; otherwise inspect
routing_logits.value().dtype() and set mDtypeScore = btg::Dtype::Fp32 for
dl_float32, btg::Dtype::Bfloat16 for dl_bfloat16 (and error/ICHECK for
unsupported dtypes). Reference symbols: mDtypeScore, routing_logits,
RoutingMethodType::DeepSeekV3.
| if logits_dtype == torch.float32 and type(moe_impl) not in [ | ||
| QuantMode.FP8_PER_TENSOR, | ||
| QuantMode.FP8_BLOCK_SCALE, | ||
| QuantMode.BF16, | ||
| ]: | ||
| pytest.skip( | ||
| f"Incompatible: logits_dtype={logits_dtype} with {type(moe_impl).__name__} + {moe_impl.quant_mode}" | ||
| ) |
There was a problem hiding this comment.
Bug: type(moe_impl) is compared against QuantMode enum values — condition always skips.
type(moe_impl) returns the class (e.g., FP8BlockScaleMoe), not a QuantMode enum value. This comparison will never match, so all FP32 logits tests will be unconditionally skipped regardless of quant mode, silently defeating the purpose of this PR's test coverage.
The check should compare moe_impl.quant_mode instead, which is what the error message already references:
Proposed fix
- if logits_dtype == torch.float32 and type(moe_impl) not in [
+ if logits_dtype == torch.float32 and moe_impl.quant_mode not in [
QuantMode.FP8_PER_TENSOR,
QuantMode.FP8_BLOCK_SCALE,
QuantMode.BF16,
]:🤖 Prompt for AI Agents
In `@tests/moe/utils.py` around lines 155 - 162, The condition incorrectly
compares type(moe_impl) to QuantMode enum values causing all FP32-logits tests
to skip; change the check to inspect moe_impl.quant_mode instead. Update the
if-statement that currently reads "if logits_dtype == torch.float32 and
type(moe_impl) not in [QuantMode...]" to use "moe_impl.quant_mode not in
[QuantMode.FP8_PER_TENSOR, QuantMode.FP8_BLOCK_SCALE, QuantMode.BF16]" so the
pytest.skip call only triggers for incompatible quant modes; keep the existing
pytest.skip message and variables (logits_dtype, moe_impl, QuantMode) unchanged.
|
/bot run |
|
@yweng0828 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
hi @yweng0828 thx for the contrib
|
|
[SUCCESS] Pipeline #44270124: 16/20 passed |
|
Hi @aleozlx , thank you for your review. The PR is ready. Local testing has also passed. |
|
@yweng0828 Does the change also apply to |
📌 Description
This PR adds more template instantiation for supporting FP32 logits for routing when using
fp8_per_tensorandfp8_blockquantization.mDtypeScoreandmDtypeExpWand adds more template instantiation.🔍 Related Issues
#2469
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Tests