Skip to content

Conversation

@nekorobov
Copy link
Collaborator

@nekorobov nekorobov commented Dec 2, 2025

📌 Description

Add the MxInt4 x BF16 TRTLLM GEN moe

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • MXInt4 MoE inference path with public API and test coverage; MXInt4 + BF16 supported end-to-end.
    • Exposed new MXInt4 op and helper in the package exports.
  • Refactor

    • Block-scale/interleave routines generalized to support uint8 and bfloat16 inputs and outputs.
    • GEMM/BatchedGemm configs now include an element-wise activation option and are arch-aware (CUDA arch).
  • Tests

    • Added MXInt4 quantization and runtime tests for MoE.
  • Chores

    • Updated packaged artifact path/checksum.

✏️ Tip: You can customize this high-level summary in your review settings.

IwakuraRein and others added 5 commits November 25, 2025 15:37
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 2, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Templates were added to block-scale interleave kernels and host helpers to support uint8 and bfloat16; MXInt4 block-scale MoE support (launcher, config discovery, entry point, tests) was introduced; GEMM interfaces now use tg::CudaArch and include an EltwiseActType; artifact checksums were updated.

Changes

Cohort / File(s) Change Summary
Block Scale Interleave Templating
csrc/nv_internal/cpp/kernels/quantization.cu, csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Converted block_scale_interleave_kernel and invokeBlockScaleInterleave to templates <typename T> with explicit instantiations for uint8_t and __nv_bfloat16.
Host CPU Interleave & Dispatch
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp
Added blockScaleInterleaveHost<T> template, expanded dtype branching (uint8 / bfloat16) for CPU/CUDA paths, and added explicit instantiations; unsupported dtypes throw.
MXInt4 MoE Launcher & Entry
csrc/trtllm_fused_moe_kernel_launcher.cu
Added MxInt4BlockScaleLauncher (config enumeration, init, routing, per-tile launchers) and public trtllm_mxint4_block_scale_moe entry; integrated MXInt4 dispatch.
Python MXInt4 MoE API
flashinfer/fused_moe/core.py, flashinfer/fused_moe/__init__.py
Added DtypeTrtllmGen.MxInt4, new op/wrapper _fake op and trtllm_mxint4_block_scale_moe API, routed forward path to MXInt4 when dtype_act==Bfloat16 and dtype_weights==MxInt4, re-exported API.
Batched GEMM Arch & Activation Updates
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/{BatchedGemmInterface,BatchedGemmOptions,Enums,GemmGatedActOptions,GemmOptions,KernelParams,TmaDescriptor}.h
Replaced boolean isBlackwell with tg::CudaArch cudaArch, added EltwiseActType and threaded eltwiseActType into constructors/options, changed config storage to tg::CudaArch, and updated validation/dtype handling (including MxInt4/Bfloat16).
Quantization & Utilities
flashinfer/fp4_quantization.py, flashinfer/utils.py
Relaxed dtype assertions to accept uint8 or bfloat16; outputs mirror input dtype; allowed num_elts_per_sf 16 or 32; updated docstrings and checks.
Artifact & Checksum Updates
flashinfer/artifacts.py
Updated TRTLLM_GEN_BMM artifact path and checksum constants.
MXInt4 MoE Tests & Helpers
tests/moe/test_trtllm_gen_fused_moe.py, tests/moe/utils.py
Added MxInt4BlockScaleMoe test class, mxint4_quantize and reference helpers, added MXINT4_BF16_BF16 quant mode and skip conditions, integrated MXInt4 into test parametrizations.
Minor / Formatting & Descriptor Updates
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/{KernelTraits.h,TmaDescriptor.h,BatchedGemmInterface.h}
Small formatting, include adjustments, extended dtype-to-TMA mappings (Bfloat16/MxInt4), and minor kernel-launch error-handling refinements.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant PyAPI as trtllm_mxint4_block_scale_moe
    participant LauncherMgr as MxInt4BlockScaleLauncher (config)
    participant Router as Routing/Prepare
    participant Kernel as MXInt4 MoE Kernel

    Client->>PyAPI: call trtllm_mxint4_block_scale_moe(...)
    PyAPI->>LauncherMgr: getValidConfigs(top_k, hidden_size, ...)
    LauncherMgr-->>PyAPI: return valid tile configs
    PyAPI->>LauncherMgr: instantiate per-tile launchers & init args
    PyAPI->>Router: prepare_routing()
    Router-->>Router: compute assignments / expert offsets
    PyAPI->>LauncherMgr: select launcher by config_index
    LauncherMgr->>Kernel: launch selected MXInt4 kernel (MXInt4 weights, BF16 activations)
    Kernel-->>LauncherMgr: output tensor
    LauncherMgr-->>PyAPI: return results
    PyAPI-->>Client: deliver MoE output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • Batched GEMM constructor/signature and validation refactor (EltwiseActType, tg::CudaArch propagation).
    • MxInt4 launcher: config enumeration, init semantics, routing and per-tile dispatch correctness.
    • Templated block-scale interleave kernels and host/CUDA instantiations (memory layout and types).
    • Tests and artifact checksum correctness for new MXInt4 paths.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • yongwww
  • yzh119
  • cyx-6
  • nvmbreughe

Poem

🐰 A tiny hop, a templated tweak,
uint8 and bfloat16 now speak.
MXInt4 marches into the field,
Launchers ready, configs revealed.
GEMMs align — the tests applaud and peek.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 34.72% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description is minimal and uses only the template header 'Add the MxInt4 x BF16 TRTLLM GEN moe' without substantive explanation, context, or details about the implementation beyond confirming checklist items. Expand the description to explain the key changes (e.g., templated quantization kernels, new MoE launcher support), rationale, and any testing notes. Fill in the 'Related Issues' section and consider adding specific reviewer guidance.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: MxInt4 x Bf16 TRT-LLM Gen MoE support' clearly and concisely describes the main change: adding MxInt4 x BF16 support for TRT-LLM Gen MoE, which aligns with the extensive changes across CUDA kernels, Python wrappers, and test files.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8222437 and a5b7681.

📒 Files selected for processing (1)
  • flashinfer/fused_moe/core.py (5 hunks)
🧰 Additional context used
🪛 Ruff (0.14.7)
flashinfer/fused_moe/core.py

2011-2011: Unused function argument: routing_logits

(ARG001)


2012-2012: Unused function argument: routing_bias

(ARG001)


2014-2014: Unused function argument: gemm1_weights

(ARG001)


2015-2015: Unused function argument: gemm1_weights_scale

(ARG001)


2016-2016: Unused function argument: gemm1_alpha

(ARG001)


2017-2017: Unused function argument: gemm1_beta

(ARG001)


2018-2018: Unused function argument: gemm1_clamp_limit

(ARG001)


2019-2019: Unused function argument: gemm2_weights

(ARG001)


2020-2020: Unused function argument: gemm2_weights_scale

(ARG001)


2021-2021: Unused function argument: num_experts

(ARG001)


2022-2022: Unused function argument: top_k

(ARG001)


2023-2023: Unused function argument: n_group

(ARG001)


2024-2024: Unused function argument: topk_group

(ARG001)


2025-2025: Unused function argument: intermediate_size

(ARG001)


2026-2026: Unused function argument: local_expert_offset

(ARG001)


2027-2027: Unused function argument: local_num_experts

(ARG001)


2028-2028: Unused function argument: routed_scaling_factor

(ARG001)


2029-2029: Unused function argument: routing_method_type

(ARG001)


2030-2030: Unused function argument: enable_pdl

(ARG001)


2031-2031: Unused function argument: output

(ARG001)


2032-2032: Unused function argument: tune_max_num_tokens

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/fused_moe/core.py (2)

119-126: LGTM! MxInt4 enum entry and UID shifts are correct.

The MxInt4 entry is properly defined with block_format_bit=1, signed_bit=1, integer_bit=1, num_bits=4, and the subsequent UID shifts maintain consistency across all enum values.


1158-1185: LGTM! MxInt4 forward path correctly implemented.

The new conditional block properly handles the MxInt4 × Bfloat16 combination and forwards all required parameters to the underlying C++ operation.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nekorobov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the quantization capabilities of the TRT-LLM Mixture-of-Experts (MoE) implementation by adding comprehensive support for MxInt4 x Bf16 quantization. This enables more efficient processing of large language models by leveraging lower-precision data types for weights while maintaining Bfloat16 for activations. The changes span across kernel implementations, data type definitions, Python bindings, and testing infrastructure, ensuring a robust and performant integration of the new quantization scheme.

Highlights

  • MxInt4 x Bf16 MoE Support: Introduced a new MxInt4BlockScaleLauncher class and trtllm_mxint4_block_scale_moe function to enable Mixture-of-Experts (MoE) operations with MxInt4 weights and Bfloat16 activations.
  • Templated Quantization Kernels: The block_scale_interleave_kernel and invokeBlockScaleInterleave functions in quantization.cu have been templated to support both uint8_t and __nv_bfloat16 types, enhancing flexibility for different quantization schemes.
  • Updated Data Type Handling: The DtypeTrtllmGen enum in flashinfer/fused_moe/core.py and Dtype enum in trtllm/gen/DtypeDecl.h have been extended to include MxInt4, along with corresponding updates to utility functions for data type properties and block scale types.
  • Python API and Testing Integration: New Python API functions (trtllm_mxint4_block_scale_moe) and test cases (MxInt4BlockScaleMoe) have been added to flashinfer/fused_moe/core.py and tests/moe/test_trtllm_gen_fused_moe.py to ensure proper functionality and validation of the new MxInt4 support.
  • Refactored Batched GEMM Options: The Batched GEMM options and related checks have been updated to use a new CudaArch enum instead of a boolean isBlackwell flag, providing more granular control and clarity for architecture-specific optimizations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for MxInt4 x Bf16 TRT-LLM Gen MoE, expanding the quantization capabilities of the system. The changes involve templatizing CUDA kernels and host-side functions to handle __nv_bfloat16 types for block scale interleaving, updating Python bindings and test infrastructure to integrate the new MxInt4 mode, and refactoring internal batched_gemm components to use a more generalized CudaArch enum. Overall, the implementation appears consistent with the existing codebase, but I've identified a few areas for improvement regarding documentation clarity and a hardcoded constraint.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (7)
flashinfer/utils.py (1)

786-803: Consider adding a docstring.

The function lacks documentation explaining:

  • The purpose and meaning of num_elts_per_sf
  • Why specific dtypes (uint8, bfloat16) are required
  • The relationship between epilogue_tile_m and shuffle block size

Adding a docstring would improve maintainability, especially given the expanding type support.

csrc/nv_internal/cpp/kernels/quantization.cu (1)

250-250: Verify zero-initialization for __nv_bfloat16.

T sf = 0; may not work correctly for __nv_bfloat16 since it's a class type. Consider using explicit initialization:

-        T sf = 0;
+        T sf = T{};

This ensures proper default construction for both uint8_t and __nv_bfloat16.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (1)

186-205: LDG/LdgPlusSts K‑tiling check has an unreachable branch

Given the invariant at Lines 262–268 that any mRouteSfsImpl != mRouteImpl must be Ldgsts/LdgPlusSts with mRouteImpl == Tma, the condition

if (doesRouteImplUseLdgsts(options.mRouteImpl) &&
    doesRouteImplUseLdgPlusSts(options.mRouteSfsImpl.value()))

can never be true (it requires mRouteImpl == Ldgsts and simultaneously mRouteImpl == Tma). The effective K‑tiling guard is the second block that checks mRouteSfsImpl for Ldgsts/LdgPlusSts.

You can safely drop or simplify the first if block to avoid dead code and rely on the SF‑routing check alone.

Also applies to: 348-361

flashinfer/fused_moe/core.py (3)

1883-2007: MXInt4 custom op integration is consistent; fake op needs unused‑arg suppression

The trtllm_mxint4_block_scale_moe_op custom op follows the BF16/FP4 patterns: it builds a MoERunner with dtype_act=Bfloat16, dtype_weights=MxInt4, uses WeightLayout.BlockMajorK, and autotunes via tune_max_num_tokens, then calls into the C++ trtllm_mxint4_block_scale_moe launcher. That wiring looks correct.

For _fake_trtllm_mxint4_block_scale_moe, Ruff is correctly flagging many unused parameters. Since the signature must mirror the real op, consider adding a no‑op usage line instead of renaming/removing parameters, e.g.:

# Keep signature in sync with real op; arguments are unused in the fake path.
_ = (
    routing_logits,
    routing_bias,
    gemm1_weights,
    gemm1_weights_scale,
    gemm1_alpha,
    gemm1_beta,
    gemm1_clamp_limit,
    gemm2_weights,
    gemm2_weights_scale,
    num_experts,
    top_k,
    n_group,
    topk_group,
    intermediate_size,
    local_expert_offset,
    local_num_experts,
    routed_scaling_factor,
    routing_method_type,
    enable_pdl,
    output,
    tune_max_num_tokens,
)

This will satisfy ARG001 without impacting behavior.


2542-2631: MXInt4 high‑level API: type hint and docstring are inconsistent with implementation

  • The function is annotated as -> List[torch.Tensor] but returns the single Tensor produced by trtllm_mxint4_block_scale_moe_op, so the type hint should be torch.Tensor.
  • The docstring text is still tailored to FP4/NVFP4 (mentions “packed fp4” weights and float8 scales). The underlying C++ MXInt4 launcher validates gemm*_weights as uint8 and gemm*_weights_scale as BF16, so those dtype descriptions should be updated to match the MXInt4 path.

Adjusting the return annotation and docstring to the actual MXInt4 semantics will avoid confusion for users of this API.


184-215: Consider extending is_trtllm_moe_supported to cover MXInt4 weights

is_trtllm_moe_supported currently whitelists BF16/E4m3/E2m1/MxE2m1 weights only. With the new BF16‑act + MxInt4 path wired through MoERunner and the MXInt4 kernel launcher, any callsites that rely on this helper to gate “is this config supported?” will still treat MXInt4 as unsupported.

If the intent is for MXInt4 to be usable via the generic TRT‑LLM MoE dispatcher (not just the dedicated trtllm_mxint4_block_scale_moe wrapper), you likely want to add DtypeTrtllmGen.MxInt4 here with appropriate dtype_act constraints.

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

941-1093: MxInt4 launcher wiring is sensible; consider adding the same shape/layout checks as other launchers

The new MxInt4BlockScaleLauncher correctly:

  • Forces BF16 activations and MXInt4 weights in init.
  • Sets routing dtypes and allocates BF16 expert_weights in prepare_routing.
  • Validates gemm1_weights/gemm2_weights as uint8 and the corresponding scales as BF16 in check_moe.
  • Populates MoERunnerArgs fields and allocates gemm1_output / gemm2_output with BF16 in prepare_moe, wiring the workspace pointers consistently with the other launchers.

One difference vs. BF16/FP8/FP4 launchers is that check_moe here does not call FusedMoeLauncher::check_moe_common() or check_weights_shape("gemm1"/"gemm2"), so hidden‑state and weight shapes/layouts are not validated on the MxInt4 path. For consistency and safer failure modes, it would be good to reuse those common checks, e.g.:

void check_moe() const override {
  FusedMoeLauncher::check_moe_common();
  check_weights_shape("gemm1");
  check_weights_shape("gemm2");
  TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::Bfloat16)
      << "Only Bfloat16 is supported by MxInt4 block scale MoE";

  TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_uint8) << "gemm1_weights must be uint8.";
  TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_bfloat16)
      << "gemm1_weights_scale must be bf16.";
  TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_uint8) << "gemm2_weights must be uint8.";
  TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_bfloat16)
      << "gemm2_weights_scale must be bf16.";
}

That keeps MXInt4 validation aligned with the existing MoE launchers.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 23ff744 and 7e9ff16.

