fix: trtllm_mxint4_block_scale_moe unit test to index output list#2627
fix: trtllm_mxint4_block_scale_moe unit test to index output list#2627aleozlx merged 1 commit intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a regression in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe PR modifies test implementations in the MoE test file to adjust output extraction behavior. Two call_moe methods now return the first element of the output tensor instead of the entire output, aligning with upstream API changes in MoE entry points. Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~5 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
The pull request correctly addresses unit test failures for the MxInt4 block scale MoE implementation. The change ensures that the output of the trtllm_mxint4_block_scale_moe function is correctly indexed, as the API now returns a list of tensors instead of a single tensor. This fix aligns the MxInt4 test logic with the existing FP4 implementation and the updated library API.
|
/bot run |
|
[FAILED] Pipeline #44676334: 14/20 passed |
|
@jimmyzho ignore my comment above, this is purely a typo in the unit test i was worried about potentially escaped non-compatible api change, but confirmed from @flashinfer_api
def trtllm_fp8_per_tensor_scale_moe(
@@ -2257,10 +2323,11 @@ def trtllm_fp8_per_tensor_scale_moe(
routed_scaling_factor: Optional[float],
use_routing_scales_on_input: bool,
routing_method_type: int = 0,
+ do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
activation_type: int = ActivationType.Swiglu.value,
-) -> torch.Tensor:
+) -> Union[List[torch.Tensor], torch.Tensor]:
"""FP8 per tensor scale MoE operation.
Args:
@@ -2282,22 +2349,21 @@ def trtllm_fp8_per_tensor_scale_moe(
routed_scaling_factor: Scaling factor for routing
use_routing_scales_on_input: Whether to use routing scales on input
routing_method_type: Type of routing method to use (default: 0)
+ do_finalize: Whether to finalize the output (default: True).
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
--
@flashinfer_api
def trtllm_fp8_block_scale_moe(
@@ -2343,10 +2418,11 @@ def trtllm_fp8_block_scale_moe(
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
weight_layout: int = 0,
+ do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
-) -> torch.Tensor:
+) -> Union[List[torch.Tensor], torch.Tensor]:
"""FP8 block scale MoE operation.
Args:
@@ -2374,16 +2450,18 @@ def trtllm_fp8_block_scale_moe(
weight_layout: Weight layout format (default: WeightLayout.MajorK). Supported layouts:
- 0: MajorK - K-major layout [Mn, K]
- 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK]
+ do_finalize: Whether to finalize the output (default: True).
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
--
@flashinfer_api
def trtllm_fp8_block_scale_routed_moe(
@@ -2433,11 +2520,12 @@ def trtllm_fp8_block_scale_routed_moe(
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
weight_layout: int = 0,
+ do_finalize: bool = True,
enable_pdl: Optional[bool] = None,
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8,
-) -> torch.Tensor:
+) -> Union[List[torch.Tensor], torch.Tensor]:
"""FP8 block scale MoE operation with pre-computed routing (packed format).
This function is used when routing decisions have already been computed
@@ -2468,14 +2556,16 @@ def trtllm_fp8_block_scale_routed_moe(
use_shuffled_weight: Whether to use shuffled weights
weight_layout: Weight layout (0 = MajorK, 1 = BlockMajorK)
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
+ do_finalize: Whether to finalize the output (default: True).
--
@flashinfer_api
def trtllm_fp4_block_scale_moe(
@@ -2589,12 +2688,8 @@ def trtllm_fp4_block_scale_moe(
do_finalize (bool): Whether to finalize the output (default: False)
enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
activation_type (int): Type of activation function (default: 3 - Swiglu)
- - 0: Gelu
- - 1: Relu
- - 2: Silu
- 3: Swiglu
- 4: Geglu
- - 5: SwigluBias
- 6: Relu2
- 7: Identity
tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)
@@ -2726,12 +2821,8 @@ def trtllm_fp4_block_scale_routed_moe(
- 4: RenormalizeNaive (Softmax -> TopK -> Renormalize)
do_finalize (bool): Whether to finalize the output (default: False)
activation_type (int): Type of activation function (default: 3 - Swiglu)
- - 0: Gelu
- - 1: Relu
--
from ..api_logging import flashinfer_api
-from flashinfer.jit import gen_dsv3_router_gemm_module
+from flashinfer.jit import gen_dsv3_router_gemm_module, gen_tinygemm2_module
import functools
from types import SimpleNamespace
+from typing import Optional
import torch
from flashinfer.utils import (
register_custom_op,
@@ -222,3 +223,139 @@ def mm_M1_16_K7168_N256(
get_dsv3_router_gemm_module().mm_M1_16_K7168_N256(
mat_a, mat_b, out, launch_with_pdl
)
+
+
+# ============================================================================
+# tinygemm2: SM90+ BF16 small GEMM with bias (from TensorRT-LLM)
+# Computes: output = input @ weight.T + bias (equivalent to F.linear)
+# ============================================================================
+
+
--
+@flashinfer_api
+def tinygemm_bf16(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ out: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ use_pdl: bool = False,
+) -> None:
+ """SM90+ optimized small GEMM: out = input @ weight.T + bias (equivalent to F.linear).
+
+ A latency-optimized, warp-specialized GEMM designed for tiny batch sizes (ideally
+ 1-8 rows, where a single TILE_N=8 tile covers the entire batch dimension) using
+ Ampere-style HMMA instructions. Uses TMA for async bulk data loads and
+ mma.sync.aligned.m16n8k16 tensor core instructions with BF16 input/weight/bias/output
+ and FP32 internal accumulation. The warp-specialized design (384 threads: 4 compute +
+ 8 DMA warps) with 16 pipeline stages and 4x stage unroll trades off peak throughput
+ in favor of minimal latency.
+
+ From TensorRT-LLM tinygemm2 kernel.
+
+ Args:
--
@@ -23,40 +23,61 @@ from ..api_logging import flashinfer_api
from ..jit.mamba import (
gen_selective_state_update_module,
gen_selective_state_update_sm90_module,
- gen_selective_state_update_sm100_module,
)
from ..utils import get_compute_capability, register_custom_op, register_fake_op
@functools.cache
-def get_selective_state_update_module_base():
- """Get cached JIT-compiled selective_state_update module (base version)."""
- return gen_selective_state_update_module().build_and_load()
-
-
-@functools.cache
-def get_selective_state_update_module_sm90():
- """Get cached JIT-compiled selective_state_update module (SM90/Hopper version)."""
- return gen_selective_state_update_sm90_module().build_and_load()
-
-
--
@flashinfer_api
@@ -78,6 +99,7 @@ def selective_state_update(
intermediate_states_buffer: Optional[torch.Tensor] = None,
intermediate_state_indices: Optional[torch.Tensor] = None,
cache_steps: int = 0,
+ algorithm: str = "auto",
) -> torch.Tensor:
r"""Selective state update operation for Mamba layers (the generation phase).
@@ -126,6 +148,10 @@ def selective_state_update(
with shape (batch,)
cache_steps : int
Number of steps/tokens to cache for speculative decoding
+ algorithm : str
+ Algorithm to use: "auto" (default, picks the best kernel based on GPU arch,
+ data types, and problem size), "simple" (all GPUs), "vertical" and "horizontal"
+ (SM90+ only). MTP mode only supports "auto" or "simple".
Returns
-------
@@ -178,6 +204,30 @@ def selective_state_update( |
📌 Description
Fix unit test failures caused by change in #2548 to the API that now returns a list of tensors.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit