-
Notifications
You must be signed in to change notification settings - Fork 590
feat: MxInt4 x Bf16 TRT-LLM Gen MoE support #2159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Nikita Korobov <[email protected]>
Signed-off-by: Nikita Korobov <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughTemplates were added to block-scale interleave kernels and host helpers to support uint8 and bfloat16; MXInt4 block-scale MoE support (launcher, config discovery, entry point, tests) was introduced; GEMM interfaces now use tg::CudaArch and include an EltwiseActType; artifact checksums were updated. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant PyAPI as trtllm_mxint4_block_scale_moe
participant LauncherMgr as MxInt4BlockScaleLauncher (config)
participant Router as Routing/Prepare
participant Kernel as MXInt4 MoE Kernel
Client->>PyAPI: call trtllm_mxint4_block_scale_moe(...)
PyAPI->>LauncherMgr: getValidConfigs(top_k, hidden_size, ...)
LauncherMgr-->>PyAPI: return valid tile configs
PyAPI->>LauncherMgr: instantiate per-tile launchers & init args
PyAPI->>Router: prepare_routing()
Router-->>Router: compute assignments / expert offsets
PyAPI->>LauncherMgr: select launcher by config_index
LauncherMgr->>Kernel: launch selected MXInt4 kernel (MXInt4 weights, BF16 activations)
Kernel-->>LauncherMgr: output tensor
LauncherMgr-->>PyAPI: return results
PyAPI-->>Client: deliver MoE output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🪛 Ruff (0.14.7)flashinfer/fused_moe/core.py2011-2011: Unused function argument: (ARG001) 2012-2012: Unused function argument: (ARG001) 2014-2014: Unused function argument: (ARG001) 2015-2015: Unused function argument: (ARG001) 2016-2016: Unused function argument: (ARG001) 2017-2017: Unused function argument: (ARG001) 2018-2018: Unused function argument: (ARG001) 2019-2019: Unused function argument: (ARG001) 2020-2020: Unused function argument: (ARG001) 2021-2021: Unused function argument: (ARG001) 2022-2022: Unused function argument: (ARG001) 2023-2023: Unused function argument: (ARG001) 2024-2024: Unused function argument: (ARG001) 2025-2025: Unused function argument: (ARG001) 2026-2026: Unused function argument: (ARG001) 2027-2027: Unused function argument: (ARG001) 2028-2028: Unused function argument: (ARG001) 2029-2029: Unused function argument: (ARG001) 2030-2030: Unused function argument: (ARG001) 2031-2031: Unused function argument: (ARG001) 2032-2032: Unused function argument: (ARG001) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the quantization capabilities of the TRT-LLM Mixture-of-Experts (MoE) implementation by adding comprehensive support for MxInt4 x Bf16 quantization. This enables more efficient processing of large language models by leveraging lower-precision data types for weights while maintaining Bfloat16 for activations. The changes span across kernel implementations, data type definitions, Python bindings, and testing infrastructure, ensuring a robust and performant integration of the new quantization scheme. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for MxInt4 x Bf16 TRT-LLM Gen MoE, expanding the quantization capabilities of the system. The changes involve templatizing CUDA kernels and host-side functions to handle __nv_bfloat16 types for block scale interleaving, updating Python bindings and test infrastructure to integrate the new MxInt4 mode, and refactoring internal batched_gemm components to use a more generalized CudaArch enum. Overall, the implementation appears consistent with the existing codebase, but I've identified a few areas for improvement regarding documentation clarity and a hardcoded constraint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (7)
flashinfer/utils.py (1)
786-803: Consider adding a docstring.The function lacks documentation explaining:
- The purpose and meaning of
num_elts_per_sf- Why specific dtypes (uint8, bfloat16) are required
- The relationship between
epilogue_tile_mand shuffle block sizeAdding a docstring would improve maintainability, especially given the expanding type support.
csrc/nv_internal/cpp/kernels/quantization.cu (1)
250-250: Verify zero-initialization for__nv_bfloat16.
T sf = 0;may not work correctly for__nv_bfloat16since it's a class type. Consider using explicit initialization:- T sf = 0; + T sf = T{};This ensures proper default construction for both
uint8_tand__nv_bfloat16.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (1)
186-205: LDG/LdgPlusSts K‑tiling check has an unreachable branchGiven the invariant at Lines 262–268 that any
mRouteSfsImpl != mRouteImplmust beLdgsts/LdgPlusStswithmRouteImpl == Tma, the conditionif (doesRouteImplUseLdgsts(options.mRouteImpl) && doesRouteImplUseLdgPlusSts(options.mRouteSfsImpl.value()))can never be true (it requires
mRouteImpl == Ldgstsand simultaneouslymRouteImpl == Tma). The effective K‑tiling guard is the second block that checksmRouteSfsImplforLdgsts/LdgPlusSts.You can safely drop or simplify the first
ifblock to avoid dead code and rely on the SF‑routing check alone.Also applies to: 348-361
flashinfer/fused_moe/core.py (3)
1883-2007: MXInt4 custom op integration is consistent; fake op needs unused‑arg suppressionThe
trtllm_mxint4_block_scale_moe_opcustom op follows the BF16/FP4 patterns: it builds aMoERunnerwithdtype_act=Bfloat16,dtype_weights=MxInt4, usesWeightLayout.BlockMajorK, and autotunes viatune_max_num_tokens, then calls into the C++trtllm_mxint4_block_scale_moelauncher. That wiring looks correct.For
_fake_trtllm_mxint4_block_scale_moe, Ruff is correctly flagging many unused parameters. Since the signature must mirror the real op, consider adding a no‑op usage line instead of renaming/removing parameters, e.g.:# Keep signature in sync with real op; arguments are unused in the fake path. _ = ( routing_logits, routing_bias, gemm1_weights, gemm1_weights_scale, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, routing_method_type, enable_pdl, output, tune_max_num_tokens, )This will satisfy ARG001 without impacting behavior.
2542-2631: MXInt4 high‑level API: type hint and docstring are inconsistent with implementation
- The function is annotated as
-> List[torch.Tensor]but returns the single Tensor produced bytrtllm_mxint4_block_scale_moe_op, so the type hint should betorch.Tensor.- The docstring text is still tailored to FP4/NVFP4 (mentions “packed fp4” weights and float8 scales). The underlying C++ MXInt4 launcher validates
gemm*_weightsasuint8andgemm*_weights_scaleas BF16, so those dtype descriptions should be updated to match the MXInt4 path.Adjusting the return annotation and docstring to the actual MXInt4 semantics will avoid confusion for users of this API.
184-215: Consider extending is_trtllm_moe_supported to cover MXInt4 weights
is_trtllm_moe_supportedcurrently whitelists BF16/E4m3/E2m1/MxE2m1 weights only. With the new BF16‑act + MxInt4 path wired throughMoERunnerand the MXInt4 kernel launcher, any callsites that rely on this helper to gate “is this config supported?” will still treat MXInt4 as unsupported.If the intent is for MXInt4 to be usable via the generic TRT‑LLM MoE dispatcher (not just the dedicated
trtllm_mxint4_block_scale_moewrapper), you likely want to addDtypeTrtllmGen.MxInt4here with appropriate dtype_act constraints.csrc/trtllm_fused_moe_kernel_launcher.cu (1)
941-1093: MxInt4 launcher wiring is sensible; consider adding the same shape/layout checks as other launchersThe new
MxInt4BlockScaleLaunchercorrectly:
- Forces BF16 activations and MXInt4 weights in
init.- Sets routing dtypes and allocates BF16
expert_weightsinprepare_routing.- Validates
gemm1_weights/gemm2_weightsasuint8and the corresponding scales as BF16 incheck_moe.- Populates
MoERunnerArgsfields and allocatesgemm1_output/gemm2_outputwith BF16 inprepare_moe, wiring the workspace pointers consistently with the other launchers.One difference vs. BF16/FP8/FP4 launchers is that
check_moehere does not callFusedMoeLauncher::check_moe_common()orcheck_weights_shape("gemm1"/"gemm2"), so hidden‑state and weight shapes/layouts are not validated on the MxInt4 path. For consistency and safer failure modes, it would be good to reuse those common checks, e.g.:void check_moe() const override { FusedMoeLauncher::check_moe_common(); check_weights_shape("gemm1"); check_weights_shape("gemm2"); TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::Bfloat16) << "Only Bfloat16 is supported by MxInt4 block scale MoE"; TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_uint8) << "gemm1_weights must be uint8."; TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_bfloat16) << "gemm1_weights_scale must be bf16."; TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_uint8) << "gemm2_weights must be uint8."; TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_bfloat16) << "gemm2_weights_scale must be bf16."; }That keeps MXInt4 validation aligned with the existing MoE launchers.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.his excluded by!**/gen/**
📒 Files selected for processing (19)
csrc/nv_internal/cpp/kernels/quantization.cu(2 hunks)csrc/nv_internal/tensorrt_llm/kernels/quantization.h(1 hunks)csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp(3 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(5 hunks)flashinfer/artifacts.py(2 hunks)flashinfer/fp4_quantization.py(2 hunks)flashinfer/fused_moe/__init__.py(2 hunks)flashinfer/fused_moe/core.py(4 hunks)flashinfer/utils.py(1 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h(5 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h(6 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h(1 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h(4 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h(12 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h(6 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h(1 hunks)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h(2 hunks)tests/moe/test_trtllm_gen_fused_moe.py(9 hunks)tests/moe/utils.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
🧬 Code graph analysis (11)
tests/moe/utils.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (2)
intermediate_size(275-275)hidden_size(265-265)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)
include/flashinfer/trtllm/fused_moe/runner.h (1)
num_experts(263-263)csrc/nv_internal/cpp/kernels/quantization.cu (4)
invokeBlockScaleInterleave(292-302)invokeBlockScaleInterleave(292-293)invokeBlockScaleInterleave(305-307)invokeBlockScaleInterleave(308-312)
flashinfer/fp4_quantization.py (1)
flashinfer/fp8_quantization.py (1)
_compute_swizzled_layout_sf_size(15-18)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h (1)
CudaArch(36-93)include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (1)
mSm(393-394)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (5)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)
gemm(30-297)buildSfTmaDescriptor(194-289)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (2)
Dtype(43-274)dtypeNumEltsPerSf(201-213)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h (1)
SfLayout(37-91)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h (1)
ceilDiv(42-44)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (1)
MmaKind(36-107)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
Dtype(43-274)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
Dtype(43-274)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h (1)
CudaArch(36-93)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (2)
doesRouteImplUseLdgsts(45-45)doesRouteImplUseLdgPlusSts(53-53)
csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)
csrc/nv_internal/cpp/kernels/quantization.cu (6)
void(244-268)void(270-288)invokeBlockScaleInterleave(292-302)invokeBlockScaleInterleave(292-293)invokeBlockScaleInterleave(305-307)invokeBlockScaleInterleave(308-312)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (9)
maybeGetMinTokenCount(55-60)top_k(270-270)hidden_size(265-265)intermediate_size(275-275)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)local_num_experts(277-277)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(89-92)flashinfer/fused_moe/core.py (1)
trtllm_mxint4_block_scale_moe(2543-2631)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (2)
trtllm(38-277)gen(39-275)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (2)
trtllm(30-110)gen(31-108)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
trtllm(28-90)gen(29-89)launchKernel(34-84)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (2)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (2)
SmVersion(361-1321)mSm(393-394)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h (2)
CudaArch(36-93)isArchBlackwell(51-54)
🪛 Ruff (0.14.7)
tests/moe/test_trtllm_gen_fused_moe.py
613-613: Unused method argument: hidden_states_sample
(ARG002)
643-643: Unused method argument: unused_args
(ARG002)
652-652: Unused method argument: args_dequant
(ARG002)
654-654: Unused method argument: gemm1_weights_orig
(ARG002)
655-655: Unused method argument: gemm2_weights_orig
(ARG002)
656-656: Unused method argument: hidden_size
(ARG002)
657-657: Unused method argument: intermediate_size
(ARG002)
659-659: Unused method argument: weight_processing
(ARG002)
751-751: Unused method argument: hidden_states_scale_global
(ARG002)
flashinfer/fused_moe/core.py
2010-2010: Unused function argument: routing_logits
(ARG001)
2011-2011: Unused function argument: routing_bias
(ARG001)
2013-2013: Unused function argument: gemm1_weights
(ARG001)
2014-2014: Unused function argument: gemm1_weights_scale
(ARG001)
2015-2015: Unused function argument: gemm1_alpha
(ARG001)
2016-2016: Unused function argument: gemm1_beta
(ARG001)
2017-2017: Unused function argument: gemm1_clamp_limit
(ARG001)
2018-2018: Unused function argument: gemm2_weights
(ARG001)
2019-2019: Unused function argument: gemm2_weights_scale
(ARG001)
2020-2020: Unused function argument: num_experts
(ARG001)
2021-2021: Unused function argument: top_k
(ARG001)
2022-2022: Unused function argument: n_group
(ARG001)
2023-2023: Unused function argument: topk_group
(ARG001)
2024-2024: Unused function argument: intermediate_size
(ARG001)
2025-2025: Unused function argument: local_expert_offset
(ARG001)
2026-2026: Unused function argument: local_num_experts
(ARG001)
2027-2027: Unused function argument: routed_scaling_factor
(ARG001)
2028-2028: Unused function argument: routing_method_type
(ARG001)
2029-2029: Unused function argument: enable_pdl
(ARG001)
2030-2030: Unused function argument: output
(ARG001)
2031-2031: Unused function argument: tune_max_num_tokens
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (43)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
292-292: LGTM. Blank line insertion for readability after the comment block.flashinfer/utils.py (1)
789-790: The review is based on code that does not exist in the repository.The snippet shown in your review comment does not match the actual code at lines 789-790 in
flashinfer/utils.py. The assertions are NOT expanded as claimed:
- Actual line 789:
assert input_tensor.dtype == torch.uint8(only uint8)- Actual line 790:
assert num_elts_per_sf == 16(only allows 16)Your review snippet claims these are expanded to
torch.uint8 or torch.bfloat16and16 or 32, but the actual code is more restrictive. Additionally, thenum_elts_per_sfparameter is not unused—it is actively passed by callers inflashinfer/fused_moe/core.py(lines 233-236, 260-263), though the current assertion rejects any value other than 16.Verify you are reviewing the correct file version or branch.
Likely an incorrect or invalid review comment.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h (1)
90-107: LGTM! Well-documented enum additions.The new
EltwiseActTypeenum and thePersistentscheduler option are clearly documented and follow the existing code patterns in this file.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)
54-66: LGTM! Bfloat16 and MxInt4 dtype support added correctly.The addition of Bfloat16 handling and the inclusion of MxInt4 in the same branch as MxE2m1 are consistent with the type definitions in DtypeDecl.h, where both are 4-bit block formats.
198-206: Good defensive initialization practice.Using brace initialization for
tmaDataFormatensures it has a defined value before the conditional branches, which is a good defensive programming practice.flashinfer/artifacts.py (2)
92-93: Artifact path updated for new batched GEMM version.The artifact path has been updated to reflect the new compiled kernels that include MXInt4 support.
113-113: Correct the artifact repository URL reference and verify the TRTLLM_GEN_BMM checksum against the NVIDIA Artifactory source.The checksum hash
b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fcshould be verified against the checksums.txt file published at the NVIDIA Artifactory repository (https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841/checksums.txt), not the S3 URL. The artifact path in the code is correct and points to the right location.tests/moe/utils.py (2)
33-33: LGTM! New quantization mode added.The
MXINT4_BF16_BF16quantization mode is properly added to support the MXInt4 path introduced in this PR.
89-94: Appropriate skip condition for MXInt4 alignment requirements.The skip condition correctly enforces that both
intermediate_sizeandhidden_sizemust be multiples of 256 for MXInt4 quantization, which aligns with the hardware requirements for this format.flashinfer/fused_moe/__init__.py (1)
34-34: LGTM! Public API export for MXInt4 MoE.The new
trtllm_mxint4_block_scale_moefunction is properly exported, making it available as part of the public API.Also applies to: 58-58
flashinfer/fp4_quantization.py (2)
261-282: LGTM! Generalized dtype support for block scale interleave.The function now correctly supports both
uint8andbfloat16input dtypes, with the output dtype mirroring the input. This aligns with the templated C++ kernel implementation.
692-717: Proper dtype validation for generalized interleave function.The assertion correctly validates that the input tensor is either
uint8orbfloat16, with a clear error message when the constraint is violated.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (4)
21-21: LGTM! Required header for ModuleCache.The
<unordered_map>header is needed for theModuleCachetype definition on line 459.
126-138: Documentation updated for MXInt4 support.The comments now correctly document that MxInt4 format uses
Dtype::Bfloat16for scaling factors, which aligns with the implementation changes in TmaDescriptor.h and the type definitions in DtypeDecl.h.
581-593: Improved PDL safety check and error handling.The changes introduce two improvements:
PDL safety: The
pdlSafeboolean correctly determines when PDL can be safely enabled based on grid wait conditions, providing more precise control than the previous implementation.Error propagation: Returning the actual CUDA error code (
result) instead of a hardcoded-1provides more useful diagnostic information to callers.
725-726: Simplified architecture parameter passing.Using
config.mSmdirectly instead of computing an intermediateisBlackwellvariable simplifies the code while maintaining the same functionality.csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)
70-72: LGTM! Template generalization for multi-dtype support.Converting
invokeBlockScaleInterleaveto a template function allows it to support multiple types (uint8_t and __nv_bfloat16) in a type-safe manner. This aligns with the explicit template instantiations in the implementation file (quantization.cu).include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (2)
122-123: LGTM: Interface update from boolean to explicit CUDA architecture enum.The parameter change from
bool isBlackwelltotg::CudaArch cudaArchimproves type safety and extensibility for supporting multiple GPU architectures (Hopper, Blackwell, Blackwell Ultra).
215-215: Verify default architecture choice.The default
tg::CudaArch::Sm100a(Blackwell) is used here. Confirm this is the intended default for new configs, especially if Hopper (Sm90a) support is still needed.tests/moe/test_trtllm_gen_fused_moe.py (5)
684-688: Inconsistentnum_elts_per_sfbetween gemm1 and gemm2.For gemm1 scales (line 688),
num_elts_per_sf=32is used, but for gemm2 scales (line 715),num_elts_per_sf=16is used. This asymmetry is unexpected for MxInt4 where the block size should be consistent.Verify this is intentional based on kernel requirements, or if this should be
32for both.Also applies to: 711-715
2475-2481: All other MoE implementations are commented out.Only
MxInt4BlockScaleMoeis enabled in the test parametrization while others (BF16, FP8, FP4 variants) are commented out. This appears to be for focused testing during development.Ensure all implementations are re-enabled before merging to maintain test coverage.
1940-1942: LGTM: Extended dequant path to handle MxInt4xBf16.The comment correctly lists all applicable modes where activation output uses bf16 without additional quantization.
591-607: The scales reshape is correct. Themxint4_quantizefunction returns scales with shape(-1, sf_vec_size), where each element represents one block's scale factor. The caller inprepare_static_weights_for_mxint4immediately reshapes these scales to match kernel expectations: for gemm1, shape becomes(num_experts, 2*intermediate_size//sf_vec_size, hidden_size//sf_vec_size), and for gemm2, shape becomes(num_experts, hidden_size//sf_vec_size, intermediate_size//sf_vec_size). This correctly distributes one scale value persf_vec_size-element block across the weight matrix dimensions, which is the expected format for block-scaled quantization kernels.Likely an incorrect or invalid review comment.
750-786: Thetrtllm_mxint4_block_scale_moekernel signature acceptsrouting_biasandrouted_scaling_factorparameters, but the Python binding is not yet imported and the test implementation is incomplete.The C++ kernel signature in
csrc/trtllm_fused_moe_kernel_launcher.cuincludesOptional<TensorView> routing_biasandOptional<double> routed_scaling_factorparameters. However, the corresponding Python binding is not imported in the test file, andMxInt4BlockScaleMoe.call_moe()is currently just a TODO stub. When this implementation is completed, ensure these parameters are extracted from kwargs and passed to the kernel, consistent with theFP8BlockScaleMoeimplementation.csrc/nv_internal/cpp/kernels/quantization.cu (2)
270-288: Reverse interleave kernel not templated.
block_scale_interleave_reverse_kernelandinvokeBlockScaleInterleaveReverseremain hardcoded foruint8_t. If bfloat16 reverse interleaving is needed in the future, this will require similar templating.This is acceptable if reverse is only used for uint8 scales currently.
Also applies to: 315-324
304-312: LGTM: Template instantiations added for both supported types.Explicit instantiations for
uint8_tand__nv_bfloat16ensure the templated launcher is available for both dtypes used by the host code.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (3)
88-91: LGTM: Cleaner defaulting for valid dimensions.The ternary-based defaulting is more concise and readable than the previous explicit
-1checks.
417-426: LGTM: MxInt4 dtype correctly maps to Bfloat16 for scaling factors.The dtype mapping logic for A's scaling factors now handles:
E2m1→E4m3(FP8 scales)MxInt4→Bfloat16(BF16 scales for INT4 weights)- Other MX types →
UE8m0This aligns with the MxInt4 block-scale design where scales are stored as BF16.
410-411: Valid dimension propagation for routed activations.The changes correctly pass valid dimensions when constructing TMA shape/stride for activation matrices. For routed activations (
useRouteAct), usingoptions.mNumTokensas the valid dimension ensures proper bounds.Also applies to: 459-460, 518-520
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)
186-187: LGTM: Dtype validation added for supported scale types.The check correctly validates that block scales are either uint8 or bfloat16 before proceeding.
205-227: LGTM: Dtype-conditional dispatch for CUDA and CPU paths.The branching correctly dispatches to the appropriate templated function based on dtype for both CUDA kernel invocation and host-side processing.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
85-133: Eltwise activation threading into GemmOptions looks consistentThe new
eltwiseActTypeparameter is plumbed fromBatchedGemmOptionsintogemm::GemmOptionsin the same relative position as in the updatedGemmOptionsconstructor, so the wiring here looks correct.
378-394: Switching BatchedGemmConfig.mSm to tg::CudaArch aligns with new arch handlingChanging
mSmtotg::CudaArchis consistent with the new CudaArch‑based validation APIs and should help keep arch handling uniform across GEMM/BatchedGEMM paths.flashinfer/fused_moe/core.py (2)
104-127: Keep DtypeTrtllmGen MxInt4 / UInt encodings in strict sync with C++ DtypeDecl*The new
MxInt4,UE8m0, andUInt*/VoidUIDs look reasonable, but they must exactly match the bitfields and UID ordering intrtllm/gen/DtypeDecl.h. Any drift will silently mis‑decode dtypes at the C++ level. Please double‑check the C++ enum to confirm these Python encodings are identical.
1157-1184: MoERunner MXInt4 branch wiring looks coherentThe new MXInt4 path in
MoERunner.forward(BF16 activations +DtypeTrtllmGen.MxInt4weights) passes the same kwargs (gemm1_weights[_scale],gemm1_alpha/beta/clamp_limit,gemm2_weights[_scale], routing params,output, tactic) as the C++ launcher expects. This matches the MXInt4 launcher signature incsrc/trtllm_fused_moe_kernel_launcher.cu, so the dispatcher logic here looks consistent.csrc/trtllm_fused_moe_kernel_launcher.cu (3)
1785-1866: C++ MXInt4 entrypoint is consistent with Python bindings and launcherThe new
trtllm_mxint4_block_scale_moefunction:
- Enforces routing logits dtypes/shapes and disallows
routing_biasfor MXInt4 (matching the Python wrapper’srouting_bias=None).- Requires
gemm*_weightsto beuint8andweight_scale_vec_size == 32, aligning with MXInt4 packing and the BF16‑scale checks inMxInt4BlockScaleLauncher::check_moe.- Builds one
MxInt4BlockScaleLauncherper selected tile_N, setsMoERunnerArgsfields (tokens, experts, hidden/intermediate sizes, local expert layout, routed_scaling_factor, do_finalize/output), and selects a config based onconfig_indexor defaults.This is in line with the BF16 / FP4 launcher patterns, and the TVM_FFI export at Line 1925 wires it into the FFI surface correctly.
1868-1895: MXInt4 getValidConfigs integration in trtllm_get_valid_moe_configs is correctThe early branch for
dtype_act=Bfloat16 && dtype_weights=MxInt4dispatches toMxInt4BlockScaleLauncher::getValidConfigs, which in turn instantiates aMoE::Runnerwith BF16/MxInt4, shuffled A, andBlockMajorK. This ensures the autotuner sees the same config space as the runtime launcher. The rest of the dtype combinations remain unchanged.
1921-1926: FFI export for trtllm_mxint4_block_scale_moe is in placeThe additional
TVM_FFI_DLL_EXPORT_TYPED_FUNCfortrtllm_mxint4_block_scale_moemakes the MXInt4 path available to the Python JIT module; nothing else to flag here.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (5)
23-27: Eltwise activation type is correctly integrated into GemmOptionsIncluding
CudaArchDecl.hand extending theGemmOptionsconstructor withEltwiseActType eltwiseActType(stored inmEltwiseActTypeand dumped viadumpOptions) is consistent with the rest of the GEMM stack and with the new BatchedGemmOptions ctor. No issues here.Also applies to: 106-144
398-421: SmVersion aliasing and GemmConfig.mSm migration to tg::CudaArch look goodReplacing the local
SmVersionenum withusing SmVersion = tg::CudaArchand updatingGemmConfig.mSmtotg::CudaArchaligns this header with the shared CUDA arch representation inCudaArchDecl.h, simplifying architecture checks.
611-617: checkAndUpdateGemmOptions: cudaArch‑based isBlackwell and validM/N/K init are reasonable
- Switching
checkAndUpdateGemmOptionsto taketg::CudaArch cudaArchand derivingisBlackwellviatg::isArchBlackwell(cudaArch)centralizes arch logic and avoids out‑of‑band booleans.- The new
< 0checks formValidM/N/Kpreserve earlier semantics (defaulting to full M/N/K when unset) while handling any negative sentinel consistently.Callers just need to ensure they pass the actual SM’s
CudaArchrather than a derived boolean.Also applies to: 642-651
686-695: A‑side cast check now correctly allows MxInt4→BF16The updated A‑cast constraint:
((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1 || options.mDtypeA == tg::Dtype::MxInt4) && options.mDtypeMmaA == tg::Dtype::Bfloat16)is exactly what the MXInt4 GEMM path needs (MxInt4 weights cast to BF16 MMA inputs) and remains compatible with the existing NVFP4/MxFP4 behavior.
786-808: New Blackwell‑only and epilogue/scheduler constraints are consistent with hardware limits
- LDTM shape checks now branch on
isBlackwell, constraining Hopper to16dp256bitwhile allowing16dp256bitor32dp32biton Blackwell, plus the extra guards for transposed outputs and epilogueTileM=64.- Enforcing
options.mMmaM == 128or specific tile shapes for certain MxFp4/MxFp8 paths, and gating DeepSeek/block‑scaled features onisBlackwell, protects against unsupported tensor core combinations.- The new block:
if (isBlackwell && !options.mUseCustomMmaSchedule && !options.mUseDeepSeekFp8 && options.mTileScheduler == TileScheduler::Persistent) { ... }ensures persistent scheduling on Blackwell always uses the custom MMA schedule, which matches the intent of the comments around custom scheduling.
- When
mNumEpilogueWarps > 4, requiringTileNto be a multiple ofEpilogueTileN * numEpilogueWrpGrpsis a sensible layout constraint for multi‑warp epilogues.Overall these validations look correct and should fail fast on unsupported configs rather than letting kernels misbehave.
Also applies to: 987-1010, 1300-1318, 1456-1459
Signed-off-by: Nikita Korobov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
788-793: LGTM - Docstring has been corrected.The docstring now correctly says "MXINT4-specific" instead of "FP4-specific" as noted in previous reviews.
🧹 Nitpick comments (2)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (1)
186-227: Add output tensor validation inBlockScaleInterleaveto catch mismatched buffers earlyThe dtype guard on
blockScaleplus the CUDA/CPU dispatch by dtype is good. However,interleavedBlockScaleis assumed to have matching dtype, contiguity, and at leastnum_experts * expert_out_sizeelements without any checks. A mismatched output tensor (e.g., wrong dtype or too small) would lead to hard-to-debug memory issues.Consider adding symmetric validation, e.g.:
CHECK_CONTIGUOUS(blockScale); TVM_FFI_ICHECK(blockScale.dtype() == dl_uint8 || blockScale.dtype() == dl_bfloat16) << "Block Scale must be uint8 or bfloat16."; + CHECK_CONTIGUOUS(interleavedBlockScale); + TVM_FFI_ICHECK_EQ(interleavedBlockScale.dtype(), blockScale.dtype()) + << "interleavedBlockScale must have the same dtype as blockScale."; + auto blockScaleShape = blockScale.sizes(); ... + TVM_FFI_ICHECK_EQ(interleavedBlockScale.numel(), num_experts * expert_out_size) + << "interleavedBlockScale has incorrect size for the given blockScale.";This keeps the public behavior the same when used correctly but fails fast with a clear error message if the caller wires up the wrong buffer.
tests/moe/test_trtllm_gen_fused_moe.py (1)
591-607: Clarifyscalesreshape logic.The returned
scales.reshape(-1, sf_vec_size)at line 606 is confusing. Thescalestensor has shape(num_groups, 1)wherenum_groups = total_elements / sf_vec_size. The reshape attempts to produce shape(num_groups/sf_vec_size, sf_vec_size), which only works ifnum_groupsis divisible bysf_vec_size.While this works for the current test cases (power-of-2 dimensions), the semantic meaning is unclear and the caller at line 625 immediately reshapes the result anyway.
Consider simplifying to return a flat tensor and let the caller handle the reshape:
- return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.reshape( - -1, sf_vec_size - ) + return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.view(-1)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp(3 hunks)flashinfer/fused_moe/core.py(4 hunks)tests/moe/test_trtllm_gen_fused_moe.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)
include/flashinfer/trtllm/fused_moe/runner.h (1)
num_experts(263-263)csrc/nv_internal/cpp/kernels/quantization.cu (4)
invokeBlockScaleInterleave(292-302)invokeBlockScaleInterleave(292-293)invokeBlockScaleInterleave(305-307)invokeBlockScaleInterleave(308-312)
🪛 Ruff (0.14.7)
tests/moe/test_trtllm_gen_fused_moe.py
613-613: Unused method argument: hidden_states_sample
(ARG002)
643-643: Unused method argument: unused_args
(ARG002)
652-652: Unused method argument: args_dequant
(ARG002)
654-654: Unused method argument: gemm1_weights_orig
(ARG002)
655-655: Unused method argument: gemm2_weights_orig
(ARG002)
656-656: Unused method argument: hidden_size
(ARG002)
657-657: Unused method argument: intermediate_size
(ARG002)
659-659: Unused method argument: weight_processing
(ARG002)
751-751: Unused method argument: hidden_states_scale_global
(ARG002)
flashinfer/fused_moe/core.py
2010-2010: Unused function argument: routing_logits
(ARG001)
2011-2011: Unused function argument: routing_bias
(ARG001)
2013-2013: Unused function argument: gemm1_weights
(ARG001)
2014-2014: Unused function argument: gemm1_weights_scale
(ARG001)
2015-2015: Unused function argument: gemm1_alpha
(ARG001)
2016-2016: Unused function argument: gemm1_beta
(ARG001)
2017-2017: Unused function argument: gemm1_clamp_limit
(ARG001)
2018-2018: Unused function argument: gemm2_weights
(ARG001)
2019-2019: Unused function argument: gemm2_weights_scale
(ARG001)
2020-2020: Unused function argument: num_experts
(ARG001)
2021-2021: Unused function argument: top_k
(ARG001)
2022-2022: Unused function argument: n_group
(ARG001)
2023-2023: Unused function argument: topk_group
(ARG001)
2024-2024: Unused function argument: intermediate_size
(ARG001)
2025-2025: Unused function argument: local_expert_offset
(ARG001)
2026-2026: Unused function argument: local_num_experts
(ARG001)
2027-2027: Unused function argument: routed_scaling_factor
(ARG001)
2028-2028: Unused function argument: routing_method_type
(ARG001)
2029-2029: Unused function argument: enable_pdl
(ARG001)
2030-2030: Unused function argument: output
(ARG001)
2031-2031: Unused function argument: tune_max_num_tokens
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (1)
170-173: BF16 explicit instantiation: ensure proper header availabilityThe explicit instantiation for
blockScaleInterleaveHost<__nv_bfloat16>looks correct and matches the CUDA-side templating. Just make sure the compilation unit always sees the definition of__nv_bfloat16(via the appropriate CUDA header) on all supported toolchains; otherwise this instantiation could fail to compile in some environments.tests/moe/test_trtllm_gen_fused_moe.py (3)
2176-2224: LGTM - Reference dequantization implementation is correct.The MxInt4 dequantization logic properly:
- Unpacks two 4-bit values from each byte using bitwise operations
- Converts unsigned nibbles to signed two's complement values in [-8, 7]
- Applies block scales correctly
2478-2582: LGTM - Test parametrization correctly includes MxInt4BlockScaleMoe.The
MxInt4BlockScaleMoeis properly added to test parametrizations with compatible routing configs and weight processing (BlockMajorK layout with shuffled weights).
711-715: Inconsistentnum_elts_per_sfbetween gemm1 and gemm2 scale permutation requires verification.For gemm1 scales (line 688),
num_elts_per_sf=32is used, but for gemm2 scales here,num_elts_per_sf=16is used. Verify if this difference is intentional based on MxInt4'ssf_vec_sizeor if both should usenum_elts_per_sf=32for consistency.Reference: FP4Moe implementation uses
num_elts_per_sf=16for both occurrences (lines 454 and 481), which aligns with itssf_vec_size=16.flashinfer/fused_moe/core.py (4)
119-126: LGTM - MxInt4 enum value properly added with UID shift.The
MxInt4dtype is correctly added with bit format(1, 1, 1, 4, 14)and subsequent UIDs are properly incremented.
1157-1184: LGTM - MxInt4 dispatch path correctly added.The dispatch for
dtype_act == BF16anddtype_weights == MxInt4is properly wired to callmoe_op.trtllm_mxint4_block_scale_moewith the appropriate arguments.
2038-2044: LGTM - Module exports correctly updated.The
SimpleNamespacereturn value properly includestrtllm_mxint4_block_scale_moealongside other MoE operations.
193-198: No changes needed - MxInt4 is not a valid DtypeTrtllmGen enum value.The
is_trtllm_moe_supportedfunction correctly lists all supported weight types. The DtypeTrtllmGen enum does not includeMxInt4; the only mixed-precision variants areMxE2m1andMxE4m3, both of which are already in the supported list. References toMxInt4in the test file are for test utilities only, not DtypeTrtllmGen enum members.Likely an incorrect or invalid review comment.
| for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) { | ||
| T* interleavedBlockScalePtr = | ||
| static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; | ||
| for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) { | ||
| auto globalRowIdx = eIdx * rows + rIdx; | ||
| T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols; | ||
| for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) { | ||
| T sf_ori = 0; | ||
| if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) { | ||
| sf_ori = blockScalePtr[cIdx]; | ||
| } | ||
| int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, | ||
| tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); | ||
| interleavedBlockScalePtr[sf_index] = sf_ori; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid out-of-bounds pointer arithmetic for padded rows in blockScaleInterleaveHost
For padded rows (rIdx >= rows), globalRowIdx = eIdx * rows + rIdx and the derived blockScalePtr can point past the end of the blockScale buffer, even though it’s not dereferenced when rIdx >= rows. This is technically undefined behavior and easy to avoid by only forming the row pointer when the row is valid.
You can also make zero-initialization of sf_ori more robust for all T by using value-initialization.
A safer layout:
- for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
- T* interleavedBlockScalePtr =
- static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
- for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
- auto globalRowIdx = eIdx * rows + rIdx;
- T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
- for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
- T sf_ori = 0;
- if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
- sf_ori = blockScalePtr[cIdx];
- }
- int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
- tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
- interleavedBlockScalePtr[sf_index] = sf_ori;
- }
- }
- }
+ T* blockScaleBasePtr = static_cast<T*>(blockScale.data_ptr());
+ for (int eIdx = 0; eIdx < static_cast<int>(num_experts); ++eIdx) {
+ T* interleavedBlockScalePtr =
+ static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
+ T* blockScaleExpertBasePtr = blockScaleBasePtr + eIdx * rows * cols;
+ for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
+ bool const valid_row = rIdx < static_cast<int>(rows);
+ T* blockScaleRowPtr = valid_row ? blockScaleExpertBasePtr + rIdx * cols : nullptr;
+ for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
+ T sf_ori{};
+ if (valid_row && cIdx < static_cast<int>(cols)) {
+ sf_ori = blockScaleRowPtr[cIdx];
+ }
+ int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
+ tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
+ interleavedBlockScalePtr[sf_index] = sf_ori;
+ }
+ }
+ }This keeps behavior the same while avoiding any out-of-bounds pointer values and strengthens default initialization for all template types.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) { | |
| T* interleavedBlockScalePtr = | |
| static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; | |
| for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) { | |
| auto globalRowIdx = eIdx * rows + rIdx; | |
| T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols; | |
| for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) { | |
| T sf_ori = 0; | |
| if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) { | |
| sf_ori = blockScalePtr[cIdx]; | |
| } | |
| int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, | |
| tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); | |
| interleavedBlockScalePtr[sf_index] = sf_ori; | |
| } | |
| } | |
| } | |
| T* blockScaleBasePtr = static_cast<T*>(blockScale.data_ptr()); | |
| for (int eIdx = 0; eIdx < static_cast<int>(num_experts); ++eIdx) { | |
| T* interleavedBlockScalePtr = | |
| static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; | |
| T* blockScaleExpertBasePtr = blockScaleBasePtr + eIdx * rows * cols; | |
| for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) { | |
| bool const valid_row = rIdx < static_cast<int>(rows); | |
| T* blockScaleRowPtr = valid_row ? blockScaleExpertBasePtr + rIdx * cols : nullptr; | |
| for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) { | |
| T sf_ori{}; | |
| if (valid_row && cIdx < static_cast<int>(cols)) { | |
| sf_ori = blockScaleRowPtr[cIdx]; | |
| } | |
| int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, | |
| tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); | |
| interleavedBlockScalePtr[sf_index] = sf_ori; | |
| } | |
| } | |
| } |
🤖 Prompt for AI Agents
In csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp around lines 151-167, the code
computes globalRowIdx and forms blockScalePtr for padded rows (rIdx >= rows)
which can produce an out-of-bounds pointer even if not dereferenced; change the
loop so blockScalePtr (and globalRowIdx) are only computed when rIdx < rows, and
otherwise keep sf_ori value-initialized (e.g., T sf_ori{}), then use sf_ori for
writing into interleavedBlockScalePtr; this avoids undefined pointer arithmetic
while preserving behavior for valid rows.
| @register_custom_op( | ||
| "flashinfer::trtllm_mxint4_block_scale_moe", | ||
| mutates_args=(""), | ||
| ) | ||
| def trtllm_mxint4_block_scale_moe_op( | ||
| routing_logits: torch.Tensor, | ||
| routing_bias: Optional[torch.Tensor], | ||
| hidden_states: torch.Tensor, | ||
| gemm1_weights: torch.Tensor, | ||
| gemm1_weights_scale: torch.Tensor, | ||
| gemm1_alpha: Optional[torch.Tensor], | ||
| gemm1_beta: Optional[torch.Tensor], | ||
| gemm1_clamp_limit: Optional[torch.Tensor], | ||
| gemm2_weights: torch.Tensor, | ||
| gemm2_weights_scale: torch.Tensor, | ||
| num_experts: int, | ||
| top_k: int, | ||
| n_group: Optional[int], | ||
| topk_group: Optional[int], | ||
| intermediate_size: int, | ||
| local_expert_offset: int, | ||
| num_local_experts: int, | ||
| routed_scaling_factor: Optional[float], | ||
| routing_method_type: int, | ||
| enable_pdl: Optional[bool] = None, | ||
| output: Optional[torch.Tensor] = None, | ||
| tune_max_num_tokens: int = 8192, | ||
| ) -> List[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type mismatch: annotation says List[torch.Tensor] but returns a plain tensor.
The function signature declares -> List[torch.Tensor] but line 2006 returns output (a torch.Tensor, not a list). This is inconsistent with the fake op (line 2036) which returns [hidden_states.new_empty(...)].
Either update the return type annotation or wrap the return value:
- ) -> List[torch.Tensor]:
+ ) -> torch.Tensor:Or alternatively:
- return output
+ return [output]Also applies to: 2006-2006
🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 1883-1910 (issue also referenced at
line 2006 and fake-op at 2036): the function is annotated to return
List[torch.Tensor] but currently returns a plain torch.Tensor; make the
implementation and fake-op consistent with the annotation by returning a list of
tensors. Specifically, wrap the single tensor return value into a list (e.g.,
return [output]) and ensure the fake-op also returns a list with matching
shape/type; alternatively, if you prefer a single tensor API, change the
function return annotation to torch.Tensor and update the fake-op to return a
tensor—pick one option and apply it consistently across the function and its
fake-op.
| @flashinfer_api | ||
| def trtllm_mxint4_block_scale_moe( | ||
| routing_logits: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| gemm1_weights: torch.Tensor, | ||
| gemm1_weights_scale: torch.Tensor, | ||
| gemm1_alpha: Optional[torch.Tensor], | ||
| gemm1_beta: Optional[torch.Tensor], | ||
| gemm1_clamp_limit: Optional[torch.Tensor], | ||
| gemm2_weights: torch.Tensor, | ||
| gemm2_weights_scale: torch.Tensor, | ||
| num_experts: int, | ||
| top_k: int, | ||
| n_group: Optional[int], | ||
| topk_group: Optional[int], | ||
| intermediate_size: int, | ||
| local_expert_offset: int, | ||
| local_num_experts: int, | ||
| routed_scaling_factor: Optional[float], | ||
| routing_method_type: int = 0, | ||
| enable_pdl: Optional[bool] = None, | ||
| output: Optional[torch.Tensor] = None, | ||
| tune_max_num_tokens: int = 8192, | ||
| ) -> List[torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing routing_bias parameter in public API wrapper.
The underlying op (trtllm_mxint4_block_scale_moe_op) accepts routing_bias as a parameter, but this public wrapper doesn't expose it. The wrapper always passes None (line 2610). This prevents users from using routing bias with MxInt4 MoE.
Other similar wrappers like trtllm_fp4_block_scale_moe include routing_bias as a parameter.
Add routing_bias to the function signature:
@flashinfer_api
def trtllm_mxint4_block_scale_moe(
routing_logits: torch.Tensor,
+ routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
...
) -> torch.Tensor:And pass it to the op:
return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe(
routing_logits,
- None,
+ routing_bias,
hidden_states,
...
)Also applies to: 2608-2611
🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 2542-2565 (and also update the call
site around 2608-2611), the public wrapper trtllm_mxint4_block_scale_moe is
missing the routing_bias parameter; add routing_bias: Optional[torch.Tensor] =
None to the function signature and forward that parameter to
trtllm_mxint4_block_scale_moe_op instead of passing None so the underlying op
can receive a routing bias; ensure all internal calls/forwarding at lines
~2608-2611 pass the new routing_bias argument through.
| Returns: | ||
| torch.Tensor: returns the final MoE output. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring return type conflicts with type annotation.
The docstring says Returns: torch.Tensor but the function signature declares -> List[torch.Tensor]. These should be consistent.
Update either the docstring or the type annotation to match (see the return type mismatch comment above for the op function).
🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 2605 to 2607, the docstring return
type says "torch.Tensor" but the function signature annotates "->
List[torch.Tensor]"; make them consistent by updating one to match the other: if
the function actually returns multiple tensors, change the docstring to
"Returns: List[torch.Tensor]" and describe each element if needed; if it returns
a single tensor, change the type annotation to "-> torch.Tensor" and update any
callers/tests accordingly; ensure the return description matches the chosen
type.
|
/bot run |
jiahanc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for contribution!
IwakuraRein
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contributions!
|
[FAILED] Pipeline #39477569: 4/20 passed |
Signed-off-by: Nikita Korobov <[email protected]>
📌 Description
Add the MxInt4 x BF16 TRTLLM GEN moe
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.