Skip to content

fix: trtllm_mxint4_block_scale_moe unit test to index output list#2627

Merged
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
jimmyzho:moe_quick_fix
Feb 27, 2026
Merged

fix: trtllm_mxint4_block_scale_moe unit test to index output list#2627
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
jimmyzho:moe_quick_fix

Conversation

@jimmyzho
Copy link
Contributor

@jimmyzho jimmyzho commented Feb 24, 2026

📌 Description

Fix unit test failures caused by change in #2548 to the API that now returns a list of tensors.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests
    • Updated MoE test implementations to correctly extract return values from method calls.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a regression in the trtllm_mxint4_block_scale_moe unit test. The test was failing because a recent API update in FlashInfer changed the return type of a function from a single tensor to a list of tensors. The fix ensures the test correctly accesses the intended output, restoring test pass rates.

Highlights

  • Unit Test Fix: Modified the trtllm_mxint4_block_scale_moe unit test to correctly handle an API change where the output is now a list of tensors, by indexing the first element.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/moe/test_trtllm_gen_fused_moe.py
    • Updated the call_moe function to access the first element of the output list, adapting to a recent API change.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 24, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and b6a81f0.

📒 Files selected for processing (1)
  • tests/moe/test_trtllm_gen_fused_moe.py

📝 Walkthrough

Walkthrough

The PR modifies test implementations in the MoE test file to adjust output extraction behavior. Two call_moe methods now return the first element of the output tensor instead of the entire output, aligning with upstream API changes in MoE entry points.

Changes

Cohort / File(s) Summary
MoE Test Call Methods
tests/moe/test_trtllm_gen_fused_moe.py
Updated FP4Moe.call_moe and MxInt4BlockScaleMoe.call_moe to return output[0] instead of full output tensor, reflecting changes to MoE return value structure.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~5 minutes

Possibly related PRs

Suggested labels

ready, run-ci, op: moe

Suggested reviewers

  • jiahanc
  • cyx-6
  • djmmoss
  • yzh119

Poem

🐰 A tensor once whole, now split in two,
The first slice returns, a cleaner view—
Where output[0] brings order to the test,
Our MoE calls now returning their best! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: fixing a unit test to index the output list due to an API change that now returns a list instead of a single tensor.
Description check ✅ Passed The description provides a clear explanation of the fix (unit test failures from PR #2548 API change) and mentions that the API now returns a list of tensors, directly addressing what needs to be done.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request correctly addresses unit test failures for the MxInt4 block scale MoE implementation. The change ensures that the output of the trtllm_mxint4_block_scale_moe function is correctly indexed, as the API now returns a list of tensors instead of a single tensor. This fix aligns the MxInt4 test logic with the existing FP4 implementation and the updated library API.

@jimmyzho
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !345 has been created, and the CI pipeline #44676334 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx added the v0.6.5 release blocker label for v0.6.5 label Feb 24, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Feb 24, 2026

we shouldn't have api breaking changing without deprecation process by default

do you think you can help make it backward compatible instead? @jimmyzho
b8fd612

thanks for catching this

tagged 0.6.5 release blocker for now

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44676334: 14/20 passed

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 27, 2026

@jimmyzho ignore my comment above, this is purely a typo in the unit test
@IwakuraRein pointed out

i was worried about potentially escaped non-compatible api change, but confirmed from git diff v0.6.4 everything seems fine

@flashinfer_api
 def trtllm_fp8_per_tensor_scale_moe(
@@ -2257,10 +2323,11 @@ def trtllm_fp8_per_tensor_scale_moe(
     routed_scaling_factor: Optional[float],
     use_routing_scales_on_input: bool,
     routing_method_type: int = 0,
+    do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
     tune_max_num_tokens: int = 8192,
     activation_type: int = ActivationType.Swiglu.value,
-) -> torch.Tensor:
+) -> Union[List[torch.Tensor], torch.Tensor]:
     """FP8 per tensor scale MoE operation.

     Args:
@@ -2282,22 +2349,21 @@ def trtllm_fp8_per_tensor_scale_moe(
         routed_scaling_factor: Scaling factor for routing
         use_routing_scales_on_input: Whether to use routing scales on input
         routing_method_type: Type of routing method to use (default: 0)
+        do_finalize: Whether to finalize the output (default: True).
         enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
--
 @flashinfer_api
 def trtllm_fp8_block_scale_moe(
@@ -2343,10 +2418,11 @@ def trtllm_fp8_block_scale_moe(
     routing_method_type: int = 0,
     use_shuffled_weight: bool = False,
     weight_layout: int = 0,
+    do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
-) -> torch.Tensor:
+) -> Union[List[torch.Tensor], torch.Tensor]:
     """FP8 block scale MoE operation.

     Args:
@@ -2374,16 +2450,18 @@ def trtllm_fp8_block_scale_moe(
         weight_layout: Weight layout format (default: WeightLayout.MajorK). Supported layouts:
             - 0: MajorK - K-major layout [Mn, K]
             - 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK]
+        do_finalize: Whether to finalize the output (default: True).
         enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
--
 @flashinfer_api
 def trtllm_fp8_block_scale_routed_moe(
@@ -2433,11 +2520,12 @@ def trtllm_fp8_block_scale_routed_moe(
     routing_method_type: int = 0,
     use_shuffled_weight: bool = False,
     weight_layout: int = 0,
+    do_finalize: bool = True,
     enable_pdl: Optional[bool] = None,
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
     fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
-) -> torch.Tensor:
+) -> Union[List[torch.Tensor], torch.Tensor]:
     """FP8 block scale MoE operation with pre-computed routing (packed format).

     This function is used when routing decisions have already been computed
@@ -2468,14 +2556,16 @@ def trtllm_fp8_block_scale_routed_moe(
         use_shuffled_weight: Whether to use shuffled weights
         weight_layout: Weight layout (0 = MajorK, 1 = BlockMajorK)
         enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
+        do_finalize: Whether to finalize the output (default: True).
--
 @flashinfer_api
 def trtllm_fp4_block_scale_moe(
@@ -2589,12 +2688,8 @@ def trtllm_fp4_block_scale_moe(
         do_finalize (bool): Whether to finalize the output (default: False)
         enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
         activation_type (int): Type of activation function (default: 3 - Swiglu)
-            - 0: Gelu
-            - 1: Relu
-            - 2: Silu
             - 3: Swiglu
             - 4: Geglu
-            - 5: SwigluBias
             - 6: Relu2
             - 7: Identity
         tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)
@@ -2726,12 +2821,8 @@ def trtllm_fp4_block_scale_routed_moe(
             - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
         do_finalize (bool): Whether to finalize the output (default: False)
         activation_type (int): Type of activation function (default: 3 - Swiglu)
-            - 0: Gelu
-            - 1: Relu
--
 from ..api_logging import flashinfer_api
-from flashinfer.jit import gen_dsv3_router_gemm_module
+from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
 import functools
 from types import SimpleNamespace
+from typing import Optional
 import torch
 from flashinfer.utils import (
     register_custom_op,
@@ -222,3 +223,139 @@ def mm_M1_16_K7168_N256(
     get_dsv3_router_gemm_module().mm_M1_16_K7168_N256(
         mat_a, mat_b, out, launch_with_pdl
     )
+
+
+# ============================================================================
+# tinygemm2: SM90+ BF16 small GEMM with bias (from TensorRT-LLM)
+# Computes: output = input @ weight.T + bias  (equivalent to F.linear)
+# ============================================================================
+
+
--
+@flashinfer_api
+def tinygemm_bf16(
+    input: torch.Tensor,
+    weight: torch.Tensor,
+    out: torch.Tensor,
+    bias: Optional[torch.Tensor] = None,
+    use_pdl: bool = False,
+) -> None:
+    """SM90+ optimized small GEMM: out = input @ weight.T + bias (equivalent to F.linear).
+
+    A latency-optimized, warp-specialized GEMM designed for tiny batch sizes (ideally
+    1-8 rows, where a single TILE_N=8 tile covers the entire batch dimension) using
+    Ampere-style HMMA instructions. Uses TMA for async bulk data loads and
+    mma.sync.aligned.m16n8k16 tensor core instructions with BF16 input/weight/bias/output
+    and FP32 internal accumulation. The warp-specialized design (384 threads: 4 compute +
+    8 DMA warps) with 16 pipeline stages and 4x stage unroll trades off peak throughput
+    in favor of minimal latency.
+
+    From TensorRT-LLM tinygemm2 kernel.
+
+    Args:
--
@@ -23,40 +23,61 @@ from ..api_logging import flashinfer_api
 from ..jit.mamba import (
     gen_selective_state_update_module,
     gen_selective_state_update_sm90_module,
-    gen_selective_state_update_sm100_module,
 )
 from ..utils import get_compute_capability, register_custom_op, register_fake_op


 @functools.cache
-def get_selective_state_update_module_base():
-    """Get cached JIT-compiled selective_state_update module (base version)."""
-    return gen_selective_state_update_module().build_and_load()
-
-
-@functools.cache
-def get_selective_state_update_module_sm90():
-    """Get cached JIT-compiled selective_state_update module (SM90/Hopper version)."""
-    return gen_selective_state_update_sm90_module().build_and_load()
-
-
--
 @flashinfer_api
@@ -78,6 +99,7 @@ def selective_state_update(
     intermediate_states_buffer: Optional[torch.Tensor] = None,
     intermediate_state_indices: Optional[torch.Tensor] = None,
     cache_steps: int = 0,
+    algorithm: str = "auto",
 ) -> torch.Tensor:
     r"""Selective state update operation for Mamba layers (the generation phase).

@@ -126,6 +148,10 @@ def selective_state_update(
         with shape (batch,)
     cache_steps : int
         Number of steps/tokens to cache for speculative decoding
+    algorithm : str
+        Algorithm to use: "auto" (default, picks the best kernel based on GPU arch,
+        data types, and problem size), "simple" (all GPUs), "vertical" and "horizontal"
+        (SM90+ only). MTP mode only supports "auto" or "simple".

     Returns
     -------
@@ -178,6 +204,30 @@ def selective_state_update(

@aleozlx aleozlx removed the v0.6.5 release blocker label for v0.6.5 label Feb 27, 2026
@aleozlx aleozlx merged commit a337e42 into flashinfer-ai:main Feb 27, 2026
49 of 67 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Mar 3, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants