Feat: Trtllm-gen MxFP8 MoE integration#2505
Conversation
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an FP8 quantization enum and MxFP8 support across Python, C++ launchers, benchmarks, and tests; threads a new Changes
Sequence Diagram(s)mermaid CLI->>Autotuner: parse --quant-mode (e.g., MxFP8xMxFP8) Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the TensorRT-LLM fused Mixture-of-Experts (MoE) implementation by integrating MxFP8 quantization. This integration provides a new, flexible FP8 quantization option alongside the existing DeepSeek FP8, allowing for fine-grained control over mixed-precision computations. The changes span core kernel logic, benchmarking, and testing, ensuring that the new quantization mode is robustly supported and validated across the system. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates mxfp8 support into the trtllm fused MoE kernels. The changes are extensive, touching benchmark scripts, C++ kernel launchers, and Python bindings. The introduction of Fp8QuantizationType is a good refactoring that makes the code more extensible. The tests have also been updated to cover the new quantization modes.
My review focuses on improving code maintainability by reducing duplication in the benchmark scripts and C++ kernel launcher. I've also pointed out some leftover debugging code and minor issues that should be addressed before merging.
| print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") | ||
|
|
||
|
|
||
| def bench_trtllm_gen_fused_moe_autotuner_mxint4( |
There was a problem hiding this comment.
This function bench_trtllm_gen_fused_moe_autotuner_mxint4 is very similar to bench_trtllm_gen_fused_moe_autotuner_fp8 and bench_trtllm_gen_fused_moe_autotuner_fp4. To improve maintainability and reduce code duplication, consider refactoring these into a more generic benchmark function or a base class. This could accept quantization functions and the specific MoE kernel as parameters, centralizing the common benchmarking logic.
| FusedMoeLauncher::check_moe_common(); | ||
|
|
||
| TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) | ||
| << "hidden_states_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) | ||
| << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) | ||
| << "hidden_states_scale dim1 must match num_tokens."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) | ||
| << "hidden_states_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) | ||
| << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) | ||
| << "hidden_states_scale dim1 must match num_tokens."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_uint8); | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) | ||
| << "gemm1_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) | ||
| << "gemm1_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_uint8) | ||
| << "gemm1_weights_scale must be uint8."; | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) | ||
| << "gemm2_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) | ||
| << "gemm2_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_uint8) | ||
| << "gemm2_weights_scale must be uint8."; | ||
| } | ||
|
|
||
| check_weights_shape("gemm1"); | ||
| check_weights_shape("gemm2"); | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
|
|
||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| } | ||
| } |
There was a problem hiding this comment.
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
|
Hi @IwakuraRein . Currently we use this in sgl. However it seems like we are missing cubin for some dim. I build from src from this branch on this commit 1dc688d Context: we are building the sglang MXFP8 trtllm_moe runner along with mm_mxfp8 flashinfer modelopt linear, so this would be quite useful. If it turns out that my usages is wrong... user error. but even after inspect cubin, it seem like this shape should be available. Do you have any ideas? should there be tileSize=64 cubin? |
Signed-off-by: Siyuan Fu <[email protected]>
|
@vincentzed Hi. There are tile size 64 cubins for mxfp8. I tried your problem shape and cannot reproduce the error. Could you try pull the latest commit? 1dc688d won't compile due to a typo so maybe flashinfer is using the old jit cache. |
Signed-off-by: Siyuan Fu <[email protected]>
0adc056 to
aae1719
Compare
aleozlx
left a comment
There was a problem hiding this comment.
looks good overall.
posted a comment about GatedActType
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1073-1099:⚠️ Potential issue | 🟠 Major
getValidConfigsuses a differentRunnerconstructor thanprepare_moe_commonfor MxFp8, causing config mismatch.For MxFp8,
getValidConfigscreates the runner using the 5-param weights-only constructor (line 1085–1088):Runner(dtype_weights, /*useDeepSeekFp8=*/false, tile_N, use_shuffled_weight, weight_layout)But at runtime,
prepare_moe_common(lines 329–331) uses the 7-param act+weights constructor because the condition at line 323 checks forE4m3(notMxE4m3), which is false for MxFp8:Runner(mDtypeAct, mDtypeWeights, /*useDeepSeekFp8=*/false, tile_N, activation_type, ...)These constructors have different signatures and parameters (the 5-param variant lacks
activationType), so they may enumerate different kernel configs. This causes valid configs from autotuning to be rejected at runtime, potentially explaining the "No kernel found" errors for MxFp8 shapes.
🧹 Nitpick comments (2)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
89-89: Nit:scale_vec_sizeis unused in the MxFP8 path.When
quant_mode == "MxFP8xMxFP8",scale_vec_sizeis assigned32on this line but never referenced (it's only consumed inside theFp8-Blockbranch). Consider moving the assignment into theif quant_mode == "Fp8-Block"block.♻️ Suggested diff
- scale_vec_size = 128 if quant_mode == "Fp8-Block" else 32 if quant_mode == "Fp8-Block": + scale_vec_size = 128 # block scale quantization is too slow, so we use per-tensor quantization for nowcsrc/trtllm_fused_moe_kernel_launcher.cu (1)
44-63: C++ enum hasPerTensorFp8not present in the PythonFp8QuantizationType.The Python enum in
flashinfer/fused_moe/core.pydefinesNoneFp8=0,DeepSeekFp8=1,MxFp8=2, but the C++ side addsPerTensorFp8=3. If this variant isn't meant to be used from Python, consider adding a comment. Also, thedefaultlabel infp8QuantizationTypeToStringfalls through toNoneFp8, which silently masks unexpected values rather than flagging them.Suggested: make the default case explicit
switch (quantization_type) { - default: - case Fp8QuantizationType::NoneFp8: + case Fp8QuantizationType::NoneFp8: return "NoneFp8"; case Fp8QuantizationType::DeepSeekFp8: return "DeepSeekFp8"; case Fp8QuantizationType::MxFp8: return "MxFp8"; case Fp8QuantizationType::PerTensorFp8: return "PerTensorFp8"; + default: + return "Unknown(" + std::to_string(static_cast<int>(quantization_type)) + ")"; }
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1090-1101:⚠️ Potential issue | 🔴 Critical
getValidConfigsuses weights-onlyRunnerconstructor, but MxFp8 runtime uses the two-dtype constructor — config index mismatch.For MxFp8 (where
dtype_act == MxE4m3anddtype_weights == MxE4m3),getValidConfigsat line 1091 creates theRunnerwith 5 parameters:Runner(dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight, weight_layout). However, inprepare_moe_common(lines 333–335), the same MxFp8 scenario matches the else branch (the condition at line 327 checks for E4m3, not MxE4m3), causing it to call a different 7-parameter constructor:Runner(dtype_act, dtype_weights, useDeepSeekFp8, tile_tokens_dim, activation_type, use_shuffled_weight, weight_layout). Different constructors produce different valid config indices, so the autotuner may select a config that the runtime runner rejects, causing "No kernel found" errors.
1020-1022:⚠️ Potential issue | 🟡 MinorRemove unnecessary
static_cast<float*>on lines 1020–1022.The
args->hidden_states_scale,args->gemm1_weights_scale, andargs->gemm2_weights_scalefields inMoERunnerArgsare typed asvoid*, notfloat*. In the MxFp8 case, these holddl_uint8tensor pointers, so casting tofloat*is both unnecessary and misleading. Other code paths (e.g., lines 1180, 1189, 1419, 1430) assign these same fields without casting. Remove the casts and assigndata_ptr()directly.
🤖 Fix all issues with AI agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 179-180: In check_routing_logits_shape(), remove the unused local
declaration "int64_t intermediate_size_factor =
isGatedActivation(activation_type) ? 2 : 1;" that shadows the class member
intermediate_size_factor (defined on the class) or replace its usage to
reference the member instead; ensure the function uses the class member
intermediate_size_factor (or a properly named local if truly needed) so the
dead/shadowing local is eliminated.
- Around line 987-991: The MxFp8 branch under-allocates gemm1_output_scale by
using args->intermediate_size/32 instead of accounting for
intermediate_size_factor (causing under-allocation for gated activations);
update the computeSwizzledLayoutSFSize call in the Fp8QuantizationType::MxFp8
branch to use (intermediate_size_factor * args->intermediate_size) / 32 (i.e.
pass the full swizzled width consistent with the gemm1_output allocation) so
gemm1_output_scale and alloc_tensor({sf_size}, ...) match the actual
gemm1_output width; references: gemm1_output_scale, computeSwizzledLayoutSFSize,
max_num_padded_tokens_gemm1, args->intermediate_size, intermediate_size_factor,
Fp8QuantizationType::MxFp8.
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
|
/bot run |
|
[CANCELING] Pipeline #43998281: canceled |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1079-1105:⚠️ Potential issue | 🔴 Critical
getValidConfigsuses wrong Runner constructor for MxFp8, causing config mismatch with runtime.For MxFp8,
prepare_moe_common(lines 326–335) constructs the Runner with the two-dtype constructor (passingmDtypeAct,mDtypeWeights,activation_type) when the conditionE4m3 && E4m3 && mUseDeepSeekFp8is false. However,getValidConfigsalways uses the weights-only constructor (line 1091–1094), regardless ofquantization_type. This means config enumeration and the actual kernel runner see different valid config sets — the root cause of "No kernel found" errors at runtime.Proposed fix: branch getValidConfigs to match prepare_moe_common logic
for (int32_t tile_N : selected_tile_nums) { - auto moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( - dtype_weights, // dtype_weights for DeepSeek FP8 - quantization_type == Fp8QuantizationType::DeepSeekFp8, // useDeepSeekFp8 - tile_N, use_shuffled_weight, static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( + dtype_weights, true /* useDeepSeekFp8 */, tile_N, use_shuffled_weight, + static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + } else { + // MxFp8: match two-dtype constructor from prepare_moe_common + moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( + dtype_weights, dtype_weights, false /* useDeepSeekFp8 */, tile_N, + ActivationType::Swiglu, use_shuffled_weight, + static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + } auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens);
🧹 Nitpick comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1004-1012: MxFp8 path does not explicitly setworkspace.activation_output/workspace.activation_output_scale.Only the DeepSeekFp8 branch (lines 1007–1010) assigns these workspace pointers. The MxFp8 path relies on implicit zero-initialization. Consider explicitly setting them to
nullptrto be safe against future refactors whereprepare_moemight be re-entered or workspace partially reused.Proposed fix
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { workspace.activation_output = activation_output.data_ptr(); workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr()); + } else { + workspace.activation_output = nullptr; + workspace.activation_output_scale = nullptr; }
1006-1006:static_cast<float*>on adl_uint8tensor for MxFp8 — type mismatch in workspace pointer.For MxFp8,
gemm1_output_scaleis allocated asdl_uint8(line 990), but line 1006 unconditionally casts it tofloat*. The kernel likely consumes the raw address, but this cast is misleading and could mask bugs if the workspace struct gains type-safety. Consider avoid*intermediate or a comment noting the intentional reinterpretation.
Signed-off-by: Siyuan Fu <[email protected]>
3e0dbdd to
03cac02
Compare
Signed-off-by: Siyuan Fu <[email protected]>
|
/bot run |
|
[FAILED] Pipeline #44028049: 14/20 passed |
|
Hey, @IwakuraRein We want to use it with Nemotron models: |
Hi @danisereb, currently the cubins for Relu2 are not generated yet. We can add it in another PR. |
📌 Description
Author: @nekorobov
Add the trtllm-gen mxfp8 moe. It uses the existing
trtllm_fp8_block_scale_moeapi and can be selected by settingfp8_quantization_type🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Tests
Chores