⛔ Files ignored due to path filters (2)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h is excluded by !**/gen/**
📒 Files selected for processing (19)
  • csrc/nv_internal/cpp/kernels/quantization.cu (2 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (3 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (5 hunks)
  • flashinfer/artifacts.py (2 hunks)
  • flashinfer/fp4_quantization.py (2 hunks)
  • flashinfer/fused_moe/__init__.py (2 hunks)
  • flashinfer/fused_moe/core.py (4 hunks)
  • flashinfer/utils.py (1 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (5 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (6 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h (1 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (4 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (12 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (6 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (9 hunks)
  • tests/moe/utils.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
🧬 Code graph analysis (11)
tests/moe/utils.py (1)
include/flashinfer/trtllm/fused_moe/runner.h (2)
  • intermediate_size (275-275)
  • hidden_size (265-265)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • num_experts (263-263)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
  • invokeBlockScaleInterleave (292-302)
  • invokeBlockScaleInterleave (292-293)
  • invokeBlockScaleInterleave (305-307)
  • invokeBlockScaleInterleave (308-312)
flashinfer/fp4_quantization.py (1)
flashinfer/fp8_quantization.py (1)
  • _compute_swizzled_layout_sf_size (15-18)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h (1)
  • CudaArch (36-93)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (1)
  • mSm (393-394)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (5)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)
  • gemm (30-297)
  • buildSfTmaDescriptor (194-289)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (2)
  • Dtype (43-274)
  • dtypeNumEltsPerSf (201-213)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h (1)
  • SfLayout (37-91)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h (1)
  • ceilDiv (42-44)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (1)
  • MmaKind (36-107)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • Dtype (43-274)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • Dtype (43-274)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h (1)
  • CudaArch (36-93)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (2)
  • doesRouteImplUseLdgsts (45-45)
  • doesRouteImplUseLdgPlusSts (53-53)
csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)
csrc/nv_internal/cpp/kernels/quantization.cu (6)
  • void (244-268)
  • void (270-288)
  • invokeBlockScaleInterleave (292-302)
  • invokeBlockScaleInterleave (292-293)
  • invokeBlockScaleInterleave (305-307)
  • invokeBlockScaleInterleave (308-312)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (9)
  • maybeGetMinTokenCount (55-60)
  • top_k (270-270)
  • hidden_size (265-265)
  • intermediate_size (275-275)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
  • local_num_experts (277-277)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (89-92)
flashinfer/fused_moe/core.py (1)
  • trtllm_mxint4_block_scale_moe (2543-2631)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (2)
  • trtllm (38-277)
  • gen (39-275)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (2)
  • trtllm (30-110)
  • gen (31-108)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
  • trtllm (28-90)
  • gen (29-89)
  • launchKernel (34-84)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (2)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (2)
  • SmVersion (361-1321)
  • mSm (393-394)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h (2)
  • CudaArch (36-93)
  • isArchBlackwell (51-54)
🪛 Ruff (0.14.7)
tests/moe/test_trtllm_gen_fused_moe.py

613-613: Unused method argument: hidden_states_sample

(ARG002)


643-643: Unused method argument: unused_args

(ARG002)


652-652: Unused method argument: args_dequant

(ARG002)


654-654: Unused method argument: gemm1_weights_orig

(ARG002)


655-655: Unused method argument: gemm2_weights_orig

(ARG002)


656-656: Unused method argument: hidden_size

(ARG002)


657-657: Unused method argument: intermediate_size

(ARG002)


659-659: Unused method argument: weight_processing

(ARG002)


751-751: Unused method argument: hidden_states_scale_global

(ARG002)

flashinfer/fused_moe/core.py

2010-2010: Unused function argument: routing_logits

(ARG001)


2011-2011: Unused function argument: routing_bias

(ARG001)


2013-2013: Unused function argument: gemm1_weights

(ARG001)


2014-2014: Unused function argument: gemm1_weights_scale

(ARG001)


2015-2015: Unused function argument: gemm1_alpha

(ARG001)


2016-2016: Unused function argument: gemm1_beta

(ARG001)


2017-2017: Unused function argument: gemm1_clamp_limit

(ARG001)


2018-2018: Unused function argument: gemm2_weights

(ARG001)


2019-2019: Unused function argument: gemm2_weights_scale

(ARG001)


2020-2020: Unused function argument: num_experts

(ARG001)


2021-2021: Unused function argument: top_k

(ARG001)


2022-2022: Unused function argument: n_group

(ARG001)


2023-2023: Unused function argument: topk_group

(ARG001)


2024-2024: Unused function argument: intermediate_size

(ARG001)


2025-2025: Unused function argument: local_expert_offset

(ARG001)


2026-2026: Unused function argument: local_num_experts

(ARG001)


2027-2027: Unused function argument: routed_scaling_factor

(ARG001)


2028-2028: Unused function argument: routing_method_type

(ARG001)


2029-2029: Unused function argument: enable_pdl

(ARG001)


2030-2030: Unused function argument: output

(ARG001)


2031-2031: Unused function argument: tune_max_num_tokens

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (43)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)

292-292: LGTM. Blank line insertion for readability after the comment block.

flashinfer/utils.py (1)

789-790: The review is based on code that does not exist in the repository.

The snippet shown in your review comment does not match the actual code at lines 789-790 in flashinfer/utils.py. The assertions are NOT expanded as claimed:

  • Actual line 789: assert input_tensor.dtype == torch.uint8 (only uint8)
  • Actual line 790: assert num_elts_per_sf == 16 (only allows 16)

Your review snippet claims these are expanded to torch.uint8 or torch.bfloat16 and 16 or 32, but the actual code is more restrictive. Additionally, the num_elts_per_sf parameter is not unused—it is actively passed by callers in flashinfer/fused_moe/core.py (lines 233-236, 260-263), though the current assertion rejects any value other than 16.

Verify you are reviewing the correct file version or branch.

Likely an incorrect or invalid review comment.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h (1)

90-107: LGTM! Well-documented enum additions.

The new EltwiseActType enum and the Persistent scheduler option are clearly documented and follow the existing code patterns in this file.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)

54-66: LGTM! Bfloat16 and MxInt4 dtype support added correctly.

The addition of Bfloat16 handling and the inclusion of MxInt4 in the same branch as MxE2m1 are consistent with the type definitions in DtypeDecl.h, where both are 4-bit block formats.


198-206: Good defensive initialization practice.

Using brace initialization for tmaDataFormat ensures it has a defined value before the conditional branches, which is a good defensive programming practice.

flashinfer/artifacts.py (2)

92-93: Artifact path updated for new batched GEMM version.

The artifact path has been updated to reflect the new compiled kernels that include MXInt4 support.


113-113: Correct the artifact repository URL reference and verify the TRTLLM_GEN_BMM checksum against the NVIDIA Artifactory source.

The checksum hash b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc should be verified against the checksums.txt file published at the NVIDIA Artifactory repository (https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841/checksums.txt), not the S3 URL. The artifact path in the code is correct and points to the right location.

tests/moe/utils.py (2)

33-33: LGTM! New quantization mode added.

The MXINT4_BF16_BF16 quantization mode is properly added to support the MXInt4 path introduced in this PR.


89-94: Appropriate skip condition for MXInt4 alignment requirements.

The skip condition correctly enforces that both intermediate_size and hidden_size must be multiples of 256 for MXInt4 quantization, which aligns with the hardware requirements for this format.

flashinfer/fused_moe/__init__.py (1)

34-34: LGTM! Public API export for MXInt4 MoE.

The new trtllm_mxint4_block_scale_moe function is properly exported, making it available as part of the public API.

Also applies to: 58-58

flashinfer/fp4_quantization.py (2)

261-282: LGTM! Generalized dtype support for block scale interleave.

The function now correctly supports both uint8 and bfloat16 input dtypes, with the output dtype mirroring the input. This aligns with the templated C++ kernel implementation.


692-717: Proper dtype validation for generalized interleave function.

The assertion correctly validates that the input tensor is either uint8 or bfloat16, with a clear error message when the constraint is violated.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (4)

21-21: LGTM! Required header for ModuleCache.

The <unordered_map> header is needed for the ModuleCache type definition on line 459.


126-138: Documentation updated for MXInt4 support.

The comments now correctly document that MxInt4 format uses Dtype::Bfloat16 for scaling factors, which aligns with the implementation changes in TmaDescriptor.h and the type definitions in DtypeDecl.h.


581-593: Improved PDL safety check and error handling.

The changes introduce two improvements:

  1. PDL safety: The pdlSafe boolean correctly determines when PDL can be safely enabled based on grid wait conditions, providing more precise control than the previous implementation.

  2. Error propagation: Returning the actual CUDA error code (result) instead of a hardcoded -1 provides more useful diagnostic information to callers.


725-726: Simplified architecture parameter passing.

Using config.mSm directly instead of computing an intermediate isBlackwell variable simplifies the code while maintaining the same functionality.

csrc/nv_internal/tensorrt_llm/kernels/quantization.h (1)

70-72: LGTM! Template generalization for multi-dtype support.

Converting invokeBlockScaleInterleave to a template function allows it to support multiple types (uint8_t and __nv_bfloat16) in a type-safe manner. This aligns with the explicit template instantiations in the implementation file (quantization.cu).

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (2)

122-123: LGTM: Interface update from boolean to explicit CUDA architecture enum.

The parameter change from bool isBlackwell to tg::CudaArch cudaArch improves type safety and extensibility for supporting multiple GPU architectures (Hopper, Blackwell, Blackwell Ultra).


215-215: Verify default architecture choice.

The default tg::CudaArch::Sm100a (Blackwell) is used here. Confirm this is the intended default for new configs, especially if Hopper (Sm90a) support is still needed.

tests/moe/test_trtllm_gen_fused_moe.py (5)

684-688: Inconsistent num_elts_per_sf between gemm1 and gemm2.

For gemm1 scales (line 688), num_elts_per_sf=32 is used, but for gemm2 scales (line 715), num_elts_per_sf=16 is used. This asymmetry is unexpected for MxInt4 where the block size should be consistent.

Verify this is intentional based on kernel requirements, or if this should be 32 for both.

Also applies to: 711-715


2475-2481: All other MoE implementations are commented out.

Only MxInt4BlockScaleMoe is enabled in the test parametrization while others (BF16, FP8, FP4 variants) are commented out. This appears to be for focused testing during development.

Ensure all implementations are re-enabled before merging to maintain test coverage.


1940-1942: LGTM: Extended dequant path to handle MxInt4xBf16.

The comment correctly lists all applicable modes where activation output uses bf16 without additional quantization.


591-607: The scales reshape is correct. The mxint4_quantize function returns scales with shape (-1, sf_vec_size), where each element represents one block's scale factor. The caller in prepare_static_weights_for_mxint4 immediately reshapes these scales to match kernel expectations: for gemm1, shape becomes (num_experts, 2*intermediate_size//sf_vec_size, hidden_size//sf_vec_size), and for gemm2, shape becomes (num_experts, hidden_size//sf_vec_size, intermediate_size//sf_vec_size). This correctly distributes one scale value per sf_vec_size-element block across the weight matrix dimensions, which is the expected format for block-scaled quantization kernels.

Likely an incorrect or invalid review comment.


750-786: The trtllm_mxint4_block_scale_moe kernel signature accepts routing_bias and routed_scaling_factor parameters, but the Python binding is not yet imported and the test implementation is incomplete.

The C++ kernel signature in csrc/trtllm_fused_moe_kernel_launcher.cu includes Optional<TensorView> routing_bias and Optional<double> routed_scaling_factor parameters. However, the corresponding Python binding is not imported in the test file, and MxInt4BlockScaleMoe.call_moe() is currently just a TODO stub. When this implementation is completed, ensure these parameters are extracted from kwargs and passed to the kernel, consistent with the FP8BlockScaleMoe implementation.

csrc/nv_internal/cpp/kernels/quantization.cu (2)

270-288: Reverse interleave kernel not templated.

block_scale_interleave_reverse_kernel and invokeBlockScaleInterleaveReverse remain hardcoded for uint8_t. If bfloat16 reverse interleaving is needed in the future, this will require similar templating.

This is acceptable if reverse is only used for uint8 scales currently.

Also applies to: 315-324


304-312: LGTM: Template instantiations added for both supported types.

Explicit instantiations for uint8_t and __nv_bfloat16 ensure the templated launcher is available for both dtypes used by the host code.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (3)

88-91: LGTM: Cleaner defaulting for valid dimensions.

The ternary-based defaulting is more concise and readable than the previous explicit -1 checks.


417-426: LGTM: MxInt4 dtype correctly maps to Bfloat16 for scaling factors.

The dtype mapping logic for A's scaling factors now handles:

  • E2m1E4m3 (FP8 scales)
  • MxInt4Bfloat16 (BF16 scales for INT4 weights)
  • Other MX types → UE8m0

This aligns with the MxInt4 block-scale design where scales are stored as BF16.


410-411: Valid dimension propagation for routed activations.

The changes correctly pass valid dimensions when constructing TMA shape/stride for activation matrices. For routed activations (useRouteAct), using options.mNumTokens as the valid dimension ensures proper bounds.

Also applies to: 459-460, 518-520

csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)

186-187: LGTM: Dtype validation added for supported scale types.

The check correctly validates that block scales are either uint8 or bfloat16 before proceeding.


205-227: LGTM: Dtype-conditional dispatch for CUDA and CPU paths.

The branching correctly dispatches to the appropriate templated function based on dtype for both CUDA kernel invocation and host-side processing.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)

85-133: Eltwise activation threading into GemmOptions looks consistent

The new eltwiseActType parameter is plumbed from BatchedGemmOptions into gemm::GemmOptions in the same relative position as in the updated GemmOptions constructor, so the wiring here looks correct.


378-394: Switching BatchedGemmConfig.mSm to tg::CudaArch aligns with new arch handling

Changing mSm to tg::CudaArch is consistent with the new CudaArch‑based validation APIs and should help keep arch handling uniform across GEMM/BatchedGEMM paths.

flashinfer/fused_moe/core.py (2)

104-127: Keep DtypeTrtllmGen MxInt4 / UInt encodings in strict sync with C++ DtypeDecl*

The new MxInt4, UE8m0, and UInt* / Void UIDs look reasonable, but they must exactly match the bitfields and UID ordering in trtllm/gen/DtypeDecl.h. Any drift will silently mis‑decode dtypes at the C++ level. Please double‑check the C++ enum to confirm these Python encodings are identical.


1157-1184: MoERunner MXInt4 branch wiring looks coherent

The new MXInt4 path in MoERunner.forward (BF16 activations + DtypeTrtllmGen.MxInt4 weights) passes the same kwargs (gemm1_weights[_scale], gemm1_alpha/beta/clamp_limit, gemm2_weights[_scale], routing params, output, tactic) as the C++ launcher expects. This matches the MXInt4 launcher signature in csrc/trtllm_fused_moe_kernel_launcher.cu, so the dispatcher logic here looks consistent.

csrc/trtllm_fused_moe_kernel_launcher.cu (3)

1785-1866: C++ MXInt4 entrypoint is consistent with Python bindings and launcher

The new trtllm_mxint4_block_scale_moe function:

  • Enforces routing logits dtypes/shapes and disallows routing_bias for MXInt4 (matching the Python wrapper’s routing_bias=None).
  • Requires gemm*_weights to be uint8 and weight_scale_vec_size == 32, aligning with MXInt4 packing and the BF16‑scale checks in MxInt4BlockScaleLauncher::check_moe.
  • Builds one MxInt4BlockScaleLauncher per selected tile_N, sets MoERunnerArgs fields (tokens, experts, hidden/intermediate sizes, local expert layout, routed_scaling_factor, do_finalize/output), and selects a config based on config_index or defaults.

This is in line with the BF16 / FP4 launcher patterns, and the TVM_FFI export at Line 1925 wires it into the FFI surface correctly.


1868-1895: MXInt4 getValidConfigs integration in trtllm_get_valid_moe_configs is correct

The early branch for dtype_act=Bfloat16 && dtype_weights=MxInt4 dispatches to MxInt4BlockScaleLauncher::getValidConfigs, which in turn instantiates a MoE::Runner with BF16/MxInt4, shuffled A, and BlockMajorK. This ensures the autotuner sees the same config space as the runtime launcher. The rest of the dtype combinations remain unchanged.


1921-1926: FFI export for trtllm_mxint4_block_scale_moe is in place

The additional TVM_FFI_DLL_EXPORT_TYPED_FUNC for trtllm_mxint4_block_scale_moe makes the MXInt4 path available to the Python JIT module; nothing else to flag here.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (5)

23-27: Eltwise activation type is correctly integrated into GemmOptions

Including CudaArchDecl.h and extending the GemmOptions constructor with EltwiseActType eltwiseActType (stored in mEltwiseActType and dumped via dumpOptions) is consistent with the rest of the GEMM stack and with the new BatchedGemmOptions ctor. No issues here.

Also applies to: 106-144


398-421: SmVersion aliasing and GemmConfig.mSm migration to tg::CudaArch look good

Replacing the local SmVersion enum with using SmVersion = tg::CudaArch and updating GemmConfig.mSm to tg::CudaArch aligns this header with the shared CUDA arch representation in CudaArchDecl.h, simplifying architecture checks.


611-617: checkAndUpdateGemmOptions: cudaArch‑based isBlackwell and validM/N/K init are reasonable

  • Switching checkAndUpdateGemmOptions to take tg::CudaArch cudaArch and deriving isBlackwell via tg::isArchBlackwell(cudaArch) centralizes arch logic and avoids out‑of‑band booleans.
  • The new < 0 checks for mValidM/N/K preserve earlier semantics (defaulting to full M/N/K when unset) while handling any negative sentinel consistently.

Callers just need to ensure they pass the actual SM’s CudaArch rather than a derived boolean.

Also applies to: 642-651


686-695: A‑side cast check now correctly allows MxInt4→BF16

The updated A‑cast constraint:

((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1 ||
  options.mDtypeA == tg::Dtype::MxInt4) &&
 options.mDtypeMmaA == tg::Dtype::Bfloat16)

is exactly what the MXInt4 GEMM path needs (MxInt4 weights cast to BF16 MMA inputs) and remains compatible with the existing NVFP4/MxFP4 behavior.


786-808: New Blackwell‑only and epilogue/scheduler constraints are consistent with hardware limits

  • LDTM shape checks now branch on isBlackwell, constraining Hopper to 16dp256bit while allowing 16dp256bit or 32dp32bit on Blackwell, plus the extra guards for transposed outputs and epilogueTileM=64.
  • Enforcing options.mMmaM == 128 or specific tile shapes for certain MxFp4/MxFp8 paths, and gating DeepSeek/block‑scaled features on isBlackwell, protects against unsupported tensor core combinations.
  • The new block:
if (isBlackwell && !options.mUseCustomMmaSchedule && !options.mUseDeepSeekFp8 &&
    options.mTileScheduler == TileScheduler::Persistent) { ... }

ensures persistent scheduling on Blackwell always uses the custom MMA schedule, which matches the intent of the comments around custom scheduling.

  • When mNumEpilogueWarps > 4, requiring TileN to be a multiple of EpilogueTileN * numEpilogueWrpGrps is a sensible layout constraint for multi‑warp epilogues.

Overall these validations look correct and should fail fast on unsupported configs rather than letting kernels misbehave.

Also applies to: 987-1010, 1300-1318, 1456-1459

Signed-off-by: Nikita Korobov <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

788-793: LGTM - Docstring has been corrected.

The docstring now correctly says "MXINT4-specific" instead of "FP4-specific" as noted in previous reviews.

🧹 Nitpick comments (2)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (1)

186-227: Add output tensor validation in BlockScaleInterleave to catch mismatched buffers early

The dtype guard on blockScale plus the CUDA/CPU dispatch by dtype is good. However, interleavedBlockScale is assumed to have matching dtype, contiguity, and at least num_experts * expert_out_size elements without any checks. A mismatched output tensor (e.g., wrong dtype or too small) would lead to hard-to-debug memory issues.

Consider adding symmetric validation, e.g.:

   CHECK_CONTIGUOUS(blockScale);
   TVM_FFI_ICHECK(blockScale.dtype() == dl_uint8 || blockScale.dtype() == dl_bfloat16)
       << "Block Scale must be uint8 or bfloat16.";
+  CHECK_CONTIGUOUS(interleavedBlockScale);
+  TVM_FFI_ICHECK_EQ(interleavedBlockScale.dtype(), blockScale.dtype())
+      << "interleavedBlockScale must have the same dtype as blockScale.";
+
   auto blockScaleShape = blockScale.sizes();
   ...
+  TVM_FFI_ICHECK_EQ(interleavedBlockScale.numel(), num_experts * expert_out_size)
+      << "interleavedBlockScale has incorrect size for the given blockScale.";

This keeps the public behavior the same when used correctly but fails fast with a clear error message if the caller wires up the wrong buffer.

tests/moe/test_trtllm_gen_fused_moe.py (1)

591-607: Clarify scales reshape logic.

The returned scales.reshape(-1, sf_vec_size) at line 606 is confusing. The scales tensor has shape (num_groups, 1) where num_groups = total_elements / sf_vec_size. The reshape attempts to produce shape (num_groups/sf_vec_size, sf_vec_size), which only works if num_groups is divisible by sf_vec_size.

While this works for the current test cases (power-of-2 dimensions), the semantic meaning is unclear and the caller at line 625 immediately reshapes the result anyway.

Consider simplifying to return a flat tensor and let the caller handle the reshape:

-    return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.reshape(
-        -1, sf_vec_size
-    )
+    return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.view(-1)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7e9ff16 and 8222437.

📒 Files selected for processing (3)
  • csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (3 hunks)
  • flashinfer/fused_moe/core.py (4 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (2)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • num_experts (263-263)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
  • invokeBlockScaleInterleave (292-302)
  • invokeBlockScaleInterleave (292-293)
  • invokeBlockScaleInterleave (305-307)
  • invokeBlockScaleInterleave (308-312)
🪛 Ruff (0.14.7)
tests/moe/test_trtllm_gen_fused_moe.py

613-613: Unused method argument: hidden_states_sample

(ARG002)


643-643: Unused method argument: unused_args

(ARG002)


652-652: Unused method argument: args_dequant

(ARG002)


654-654: Unused method argument: gemm1_weights_orig

(ARG002)


655-655: Unused method argument: gemm2_weights_orig

(ARG002)


656-656: Unused method argument: hidden_size

(ARG002)


657-657: Unused method argument: intermediate_size

(ARG002)


659-659: Unused method argument: weight_processing

(ARG002)


751-751: Unused method argument: hidden_states_scale_global

(ARG002)

flashinfer/fused_moe/core.py

2010-2010: Unused function argument: routing_logits

(ARG001)


2011-2011: Unused function argument: routing_bias

(ARG001)


2013-2013: Unused function argument: gemm1_weights

(ARG001)


2014-2014: Unused function argument: gemm1_weights_scale

(ARG001)


2015-2015: Unused function argument: gemm1_alpha

(ARG001)


2016-2016: Unused function argument: gemm1_beta

(ARG001)


2017-2017: Unused function argument: gemm1_clamp_limit

(ARG001)


2018-2018: Unused function argument: gemm2_weights

(ARG001)


2019-2019: Unused function argument: gemm2_weights_scale

(ARG001)


2020-2020: Unused function argument: num_experts

(ARG001)


2021-2021: Unused function argument: top_k

(ARG001)


2022-2022: Unused function argument: n_group

(ARG001)


2023-2023: Unused function argument: topk_group

(ARG001)


2024-2024: Unused function argument: intermediate_size

(ARG001)


2025-2025: Unused function argument: local_expert_offset

(ARG001)


2026-2026: Unused function argument: local_num_experts

(ARG001)


2027-2027: Unused function argument: routed_scaling_factor

(ARG001)


2028-2028: Unused function argument: routing_method_type

(ARG001)


2029-2029: Unused function argument: enable_pdl

(ARG001)


2030-2030: Unused function argument: output

(ARG001)


2031-2031: Unused function argument: tune_max_num_tokens

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp (1)

170-173: BF16 explicit instantiation: ensure proper header availability

The explicit instantiation for blockScaleInterleaveHost<__nv_bfloat16> looks correct and matches the CUDA-side templating. Just make sure the compilation unit always sees the definition of __nv_bfloat16 (via the appropriate CUDA header) on all supported toolchains; otherwise this instantiation could fail to compile in some environments.

tests/moe/test_trtllm_gen_fused_moe.py (3)

2176-2224: LGTM - Reference dequantization implementation is correct.

The MxInt4 dequantization logic properly:

  1. Unpacks two 4-bit values from each byte using bitwise operations
  2. Converts unsigned nibbles to signed two's complement values in [-8, 7]
  3. Applies block scales correctly

2478-2582: LGTM - Test parametrization correctly includes MxInt4BlockScaleMoe.

The MxInt4BlockScaleMoe is properly added to test parametrizations with compatible routing configs and weight processing (BlockMajorK layout with shuffled weights).


711-715: Inconsistent num_elts_per_sf between gemm1 and gemm2 scale permutation requires verification.

For gemm1 scales (line 688), num_elts_per_sf=32 is used, but for gemm2 scales here, num_elts_per_sf=16 is used. Verify if this difference is intentional based on MxInt4's sf_vec_size or if both should use num_elts_per_sf=32 for consistency.

Reference: FP4Moe implementation uses num_elts_per_sf=16 for both occurrences (lines 454 and 481), which aligns with its sf_vec_size=16.

flashinfer/fused_moe/core.py (4)

119-126: LGTM - MxInt4 enum value properly added with UID shift.

The MxInt4 dtype is correctly added with bit format (1, 1, 1, 4, 14) and subsequent UIDs are properly incremented.


1157-1184: LGTM - MxInt4 dispatch path correctly added.

The dispatch for dtype_act == BF16 and dtype_weights == MxInt4 is properly wired to call moe_op.trtllm_mxint4_block_scale_moe with the appropriate arguments.


2038-2044: LGTM - Module exports correctly updated.

The SimpleNamespace return value properly includes trtllm_mxint4_block_scale_moe alongside other MoE operations.


193-198: No changes needed - MxInt4 is not a valid DtypeTrtllmGen enum value.

The is_trtllm_moe_supported function correctly lists all supported weight types. The DtypeTrtllmGen enum does not include MxInt4; the only mixed-precision variants are MxE2m1 and MxE4m3, both of which are already in the supported list. References to MxInt4 in the test file are for test utilities only, not DtypeTrtllmGen enum members.

Likely an incorrect or invalid review comment.

Comment on lines +151 to +167
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
T* interleavedBlockScalePtr =
static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
auto globalRowIdx = eIdx * rows + rIdx;
T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
T sf_ori = 0;
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
sf_ori = blockScalePtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid out-of-bounds pointer arithmetic for padded rows in blockScaleInterleaveHost

For padded rows (rIdx >= rows), globalRowIdx = eIdx * rows + rIdx and the derived blockScalePtr can point past the end of the blockScale buffer, even though it’s not dereferenced when rIdx >= rows. This is technically undefined behavior and easy to avoid by only forming the row pointer when the row is valid.

You can also make zero-initialization of sf_ori more robust for all T by using value-initialization.

A safer layout:

-  for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
-    T* interleavedBlockScalePtr =
-        static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
-    for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
-      auto globalRowIdx = eIdx * rows + rIdx;
-      T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
-      for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
-        T sf_ori = 0;
-        if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
-          sf_ori = blockScalePtr[cIdx];
-        }
-        int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
-                                      tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
-        interleavedBlockScalePtr[sf_index] = sf_ori;
-      }
-    }
-  }
+  T* blockScaleBasePtr = static_cast<T*>(blockScale.data_ptr());
+  for (int eIdx = 0; eIdx < static_cast<int>(num_experts); ++eIdx) {
+    T* interleavedBlockScalePtr =
+        static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
+    T* blockScaleExpertBasePtr = blockScaleBasePtr + eIdx * rows * cols;
+    for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
+      bool const valid_row = rIdx < static_cast<int>(rows);
+      T* blockScaleRowPtr = valid_row ? blockScaleExpertBasePtr + rIdx * cols : nullptr;
+      for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
+        T sf_ori{};
+        if (valid_row && cIdx < static_cast<int>(cols)) {
+          sf_ori = blockScaleRowPtr[cIdx];
+        }
+        int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
+                                      tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
+        interleavedBlockScalePtr[sf_index] = sf_ori;
+      }
+    }
+  }

This keeps behavior the same while avoiding any out-of-bounds pointer values and strengthens default initialization for all template types.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
T* interleavedBlockScalePtr =
static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
auto globalRowIdx = eIdx * rows + rIdx;
T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
T sf_ori = 0;
if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
sf_ori = blockScalePtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
}
T* blockScaleBasePtr = static_cast<T*>(blockScale.data_ptr());
for (int eIdx = 0; eIdx < static_cast<int>(num_experts); ++eIdx) {
T* interleavedBlockScalePtr =
static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
T* blockScaleExpertBasePtr = blockScaleBasePtr + eIdx * rows * cols;
for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
bool const valid_row = rIdx < static_cast<int>(rows);
T* blockScaleRowPtr = valid_row ? blockScaleExpertBasePtr + rIdx * cols : nullptr;
for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
T sf_ori{};
if (valid_row && cIdx < static_cast<int>(cols)) {
sf_ori = blockScaleRowPtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
}
🤖 Prompt for AI Agents
In csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp around lines 151-167, the code
computes globalRowIdx and forms blockScalePtr for padded rows (rIdx >= rows)
which can produce an out-of-bounds pointer even if not dereferenced; change the
loop so blockScalePtr (and globalRowIdx) are only computed when rIdx < rows, and
otherwise keep sf_ori value-initialized (e.g., T sf_ori{}), then use sf_ori for
writing into interleavedBlockScalePtr; this avoids undefined pointer arithmetic
while preserving behavior for valid rows.

Comment on lines +1883 to +1910
@register_custom_op(
"flashinfer::trtllm_mxint4_block_scale_moe",
mutates_args=(""),
)
def trtllm_mxint4_block_scale_moe_op(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_alpha: Optional[torch.Tensor],
gemm1_beta: Optional[torch.Tensor],
gemm1_clamp_limit: Optional[torch.Tensor],
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
num_local_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int,
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
) -> List[torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type mismatch: annotation says List[torch.Tensor] but returns a plain tensor.

The function signature declares -> List[torch.Tensor] but line 2006 returns output (a torch.Tensor, not a list). This is inconsistent with the fake op (line 2036) which returns [hidden_states.new_empty(...)].

Either update the return type annotation or wrap the return value:

-    ) -> List[torch.Tensor]:
+    ) -> torch.Tensor:

Or alternatively:

-        return output
+        return [output]

Also applies to: 2006-2006

🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 1883-1910 (issue also referenced at
line 2006 and fake-op at 2036): the function is annotated to return
List[torch.Tensor] but currently returns a plain torch.Tensor; make the
implementation and fake-op consistent with the annotation by returning a list of
tensors. Specifically, wrap the single tensor return value into a list (e.g.,
return [output]) and ensure the fake-op also returns a list with matching
shape/type; alternatively, if you prefer a single tensor API, change the
function return annotation to torch.Tensor and update the fake-op to return a
tensor—pick one option and apply it consistently across the function and its
fake-op.

Comment on lines +2542 to +2565
@flashinfer_api
def trtllm_mxint4_block_scale_moe(
routing_logits: torch.Tensor,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_alpha: Optional[torch.Tensor],
gemm1_beta: Optional[torch.Tensor],
gemm1_clamp_limit: Optional[torch.Tensor],
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int = 0,
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
) -> List[torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing routing_bias parameter in public API wrapper.

The underlying op (trtllm_mxint4_block_scale_moe_op) accepts routing_bias as a parameter, but this public wrapper doesn't expose it. The wrapper always passes None (line 2610). This prevents users from using routing bias with MxInt4 MoE.

Other similar wrappers like trtllm_fp4_block_scale_moe include routing_bias as a parameter.

Add routing_bias to the function signature:

 @flashinfer_api
 def trtllm_mxint4_block_scale_moe(
     routing_logits: torch.Tensor,
+    routing_bias: Optional[torch.Tensor],
     hidden_states: torch.Tensor,
     ...
 ) -> torch.Tensor:

And pass it to the op:

     return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe(
         routing_logits,
-        None,
+        routing_bias,
         hidden_states,
         ...
     )

Also applies to: 2608-2611

🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 2542-2565 (and also update the call
site around 2608-2611), the public wrapper trtllm_mxint4_block_scale_moe is
missing the routing_bias parameter; add routing_bias: Optional[torch.Tensor] =
None to the function signature and forward that parameter to
trtllm_mxint4_block_scale_moe_op instead of passing None so the underlying op
can receive a routing bias; ensure all internal calls/forwarding at lines
~2608-2611 pass the new routing_bias argument through.

Comment on lines +2605 to +2607
Returns:
torch.Tensor: returns the final MoE output.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring return type conflicts with type annotation.

The docstring says Returns: torch.Tensor but the function signature declares -> List[torch.Tensor]. These should be consistent.

Update either the docstring or the type annotation to match (see the return type mismatch comment above for the op function).

🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 2605 to 2607, the docstring return
type says "torch.Tensor" but the function signature annotates "->
List[torch.Tensor]"; make them consistent by updating one to match the other: if
the function actually returns multiple tensors, change the docstring to
"Returns: List[torch.Tensor]" and describe each element if needed; if it returns
a single tensor, change the type annotation to "-> torch.Tensor" and update any
callers/tests accordingly; ensure the return description matches the chosen
type.

@jiahanc
Copy link
Collaborator

jiahanc commented Dec 2, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !171 has been created, and the CI pipeline #39477569 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for contribution!

Copy link
Collaborator

@IwakuraRein IwakuraRein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contributions!

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39477569: 4/20 passed

Signed-off-by: Nikita Korobov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants