Skip to content

feat: cuteDSL fp4 moe for better DSR1 performance.#2398

Merged
nv-yunzheq merged 33 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cuteDSL_moe
Feb 7, 2026
Merged

feat: cuteDSL fp4 moe for better DSR1 performance.#2398
nv-yunzheq merged 33 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cuteDSL_moe

Conversation

@nv-yunzheq
Copy link
Collaborator

@nv-yunzheq nv-yunzheq commented Jan 22, 2026

📌 Description

cuteDSL fp4 moe from TRTLLM with fusion. issue #2259
The collective fusion performance improvement collected from TRTLLM PR 8880, PR 9288, PR 9618,.

Introducing two new API for moe
cute_dsl_fused_moe_nvfp4, CuteDslMoEWrapper.
The two API are functionally equivalent. The cute_dsl_fused_moe_nvfp4 function could run the operation directly, while the wrapper API splits tensor setup and execution to better support cuda graph. Using wrapper API is recommended.

The PR also introduces autotune functionality for this function.

Performance data:

$ python benchmarks/bench_moe_deepseek.py --num-tokens 1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384
====================================================================================================
DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP=1)
====================================================================================================
Model: hidden=7168, intermediate=2048, experts=256, top_k=8
EP Config: 256 local experts (simulating 1-way parallelism)
CUDA Graph: enabled, CUPTI: enabled
----------------------------------------------------------------------------------------------------
Tokens |     CuteDSL     |     CUTLASS     |     TRTLLM      | Speedup (CuteDSL/X) |  Winner
       |      ms  TFLOPS |      ms  TFLOPS |      ms  TFLOPS |  CUTLASS   TRTLLM |
----------------------------------------------------------------------------------------------------
     1 |   0.064    10.9 |   0.098     7.2 |   0.053    13.2 |    1.53x    0.83x |  TRTLLM
     2 |   0.077    18.2 |   0.109    13.0 |   0.064    21.9 |    1.40x    0.83x |  TRTLLM
     4 |   0.097    28.9 |   0.131    21.5 |   0.086    32.7 |    1.35x    0.88x |  TRTLLM
     8 |   0.098    57.5 |   0.133    42.4 |   0.092    61.4 |    1.36x    0.94x |  TRTLLM
    16 |   0.102   110.3 |   0.138    82.0 |   0.104   108.5 |    1.35x    1.02x | CuteDSL
    32 |   0.115   196.3 |   0.153   146.9 |   0.138   163.4 |    1.34x    1.20x | CuteDSL
    64 |   0.123   365.7 |   0.168   269.0 |   0.154   292.3 |    1.36x    1.25x | CuteDSL
   128 |   0.134   674.3 |   0.217   416.0 |   0.174   519.4 |    1.62x    1.30x | CuteDSL
   256 |   0.144  1252.3 |   0.250   723.0 |   0.220   821.3 |    1.73x    1.52x | CuteDSL
   512 |   0.200  1802.3 |   0.336  1073.4 |   0.271  1331.9 |    1.68x    1.35x | CuteDSL
  1024 |   0.286  2520.1 |   0.480  1501.8 |   0.570  1265.7 |    1.68x    1.99x | CuteDSL
  2048 |   0.469  3073.8 |   0.722  1998.8 |   0.542  2660.6 |    1.54x    1.16x | CuteDSL
  4096 |   0.852  3387.2 |   1.274  2266.0 |   0.845  3417.2 |    1.49x    0.99x |  TRTLLM
  8192 |   1.676  3444.7 |   2.364  2441.4 |   1.590  3630.7 |    1.41x    0.95x |  TRTLLM
 16384 |   3.290  3509.3 |   4.639  2488.8 |   3.383  3412.5 |    1.41x    1.03x | CuteDSL
----------------------------------------------------------------------------------------------------

$ python benchmarks/bench_moe_deepseek.py --num-tokens 1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384 --ep 16
====================================================================================================
DeepSeek-V3 MoE Benchmark: CuteDSL vs CUTLASS vs TRTLLM (EP=16)
====================================================================================================
Model: hidden=7168, intermediate=2048, experts=256, top_k=8
EP Config: 16 local experts (simulating 16-way parallelism)
CUDA Graph: enabled, CUPTI: enabled
----------------------------------------------------------------------------------------------------
Tokens |     CuteDSL     |     CUTLASS     |     TRTLLM      | Speedup (CuteDSL/X) |  Winner
       |      ms  TFLOPS |      ms  TFLOPS |      ms  TFLOPS |  CUTLASS   TRTLLM |
----------------------------------------------------------------------------------------------------
     1 |   0.021     2.1 |   0.024     1.8 |   0.018     2.5 |    1.16x    0.85x |  TRTLLM
     2 |   0.020     4.3 |   0.023     3.8 |   0.018     4.8 |    1.13x    0.90x |  TRTLLM
     4 |   0.020     8.6 |   0.024     7.4 |   0.019     9.2 |    1.16x    0.93x |  TRTLLM
     8 |   0.020    17.5 |   0.023    15.4 |   0.020    17.7 |    1.14x    0.99x |  TRTLLM
    16 |   0.021    33.9 |   0.024    29.9 |   0.022    32.8 |    1.13x    1.03x | CuteDSL
    32 |   0.020    69.4 |   0.024    59.8 |   0.025    57.3 |    1.16x    1.21x | CuteDSL
    64 |   0.022   129.2 |   0.024   118.7 |   0.025   112.1 |    1.09x    1.15x | CuteDSL
   128 |   0.021   267.7 |   0.024   236.9 |   0.026   216.7 |    1.13x    1.24x | CuteDSL
   256 |   0.023   487.3 |   0.025   456.4 |   0.027   411.6 |    1.07x    1.18x | CuteDSL
   512 |   0.025   917.5 |   0.031   731.0 |   0.030   753.2 |    1.26x    1.22x | CuteDSL
  1024 |   0.028  1638.6 |   0.036  1243.9 |   0.040  1137.0 |    1.32x    1.44x | CuteDSL
  2048 |   0.036  2494.3 |   0.045  1988.4 |   0.061  1481.9 |    1.25x    1.68x | CuteDSL
  4096 |   0.050  3597.4 |   0.061  2934.5 |   0.101  1794.7 |    1.23x    2.00x | CuteDSL
  8192 |   0.103  3490.5 |   0.150  2409.6 |   0.209  1723.1 |    1.45x    2.03x | CuteDSL
 16384 |   0.140  5160.4 |   0.198  3638.0 |   0.364  1980.6 |    1.42x    2.61x | CuteDSL
----------------------------------------------------------------------------------------------------

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added comprehensive Mixture-of-Experts (MoE) support with FP4 quantization for Blackwell GPUs.
    • Introduced multi-precision MoE kernels supporting FP16, BF16, FP8, and FP4 formats.
    • Added CUDA graph-compatible MoE APIs for improved performance and flexibility.
    • Implemented automatic kernel tuning for MoE operations.
    • Added expert parallelism support for distributed MoE inference.
  • Documentation

    • Added extensive test suite validating MoE functionality across precision formats and configurations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 22, 2026

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive Mixture-of-Experts (MoE) pipeline for DeepSeek-V3 on Blackwell GPUs (SM100+), spanning CUDA kernels, TVM FFI bindings, Python CuTe-DSL wrappers, and integrated testing/benchmarking. It adds permute/unpermute/sort/activation kernels with multi-precision support (FP16, BF16, FP8, FP4) and high-level APIs for fused MoE GEMM operations with persistent tile scheduling and CUDA graph compatibility.

Changes

Cohort / File(s) Summary
CUDA C++ Core MoE Kernels
csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h, moeUtils.cu
Introduces templated MoE operations: moePermute, moeUnpermute, moeOutputMemset, moeActivation with kernel/host launcher implementations. Supports multiple data types (half, bf16, fp8, fp4) with explicit instantiations and occupancy-aware kernel configuration.
TVM FFI Bindings
csrc/moe_utils_binding.cu
Exposes MoE operations via TVM FFI with function variants for fp16, bf16, fp8, fp4 (gated by feature flags). Includes moe_sort binding for DeepSeek-V3 routing and helper utilities (computeLog2).
CUTLASS Activation Adaptors
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh
Adds IdentityAdaptor, GLUAdaptor, SwigluBiasAdaptor template classes for activation operations in CUTLASS epilogue logic.
Configuration Headers
csrc/nv_internal/include/tensorrt_llm/common/config.h
Introduces ABI namespace macros (TRTLLM_ABI_NAMESPACE, TRTLLM_NAMESPACE_BEGIN/END) for versioned API exposure.
JIT Module Generator
flashinfer/jit/moe_utils.py, __init__.py
Implements gen_moe_utils_module() for JIT-compiling MoE bindings, including NVCC flags and include paths for FP8/FP4/BF16 support.
CuTe-DSL Utilities
flashinfer/cute_dsl/utils.py, __init__.py
Adds scale-factor layout utilities (convert_sf_to_mma_layout, convert_sf_from_mma_layout, get_mma_sf_shape) and hardware caching (get_max_active_clusters).
Blackwell Kernel Implementations
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py, blockscaled_contiguous_grouped_gemm_swiglu_fusion.py, blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Full-featured persistent-tile-scheduled GEMM kernels with SwiGLU/finalize fusions, TMA operations, and dynamic compilation caching. Extensive validation and tensor layout handling.
CuTe-DSL Pipeline & Utilities
flashinfer/fused_moe/cute_dsl/blackwell/custom_pipeline.py, utils.py, blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
Advanced pipeline constructs (PipelineTmaUmma, PipelineUmmaAsync, PipelineCpAsyncUmma) with barrier synchronization, atomic ops, and gather-based GEMM. Grid-dependency control and function generation for SILU/sigmoid.
High-Level MoE APIs
flashinfer/fused_moe/cute_dsl/fused_moe.py, moe_utils.py, tuner.py, __init__.py
Functional and wrapper APIs (cute_dsl_fused_moe_nvfp4, CuteDslMoEWrapper) with CUDA graph support, auto-tuning framework, and comprehensive MoE operation wrappers (permute, unpermute, sort, activation variants).
Top-Level Integration
flashinfer/__init__.py, flashinfer/fused_moe/__init__.py
Conditional imports of CuteDSL MoE APIs with graceful fallback when CuTe-DSL is unavailable.
Testing Suite
tests/moe/test_cute_dsl_fused_moe.py, __init__.py
Comprehensive test module with FP4 quantization reference, MoE computation validation, functional/wrapper API tests, CUDA graph capture/replay, and consistency checks across implementations.
Benchmark Suite
benchmarks/bench_moe_deepseek.py
Feature-rich DeepSeek-V3 MoE benchmark comparing CuteDSL, CUTLASS, and TRT-LLM backends across token counts with performance metrics (TFLOPS, latency).
Hardware Optimization
flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py
Minor update to cache max_active_clusters computation for Blackwell.

Sequence Diagram(s)

sequenceDiagram
    participant Python as Python/PyTorch
    participant Permute as moe_permute Kernel
    participant Gemm1 as GEMM1 + SwiGLU
    participant Sort as moe_sort (Routing)
    participant Gather as Gather Kernel
    participant Gemm2 as GEMM2 + Finalize
    participant Unpermute as moe_unpermute Kernel
    participant Output as Output Buffer

    Python->>Sort: Route tokens to experts<br/>(token_selected_experts, scores)
    Sort->>Permute: Compute routing mappings<br/>(tile indices, permutation)
    Python->>Permute: Call moe_permute<br/>(input → permuted layout)
    Permute->>Gemm1: Permuted tokens<br/>(FP4 quantized)
    Gemm1->>Gemm1: Per-expert GEMM + SwiGLU<br/>(fused activation)
    Gemm1->>Gather: Intermediate outputs
    Gather->>Gemm2: Gather by expert<br/>(without re-permutation)
    Gemm2->>Gemm2: Per-expert GEMM2<br/>(with finalize epilogue)
    Gemm2->>Unpermute: Expert outputs<br/>(FP4 → FP16/BF16)
    Python->>Unpermute: Call moe_unpermute<br/>(apply top-k scaling)
    Unpermute->>Output: Scatter to original<br/>token positions
    Output->>Python: Final MoE output<br/>(combined experts)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • Adds DeepSeek-V3 routing/moe_sort binding which integrates routingDeepSeek::run kernel for token-to-expert sorting and mapping, overlapping with DeepSeek routing kernel changes.
  • Modifies CUTLASS MoE GEMM codepath with new kernels, adaptors, and template instantiations that interact with existing CUTLASS dispatch and configuration mechanisms.
  • Updates CuTe-DSL kernel stack with Blackwell-specific persistent-tile-scheduled GEMMs and shared memory optimizations that align with broader Blackwell optimization efforts.

Suggested labels

run-ci

Suggested reviewers

  • aleozlx
  • nvmbreughe
  • djmmoss
  • yzh119
  • bkryu
  • cyx-6
  • jiahanc

Poem

🐰 Whiskers twitch with glee!
A MoE pipeline blooms so bright,
Blackwell kernels dance with light—
Permute, sort, and GEMM unite! ✨
DeepSeek flows to expert heights!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 69.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding cuteDSL fp4 MoE implementation for performance improvements. It directly relates to the bulk of changes in the PR.
Description check ✅ Passed The PR description includes the required Description section with issue reference, Related Issues, and Pull Request Checklist. However, key checklist items (pre-commit checks and tests) remain unchecked, and the description could better address test coverage for the extensive new code.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities by adding a highly optimized, fused Mixture-of-Experts (MoE) implementation. By integrating CuTe-DSL kernels from TensorRT-LLM, the new pipeline leverages advanced GPU features to deliver efficient FP4 computations, reducing memory bandwidth and improving throughput for large language models. The focus is on providing a complete, end-to-end MoE solution with built-in auto-tuning for optimal performance on Blackwell GPUs.

Highlights

  • FP4 Mixture-of-Experts (MoE) Support: Introduced comprehensive support for FP4 Mixture-of-Experts (MoE) computations leveraging NVIDIA's CuTe-DSL kernels, adapted from TensorRT-LLM, specifically optimized for Blackwell (SM100) GPUs.
  • Fused MoE Pipeline: Implemented a high-level cute_dsl_fused_moe_nvfp4 API that orchestrates the entire MoE pipeline, including token sorting, GEMM1 with fused gather and SwiGLU activation, and GEMM2 with fused finalize (unpermute, scaling, and atomic scatter-reduce).
  • CuTe-DSL Kernel Integration: Integrated several new CuTe-DSL kernels for specialized MoE operations: blockscaled_contiguous_grouped_gemm, blockscaled_contiguous_grouped_gemm_swiglu_fusion, blockscaled_contiguous_grouped_gemm_finalize_fusion, and blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.
  • Core MoE Utilities: Added C++/CUDA kernels and Python wrappers for fundamental MoE utilities such as moe_sort (token routing), moe_permute (data reordering), moe_unpermute (reverse reordering), moe_output_memset (output buffer initialization), and moe_activation (various activation functions including SwiGLU, GeGLU, GELU, SiLU, ReLU).
  • Auto-Tuning Framework: Introduced a new auto-tuning framework (flashinfer.cute_dsl.tuner) for the CuTe-DSL fused MoE kernels, allowing for dynamic performance optimization across different GEMM tactics.
  • Blackwell (SM100) Optimizations: The new CuTe-DSL kernels and pipeline components are specifically designed and optimized for the NVIDIA Blackwell architecture, utilizing features like persistent tile scheduling and warp specialization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature: fused Mixture of Experts (MoE) kernels using CuteDSL for FP4 on Blackwell architecture. The implementation is adapted from TensorRT-LLM and includes several advanced techniques like warp specialization and persistent scheduling. The changes are extensive, adding new CUDA kernels, Python bindings, and high-level APIs with auto-tuning support.

My review focused on the integration and correctness of the new components. I've identified a few issues:

  • A missing validation check in moe_sort that could lead to runtime errors if tile_tokens_dim is not a power of 2.
  • A potential bug in a wrapper method within a GEMM kernel due to incorrect integer division.
  • Some code duplication in utility files that should be refactored.

Overall, this is a substantial contribution that brings high-performance MoE capabilities to FlashInfer. The code is well-structured, though complex due to the nature of CuteDSL. Addressing the identified issues will improve the robustness and maintainability of this new feature.

Comment on lines 323 to 332
def moe_sort(
token_selected_experts: torch.Tensor,
token_final_scales: torch.Tensor,
num_experts: int,
top_k: int,
local_expert_offset: int = 0,
num_local_experts: Optional[int] = None,
tile_tokens_dim: int = 128,
enable_pdl: bool = False,
) -> Tuple[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The moe_sort function passes tile_tokens_dim to a CUDA kernel that expects it to be a power of two. The underlying C++ code uses a computeLog2 function that will return -1 for non-power-of-two inputs, which can lead to undefined behavior or cryptic errors in the routingDeepSeek kernel. It's crucial to validate this parameter in the Python wrapper to prevent such issues.

Please add an assertion at the beginning of the function body:

assert (tile_tokens_dim > 0) and ((tile_tokens_dim & (tile_tokens_dim - 1)) == 0), "tile_tokens_dim must be a power of 2"

epilogue_op: cutlass.Constexpr = lambda x: x,
):
scale_k = k // scaling_vector_size
num_tiles = m // tile_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of num_tiles using integer division m // tile_size is incorrect if m is not perfectly divisible by tile_size. This will result in an undersized tile_idx_to_group_idx tensor, leading to out-of-bounds access when the kernel scheduler processes the last partial tile. You should use ceiling division to ensure the number of tiles is calculated correctly.

Suggested change
num_tiles = m // tile_size
num_tiles = (m + tile_size - 1) // tile_size

Comment on lines +61 to +197
self._c_pointer = None
assert int(self._pointer) % self._assumed_align == 0, (
f"pointer must be {self._assumed_align} bytes aligned"
)

def size_in_bytes(self) -> int:
return ctypes.sizeof(ctypes.c_void_p(int(self._pointer)))

def __get_mlir_types__(self):
return [self.mlir_type]

def __c_pointers__(self):
if self._c_pointer is None:
self._desc = ctypes.c_void_p(int(self._pointer))
self._c_pointer = ctypes.addressof(self._desc)
return [self._c_pointer]

def __new_from_mlir_values__(self, values):
assert len(values) == 1
return values[0]

# Move mlir Type out of __init__ to decouple with mlir Context
@property
def mlir_type(self) -> ir.Type:
return _cute_ir.PtrType.get(
self._dtype.mlir_type, self._addr_space, self._assumed_align
)

@property
def dtype(self) -> Type[Numeric]:
return self._dtype

@property
def memspace(self):
return self._addr_space

def align(self, min_align: int, *, loc=None, ip=None) -> Pointer:
raise NotImplementedError("align is not supported in runtime")

def verify(self, expected_py_type):
if expected_py_type is Pointer or (
isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer
):
return True

return False

def __str__(self) -> str:
return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>"

def __repr__(self):
return self.__str__()


def make_ptr(
dtype: Type[Numeric],
value: Union[int, ctypes._Pointer],
mem_space: AddressSpace = AddressSpace.generic,
assumed_align=None,
) -> Pointer:
"""Creates a pointer from a memory address.

Args:
dtype (Type[Numeric]): Data type of the pointer elements.
value (Union[int, ctypes._Pointer]): Memory address as an integer or ctypes pointer.
mem_space (AddressSpace, optional): Memory address space. Defaults to AddressSpace.generic.
assumed_align (int, optional): Alignment in bytes. Defaults to None.

Returns:
Pointer: A pointer object.

Example:
```python
import numpy as np
import ctypes
from cutlass import Float32
from cutlass.cute.runtime import make_ptr

# Create a numpy array
a = np.random.randn(16, 32).astype(np.float32)
# Get pointer address as ctypes pointer
ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
# Create pointer from address
y = make_ptr(cutlass.Float32, ptr_address)
```
"""
# check if value is int or ctypes.POINTER
if isinstance(value, int):
address_value = value
elif isinstance(value, ctypes._Pointer):
# get address value
address_value = ctypes.cast(value, ctypes.c_void_p).value
assert address_value is not None, "Pointer address is None"
else:
raise TypeError(
f"Expect int or ctypes.POINTER for value but got {type(value)=}"
)

return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _Pointer class and make_ptr function are duplicated from flashinfer/cute_dsl/utils.py. This duplicated code appears to be unused within the blackwell module, as other files import these utilities from the central flashinfer.cute_dsl.utils location. Removing this duplication will improve maintainability and prevent potential inconsistencies in the future.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Fix all issues with AI agents
In `@csrc/moe_utils_binding.cu`:
- Around line 273-338: The kernel assumes token_final_scales is bfloat16 but the
public API allows float32; in moe_sort validate the dtype for
token_final_scales_ptr (or enforce conversion in the Python wrapper) and set
routingData fields accordingly: if callers pass float32, either (A) convert the
tensor to bfloat16 before calling moe_sort (mirror how token_selected_experts is
converted to int32) or (B) detect float32 here and set routingData.mDtypeExpW
and routingData.mDtypeBias to Fp32 and ensure routingData.mPtrTopKWeights is
treated as float*; update the docstring if you choose to restrict to bfloat16.
Ensure checks reference token_final_scales_ptr, routingData.mDtypeExpW,
routingData.mDtypeBias, and routingData.mPtrTopKWeights (or the Python wrapper
conversion path used for token_selected_experts).

In `@csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu`:
- Around line 119-121: The cached static smCount computed via
tensorrt_llm::common::getMultiProcessorCount() prevents correct behavior when
the process switches CUDA devices; change the declaration so smCount is
evaluated per-call (e.g., remove static or make it thread_local as done
elsewhere) and recompute it before calculating maxBlocksPerSM and blocks; update
all occurrences where smCount is defined (the instances that call
getMultiProcessorCount()) so each call queries the current device rather than
reusing a process-global cached value.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh`:
- Around line 56-69: The gate is only clamped from above which can let very
negative values pass into the sigmoid and cause numerical instability; in
SwigluBiasAdaptor::operator() clamp the gate symmetrically (e.g., use
cutlass::maximum<T>{}(cutlass::minimum<T>{}(gate, limit), -limit)) before
computing the sigmoid and using it for the gate multiplication so the sigmoid
receives a bounded input and numerical overflow/underflow is avoided.

In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 72-160: Add the missing `@flashinfer_api` decorator to the public
function create_finalize_fusion_tensors to enable FLASHINFER_LOGLEVEL-based API
logging; also modify its signature to accept an optional token_final_scales:
Optional[torch.Tensor] = None parameter (dtype final_scale_dtype, shape
(seq_len, topk)) and, if provided, validate shape/dtype and use it instead of
generating random values, otherwise keep the current randomized normalized
initialization but update the docstring to state these are placeholder/test
values; leave the existing _finalize_kernel_cache behavior unchanged.

In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py`:
- Around line 241-272: The FP4 output scale allocation assumes permuted_m is
divisible by 128 and scale_intermediate_size (computed from intermediate_size //
sf_vec_size) is divisible by 4; add explicit validation in the generate_sfc
branch before allocating out_scale: check that permuted_m % 128 == 0 and
scale_intermediate_size % 4 == 0 (using the existing symbols generate_sfc,
out_scale, permuted_m, intermediate_size, sf_vec_size, scale_intermediate_size),
and raise a clear ValueError if either check fails so the buffer size cannot be
undersized and cause out-of-bounds writes.
- Around line 69-82: Replace the `@functools.lru_cache`(maxsize=None) decorator on
_get_compiled_swiglu_kernel with `@functools.cache` and remove the unused shape
parameters permuted_m, n, k, and num_experts from the function signature and all
call sites so they are not included in the cache key; update any callers that
pass those four parameters to stop supplying them and adjust the function
internals to use only the remaining parameters (ab_dtype_name, sf_dtype_name,
c_dtype_name, sf_vec_size, mma_tiler_mn, cluster_shape_mn, vectorized_f32). Also
make the same decorator/signature change for the other identical
_get_compiled_swiglu_kernel occurrence in this module.

In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py`:
- Around line 361-363: Replace the two assert statements that check tensor
device types (the checks using a.device.type == "cuda" and b.device.type ==
"cuda") with explicit runtime validation: use if statements that raise a clear
exception (e.g., ValueError or RuntimeError) when a or b are not on CUDA, and
include a helpful message mentioning which tensor failed the check; update the
validation block in blockscaled_contiguous_grouped_gemm.py to perform these
explicit checks instead of using assert so they are not stripped in optimized
Python.

In `@flashinfer/cute_dsl/tuner.py`:
- Around line 260-266: The initializer lambda in tuner.py currently samples
expert indices with a hardcoded range of 0..7 (comment "num_experts=8 typical"),
which can produce invalid indices for models with different expert counts;
update the lambda used in the tuner initialization (the anonymous function that
calls torch.randint) to derive the upper bound from the actual model/runner
expert count (e.g., use a passed-in num_experts or runner.num_experts) or at
minimum clamp to min/max to avoid out-of-range values; locate the lambda in
tuner.py and replace the hardcoded 8 with the dynamic num_experts value (or a
safe expression like max(1, num_experts)) so sampled indices are always valid.
🧹 Nitpick comments (28)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh (2)

28-40: Unused member variables alpha, beta, limit in IdentityAdaptor.

These members are declared but never used in operator(). If they're for API consistency with other adaptors, consider documenting this intent. Otherwise, they add unnecessary storage overhead.


42-54: Same observation for GLUAdaptor - unused alpha, beta, limit members.

The members are not referenced in the operator() implementation. If these are reserved for future use or API consistency, a brief comment would clarify intent.

flashinfer/cute_dsl/utils.py (1)

29-31: Duplicate ceil_div implementation - consider reusing existing utility.

This function is already defined in flashinfer/utils.py (with proper docstring) and duplicated in several other files. Consider importing from flashinfer.utils to reduce duplication.

♻️ Suggested change
-def ceil_div(a: int, b: int) -> int:
-    """Ceiling division."""
-    return (a + b - 1) // b
+from flashinfer.utils import ceil_div

Based on relevant code snippets showing ceil_div exists in flashinfer/utils.py (lines 621-632).

flashinfer/cute_dsl/fused_moe.py (4)

110-125: Prefix unused variable with underscore.

total_num_padded_tokens from moe_sort is unpacked but never used. Prefix with underscore to indicate intentional non-use.

♻️ Suggested fix
     (
         tile_idx_to_expert_idx,
         tile_idx_to_mn_limit,
         expanded_idx_to_permuted_idx,
         permuted_idx_to_expanded_idx,
-        total_num_padded_tokens,
+        _total_num_padded_tokens,
         num_non_exiting_tiles,
     ) = moe_sort(

128-136: Auxiliary stream creation without explicit lifecycle management.

A new torch.cuda.Stream() is created at line 133 if aux_stream is None, but there's no mechanism to reuse or properly manage this stream across calls. Consider:

  1. Documenting that callers should pass a reusable stream for optimal performance
  2. Using a module-level cached stream to avoid repeated allocation

336-342: Duplicate output allocation logic.

The moe_output allocation at lines 337-342 duplicates the logic in _cute_dsl_fused_moe_nvfp4_impl (lines 102-107). Since moe_output is passed to the runner, the allocation in the public API is necessary, but consider removing the duplicate in _impl or documenting why both are needed.


179-180: memset_event.wait() could be more explicit about stream context.

The call at line 180 waits on the default stream after the memset_event.record() at line 177 is called on aux_stream. While this synchronization pattern is correct, consider explicitly documenting the stream interaction or adding error handling for potential failures in the aux_stream work (e.g., exceptions in moe_output_memset).

flashinfer/cute_dsl/tuner.py (3)

250-273: Mutable class attributes should be annotated with ClassVar.

dynamic_tensor_initializers and tuning_config are class-level attributes that should be typed with typing.ClassVar to indicate they're shared across instances and not instance attributes.

♻️ Suggested fix
+from typing import Any, Callable, ClassVar, Dict, List, Tuple
+
 class CuteDslFusedMoENvfp4Runner(TunableRunner):
     ...
     # Tensor initializers for dynamic tensors (indices 0, 1, 2, 3, 11)
     # These create valid dummy tensors for profiling with different num_tokens
-    dynamic_tensor_initializers = [
+    dynamic_tensor_initializers: ClassVar[List[Callable]] = [
         ...
     ]

     # Tuning config with dynamic tensor specs for num_tokens dimension
-    tuning_config = TuningConfig(
+    tuning_config: ClassVar[TuningConfig] = TuningConfig(
         ...
     )

341-347: PEP 484 violation: implicit Optional type.

tactic: Tuple[Any, ...] = None implicitly allows None but the type hint doesn't reflect this. Use explicit Optional or union syntax.

♻️ Suggested fix
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
     def forward(  # type: ignore[override]
         self,
         inputs: List[torch.Tensor],
-        tactic: Tuple[Any, ...] = None,  # type: ignore[assignment]
+        tactic: Optional[Tuple[Any, ...]] = None,
         do_preparation: bool = False,
         **kwargs: Any,
     ) -> torch.Tensor:

362-363: Handle tactic == -1 edge case.

The condition tactic is None or tactic == -1 suggests -1 is a sentinel value. Document this behavior or use a more explicit sentinel (e.g., a constant).

flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py (2)

144-150: Ambiguous variable name l - rename for clarity.

The variable l (lowercase L) is easily confused with 1 (one). Rename to something more descriptive like num_groups or batch_dim.

♻️ Suggested fix
 def create_scale_factor_tensor(
-    l: int,
+    num_groups: int,
     mn: int,
     k: int,
     sf_vec_size: int,
     dtype: Type[cutlass.Numeric],
 ) -> Tuple[torch.Tensor, cute.Tensor, torch.Tensor]:
     """Create scale factor tensors in the MMA-compatible layout.
     ...
     Args:
-        l: Batch/expert dimension
+        num_groups: Batch/expert dimension

And update all references to l within the function.


176-177: Another duplicate ceil_div definition.

This is the same utility already present in flashinfer/utils.py and flashinfer/cute_dsl/utils.py. Import from the canonical location instead.

♻️ Suggested fix
+from flashinfer.utils import ceil_div
+
 def create_scale_factor_tensor(...):
-    def ceil_div(a, b):
-        return (a + b - 1) // b
-
     sf_k = ceil_div(k, sf_vec_size)
csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu (2)

134-136: Missing error checking for cudaLaunchKernelEx.

The return value of cudaLaunchKernelEx is not checked. Consider adding error handling to catch launch failures.

🔧 Suggested fix
-  cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf,
-                     tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles,
-                     hidden_size, top_k, tile_size);
+  TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf,
+                     tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles,
+                     hidden_size, top_k, tile_size));

This applies to all cudaLaunchKernelEx calls at lines 235, 334, and 463 as well.


56-62: Missing return value check for cudaOccupancyMaxActiveBlocksPerMultiprocessor.

The CUDA API call result is discarded. If it fails, numBlocks remains 0, which could cause issues downstream.

🔧 Suggested fix
 template <typename KernelFunc>
 int32_t getMaxActiveBlocksPerSM(KernelFunc kernel, int32_t threadsPerBlock,
                                 size_t dynamicSmemSize) {
   int numBlocks = 0;
-  cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, threadsPerBlock,
-                                                dynamicSmemSize);
+  TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, threadsPerBlock,
+                                                dynamicSmemSize));
   return numBlocks;
 }
csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h (1)

108-113: Forward declaration without implementation.

moeActivationQuantize is declared here but not implemented in the corresponding .cu file. Line 480 of moeUtils.cu notes this is deferred. Consider adding a comment here to indicate the function is not yet implemented to avoid linker errors if called.

+// Note: Implementation deferred - will be added when NVFP4 output support is needed.
 template <typename InputType, typename OutputType, typename SFType>
 void moeActivationQuantize(InputType const* input, OutputType* output, float const* global_sf,
flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (3)

293-318: Consider adding @functools.cache decorator for module-level caching.

Per coding guidelines, Python API functions should use @functools.cache decorator to avoid recompilation. While the function uses _gather_kernel_cache internally, adding the decorator could provide additional caching benefits at the function level for repeated identical calls.


400-401: Replace assert with explicit exception for input validation.

Using assert for input validation in public APIs can be disabled with -O flag. Use explicit exceptions instead.

🔧 Suggested fix
-    assert a.device.type == "cuda", "Input tensors must be on CUDA device"
-    assert b.device.type == "cuda", "Input tensors must be on CUDA device"
+    if a.device.type != "cuda":
+        raise ValueError("Input tensor 'a' must be on CUDA device")
+    if b.device.type != "cuda":
+        raise ValueError("Input tensor 'b' must be on CUDA device")

189-191: Global kernel cache is not thread-safe.

_gather_kernel_cache is a module-level mutable dictionary that could cause race conditions in multi-threaded scenarios. Consider using threading.Lock or functools.lru_cache for thread-safe caching.

🔧 Suggested approach
import threading

_gather_kernel_cache: Dict[Tuple, Any] = {}
_gather_kernel_cache_lock = threading.Lock()

# Then in _get_compiled_gather_kernel:
with _gather_kernel_cache_lock:
    if cache_key not in _gather_kernel_cache:
        # ... compile kernel ...
        _gather_kernel_cache[cache_key] = compiled_gemm
    return _gather_kernel_cache[cache_key]
flashinfer/cute_dsl/blackwell/custom_pipeline.py (2)

61-71: Unused parameter cta_layout_vmnk.

The parameter is accepted but never used in the function body. If this is reserved for future use, consider documenting it or using _ prefix.

🔧 Suggested fix
-def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):
+def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None):  # noqa: ARG001
     """Initializes the mbarrier and synchronizes the threadblock or cluster.
 
     This function places a fence on the mbarrier initialization to ensure
     proper synchronization across the threadblock or cluster.
 
     Args:
-        cta_layout_vmnk (Optional[cute.Layout]): The CTA layout for VMNK. Defaults to None.
+        cta_layout_vmnk (Optional[cute.Layout]): Reserved for future cluster sync. Defaults to None.
     """
     cute.arch.mbarrier_init_fence()

184-187: Consider using TypeError for type validation.

Static analysis suggests TypeError is more appropriate when checking instance types. This is a minor style improvement.

🔧 Suggested fix
         if not isinstance(barrier_storage, cute.Pointer):
-            raise ValueError(
+            raise TypeError(
                 f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
             )

This applies to similar checks at lines 329-332 and 484-487.

flashinfer/cute_dsl/blackwell/utils.py (1)

61-196: Duplicated code from flashinfer/cute_dsl/utils.py.

The _Pointer class (lines 62-149) and make_ptr function (lines 152-196) are nearly identical to those in flashinfer/cute_dsl/utils.py (see relevant_code_snippets). The comment on line 61 mentions "WAR for CuTeDSL make_ptr implementation" - if this is a temporary workaround, consider adding a TODO to consolidate once the upstream issue is resolved.

# WAR for CuTeDSL make_ptr implementation
# TODO: Remove this once upstream CuTeDSL provides the fix, and import from flashinfer.cute_dsl.utils
flashinfer/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py (2)

754-754: Unused unpacked variables bidy and bidz.

These variables are unpacked but never used. Use underscore prefix to indicate intentional discard.

🔧 Suggested fix
-        bidx, bidy, bidz = cute.arch.block_idx()
+        bidx, _bidy, _bidz = cute.arch.block_idx()

1864-1871: Unused parameter tidx in epilog_gmem_copy_and_partition.

The tidx parameter is declared but not used in the method body. Consider removing it or adding # noqa: ARG002 if it's part of a required interface.

🔧 Suggested fix
     def epilog_gmem_copy_and_partition(
         self,
-        tidx: cutlass.Int32,
+        tidx: cutlass.Int32,  # noqa: ARG002 - kept for interface consistency
         atom: Union[cute.CopyAtom, cute.TiledCopy],
flashinfer/jit/moe_utils.py (1)

17-26: Cache the JIT spec generator to avoid redundant registrations.

A module-level cache keeps JitSpec creation idempotent and aligns with the repo caching guidance. As per coding guidelines, please add caching.

♻️ Proposed change
+import functools
+
 def gen_moe_utils_module() -> JitSpec:
+@functools.cache
+def gen_moe_utils_module() -> JitSpec:
flashinfer/cute_dsl/blackwell/__init__.py (1)

43-53: Optional: sort __all__ to satisfy Ruff (RUF022).

If Ruff is enforced, sorting will keep lint clean.

♻️ Proposed change
 __all__ = [
-    "Sm100BlockScaledContiguousGroupedGemmKernel",
-    "Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel",
-    "Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel",
     "BlockScaledContiguousGatherGroupedGemmKernel",
+    "Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel",
+    "Sm100BlockScaledContiguousGroupedGemmKernel",
+    "Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel",
+    "TRTLLM_ENABLE_PDL",
     "cvt_sf_MKL_to_M32x4xrm_K4xrk_L",
-    "TRTLLM_ENABLE_PDL",
     "griddepcontrol_launch_dependents",
     "griddepcontrol_wait",
     "is_power_of_2",
 ]
tests/moe/test_cute_dsl_fused_moe.py (2)

36-41: Consider using flashinfer.utils.is_sm100a_supported() for GPU capability check.

As per coding guidelines, tests should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures.

♻️ Suggested refactor
-def is_blackwell():
-    """Check if running on Blackwell GPU (SM100+)."""
-    if not torch.cuda.is_available():
-        return False
-    props = torch.cuda.get_device_properties(0)
-    return props.major >= 10
+from flashinfer.utils import is_sm100a_supported
+
+def is_blackwell():
+    """Check if running on Blackwell GPU (SM100+)."""
+    return is_sm100a_supported()

341-394: Consider adding a numerical accuracy check for expert parallelism tests.

The test validates that results don't contain NaN/Inf but skips strict accuracy comparison. While the comment explains the semantic difference from filtering, a basic sanity check (e.g., output magnitude is reasonable) would strengthen the test.

flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

256-278: Consider adding @functools.cache decorator per coding guidelines.

As per coding guidelines, Python API functions should use @functools.cache decorator to implement module-level caching and avoid recompilation. The @flashinfer_api decorator is correctly used for debugging.

Note: The internal _get_compiled_finalize_kernel already implements caching, so this may be intentionally omitted to avoid double-caching. If so, a brief comment explaining this would help future maintainers.

Comment on lines 273 to 338
void moe_sort(
// Inputs
int64_t token_selected_experts_ptr, // [num_tokens, top_k], int32
int64_t token_final_scales_ptr, // [num_tokens, top_k], float32 or bf16
int32_t num_tokens, int32_t num_experts, int32_t top_k, int32_t local_expert_offset,
int32_t num_local_experts, int32_t tile_tokens_dim, bool use_pdl,
// Outputs (pre-allocated buffers)
int64_t tile_idx_to_expert_idx_ptr, int64_t tile_idx_to_mn_limit_ptr,
int64_t expanded_idx_to_permuted_idx_ptr, int64_t permuted_idx_to_expanded_idx_ptr,
int64_t total_num_padded_tokens_ptr, int64_t num_non_exiting_tiles_ptr,
// Optional: expert counts buffer for large token counts (>1024)
// Should be size 2 * num_experts, int32
int64_t expert_counts_ptr) {
// Set up the routing data structure
moe::dev::routing::routingDeepSeek::Data routingData;

// Configure dtypes
routingData.mDtypeExpW = batchedGemm::trtllm::gen::Dtype::Bfloat16;
routingData.mDtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16;
routingData.mDtypeScore = batchedGemm::trtllm::gen::Dtype::Fp32;
routingData.mUsePdl = use_pdl;

// Input tensors (pre-computed expert selections)
routingData.mPtrTopKIds = reinterpret_cast<int32_t*>(token_selected_experts_ptr);
routingData.mPtrTopKWeights = reinterpret_cast<void*>(token_final_scales_ptr);
routingData.mPtrScores = nullptr; // Not using routing logits
routingData.mPtrRoutingBias = nullptr; // Not using bias

// Output tensors
routingData.mPtrCtaIdxXyToBatchIdx = reinterpret_cast<int32_t*>(tile_idx_to_expert_idx_ptr);
routingData.mPtrCtaIdxXyToMnLimit = reinterpret_cast<int32_t*>(tile_idx_to_mn_limit_ptr);
routingData.mPtrExpandedIdxToPermutedIdx =
reinterpret_cast<int32_t*>(expanded_idx_to_permuted_idx_ptr);
routingData.mPtrPermutedIdxToTokenIdx =
reinterpret_cast<int32_t*>(permuted_idx_to_expanded_idx_ptr);
routingData.mPtrPermutedIdxSize = reinterpret_cast<int32_t*>(total_num_padded_tokens_ptr);
routingData.mPtrNumNonExitingCtas = reinterpret_cast<int32_t*>(num_non_exiting_tiles_ptr);

// Not using packed format since we have explicit TopK IDs
routingData.mPtrTopKPacked = nullptr;

// Expert counts buffer: required when num_tokens > 1024
// The kernel will set this to nullptr internally for small token counts
routingData.mPtrExpertCounts = reinterpret_cast<int32_t*>(expert_counts_ptr);

// Metadata
routingData.mNumTokens = num_tokens;
routingData.mNumExperts = num_experts;
routingData.mTopK = top_k;
routingData.mPaddingLog2 = computeLog2(tile_tokens_dim);
routingData.mTileTokensDim = tile_tokens_dim;
routingData.mLocalExpertsStartIdx = local_expert_offset;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = num_local_experts;

// DeepSeekV3 specific parameters
// For moe_sort, we use n_group=1, topk_group=1 since experts are already selected
routingData.mNumExpertGroups = 1;
routingData.mNumLimitedGroups = 1;
routingData.mRouteScale = 1.0f;
routingData.mUseRoutingSoftmax = false;

// Run the routing kernel
cudaStream_t stream = get_current_stream();
moe::dev::routing::routingDeepSeek::run(routingData, stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cd /tmp && find . -name "*.py" -type f 2>/dev/null | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Search for moe_sort calls in Python files
find . -name "*.py" -type f -exec grep -l "moe_sort" {} \; 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 214


🏁 Script executed:

# Search more broadly for moe_sort references and how it's called
rg "moe_sort" --type=py -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 11915


🏁 Script executed:

# Also check the csrc directory structure for binding files
ls -la csrc/ 2>/dev/null | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 2289


🏁 Script executed:

rg "def moe_sort" -A 50 flashinfer/moe_utils.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 2412


🏁 Script executed:

# Get more context on token_final_scales handling
rg "token_final_scales" flashinfer/moe_utils.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 1955


🏁 Script executed:

# Check if there's dtype conversion or validation before calling the C++ kernel
sed -n '/def moe_sort/,/^def /p' flashinfer/moe_utils.py | head -120

Repository: flashinfer-ai/flashinfer

Length of output: 5232


🏁 Script executed:

# Check if there's dtype validation for token_final_scales anywhere
rg "token_final_scales" -B 5 -A 5 flashinfer/moe_utils.py | grep -A 15 "dtype"

Repository: flashinfer-ai/flashinfer

Length of output: 1617


🏁 Script executed:

# Check the full moe_sort function implementation to see if dtype is checked
sed -n '/def moe_sort/,/^    func(/p' flashinfer/moe_utils.py | tail -50

Repository: flashinfer-ai/flashinfer

Length of output: 1878


🏁 Script executed:

# Check if mLocalExpertsStrideLog2 is ever used non-zero in the codebase
rg "mLocalExpertsStrideLog2" -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 19411


🏁 Script executed:

# Check the routing kernel implementation to understand what this parameter means
rg "routingDeepSeek" csrc/ --type=cu -l

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

# Search for documentation or comments about stride in DeepSeek routing
rg "Stride\|stride\|STRIDE" csrc/ --type=cu -B 2 -A 2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 94


Add dtype validation for token_final_scales or document bfloat16-only requirement.

The Python API documents token_final_scales as supporting both torch.float32 and torch.bfloat16, but the C++ kernel hardcodes mDtypeExpW and mDtypeBias to Bfloat16 (lines 290-291). If callers pass float32 tensors, the kernel will misinterpret the bit patterns as bfloat16, causing silent correctness issues. Either enforce dtype conversion in the Python wrapper (similar to how token_selected_experts is converted to int32) or update the docstring to document that only bfloat16 is supported.

Note: mLocalExpertsStrideLog2 = 0 is intentional and correct for non-strided deployment (all experts local to a single GPU).

🤖 Prompt for AI Agents
In `@csrc/moe_utils_binding.cu` around lines 273 - 338, The kernel assumes
token_final_scales is bfloat16 but the public API allows float32; in moe_sort
validate the dtype for token_final_scales_ptr (or enforce conversion in the
Python wrapper) and set routingData fields accordingly: if callers pass float32,
either (A) convert the tensor to bfloat16 before calling moe_sort (mirror how
token_selected_experts is converted to int32) or (B) detect float32 here and set
routingData.mDtypeExpW and routingData.mDtypeBias to Fp32 and ensure
routingData.mPtrTopKWeights is treated as float*; update the docstring if you
choose to restrict to bfloat16. Ensure checks reference token_final_scales_ptr,
routingData.mDtypeExpW, routingData.mDtypeBias, and routingData.mPtrTopKWeights
(or the Python wrapper conversion path used for token_selected_experts).

Comment on lines +119 to +121
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, check if the file exists and get its structure
fd "moeUtils.cu" --type f | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 132


🏁 Script executed:

#!/bin/bash
# Read the specific lines mentioned in the review (119-121, 319, 448)
if [ -f "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu" ]; then
  echo "=== Lines 119-121 ==="
  sed -n '119,121p' "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu"
  echo ""
  echo "=== Lines 315-325 (around line 319) ==="
  sed -n '315,325p' "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu"
  echo ""
  echo "=== Lines 444-454 (around line 448) ==="
  sed -n '444,454p' "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu"
fi

Repository: flashinfer-ai/flashinfer

Length of output: 1395


🏁 Script executed:

#!/bin/bash
# Search for getMultiProcessorCount and understand its implementation
rg -n "getMultiProcessorCount" --type cpp -C 3

Repository: flashinfer-ai/flashinfer

Length of output: 4808


🏁 Script executed:

#!/bin/bash
# Check how this function is defined and where it comes from
rg -n "def getMultiProcessorCount|getMultiProcessorCount\s*\(" --type cpp --type h -A 5 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 4806


Static SM count caching prevents proper multi-GPU support.

The getMultiProcessorCount() function queries the current device dynamically, but caching its result in a static variable means it is evaluated only once per process. If the application switches to a different device, this cached value becomes incorrect. Consider using thread_local (as used elsewhere in the codebase, e.g., fp4Op.cpp) or querying per-call if multi-GPU support is required.

This pattern appears at lines 119, 319, and 448.

🤖 Prompt for AI Agents
In `@csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu` around
lines 119 - 121, The cached static smCount computed via
tensorrt_llm::common::getMultiProcessorCount() prevents correct behavior when
the process switches CUDA devices; change the declaration so smCount is
evaluated per-call (e.g., remove static or make it thread_local as done
elsewhere) and recompute it before calculating maxBlocksPerSM and blocks; update
all occurrences where smCount is defined (the instances that call
getMultiProcessorCount()) so each call queries the current device rather than
reusing a process-global cached value.

Comment on lines +56 to +69
struct SwigluBiasAdaptor {
constexpr static bool IS_GLU = true;
float alpha = 1.0f;
float beta = 0.0f;
float limit = std::numeric_limits<float>::infinity();

template <class T>
__device__ T operator()(T const& gate, T const& linear) const {
cutlass::epilogue::thread::Sigmoid<T> fn{};
T linear_clamped = cutlass::maximum<T>{}(cutlass::minimum<T>{}(linear, limit), -limit);
T gate_clamped = cutlass::minimum<T>{}(gate, limit);
return gate_clamped * fn(gate_clamped * alpha) * (linear_clamped + beta);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

SwiGLU activation function implementation clamping

💡 Result:

Short answer: clamp the gate input to a reasonable float32-safe range (e.g. [-50, 50] or [-88, 88]) before applying SiLU/Sigmoid when you implement SwiGLU manually or in custom kernels; most frameworks' builtin sigmoid/silu are already numerically-stable but clamping avoids exp overflow/underflow in custom code. [1][2][3][4]

Example (PyTorch):

gate, up = x.chunk(2, dim=-1)
gate = gate.clamp(-50.0, 50.0)        # or clamp(-88.0, 88.0)
out = torch.nn.functional.silu(gate) * up

Why: Sigmoid/SiLU involve exp() and large inputs can overflow; numerical-stable implementations or conditional formulas avoid this, but simple clamping is a safe, common shortcut. [2][3]

Sources:

  • PyTorch SiLU docs (SiLU = Swish). [1]
  • Notes / examples on numerically-stable sigmoid implementations and clipping. [2][3]
  • Megatron / fused SwiGLU kernel docs showing fused SwiGLU implementations (practical kernel context). [4]

🌐 Web query:

SwiGLU gate linear output sigmoid clamping numerical stability

💡 Result:

Short answer: SwiGLU itself isn’t specially unstable, but practical numerical issues come from (a) computing sigmoid/SiLU (exp overflow/underflow in extreme inputs), and (b) the elementwise multiplication between a (bounded) gated path and an unbounded linear path which can amplify large values and produce INF/NaN or tiny gradients in low-precision. Use stable built-ins, clamp where needed, and apply usual FFN safeguards. Sources: GLU-variants discussion and SwiGLU descriptions, and SiLU docs. [1][2][3]

Practical checklist (concise):

  • Use a numerically stable implementation: call a fused/engine-optimized SiLU/torch.nn.functional.silu rather than hand-writing x * torch.sigmoid(x) when possible (better perf and lower fp16 risk). [3]
  • Prevent extreme inputs to sigmoid/exp: ensure reasonable weight init (Xavier/Kaiming) and consider activation normalization (LayerNorm) before FFN to keep pre-activations in safe range. [1][3]
  • Avoid catastrophic multiplicative growth: after SwiGLU, use an output projection with appropriate weight scale and/or scalar gating (learned scale) to limit dynamic range. [1][2]
  • Clamp only when necessary: if you must clamp sigmoid outputs (e.g., to avoid exact 0/1 in downstream log/div), use small eps like 1e-6 (torch.clamp(sigmoid(x), eps, 1-eps)), but prefer fixing sources of extreme values first. Over-clamping changes gradients and representational capacity. [1]
  • FP16/AMP tips: use PyTorch’s fused SiLU and mixed-precision autocast, keep LayerNorm in fp32 or use loss-scaling to avoid underflow/overflow; avoid in-place ops on tensors used in backward. [3]
  • Debugging NaNs/Infs: log max/min of pre-activation (both paths), check gradient norms, temporarily clip gradients and activations to locate which path explodes (linear path vs gated path). [1][2]

If you want, I can:

  • show compact PyTorch-safe code for SwiGLU with fp16/AMP best practices and optional sigmoid clamping; or
  • run a short experiment to demonstrate where NaNs arise and how clamping/scale fixes them.

🏁 Script executed:

# Search for other SwiGLU implementations or references in the codebase
rg -i "swiglu|swiglu" --type cpp --type h --type cuh -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Look for tests or documentation related to this kernel
fd -e "test" -e "md" | head -20 | xargs rg -l -i "swiglu\|moe_kernels" 2>/dev/null || true

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are comments or context around the clamping logic
cat -n csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh | sed -n '40,80p'

Repository: flashinfer-ai/flashinfer

Length of output: 1251


SwigluBiasAdaptor gate clamping deviates from standard SwiGLU numerical stability practices.

The gate value is only clamped from above (line 66: cutlass::minimum<T>{}(gate, limit)), while linear is clamped symmetrically to [-limit, limit]. Standard SwiGLU implementations clamp the gate input symmetrically (e.g., [-50, 50]) before applying sigmoid to prevent exp overflow/underflow in the sigmoid computation. Very negative gate values in this implementation could cause numerical instability.

🤖 Prompt for AI Agents
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh`
around lines 56 - 69, The gate is only clamped from above which can let very
negative values pass into the sigmoid and cause numerical instability; in
SwigluBiasAdaptor::operator() clamp the gate symmetrically (e.g., use
cutlass::maximum<T>{}(cutlass::minimum<T>{}(gate, limit), -limit)) before
computing the sigmoid and using it for the gate multiplication so the sigmoid
receives a bounded input and numerical overflow/underflow is avoided.

Comment on lines +72 to +160
def create_finalize_fusion_tensors(
seq_len: int,
topk: int,
permuted_m: int,
group_m_list: List[int],
mma_tiler_mn: Tuple[int, int],
final_scale_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create tensors required for finalize fusion.

This function creates the mapping tensor and final scale tensor needed
for the fused finalize operation in GEMM2.

Args:
seq_len: Number of output tokens (original sequence length)
topk: Number of experts per token
permuted_m: Total permuted M dimension (sum of aligned group sizes)
group_m_list: List of actual (unaligned) M values per expert
mma_tiler_mn: MMA tile shape (M, N) for alignment
final_scale_dtype: Data type for token final scales. Default: torch.float32

Returns:
Tuple of:
- permuted_idx_to_expanded_idx: Mapping tensor, shape (permuted_m,), int32
Maps permuted row index to expanded_idx = token_idx * topk + k_idx
Invalid rows are marked with -1.
- token_final_scales: Router scale tensor, shape (seq_len, topk), final_scale_dtype
Normalized routing weights for each (token, topk) pair.

Example:
>>> seq_len, topk, num_experts = 4096, 8, 8
>>> group_m_list = [512, 480, 256, 320, 640, 512, 384, 704] # Tokens per expert
>>> permuted_m = sum(align_to(m, 256) for m in group_m_list) # Aligned total
>>>
>>> permuted_idx_to_expanded_idx, token_final_scales = create_finalize_fusion_tensors(
... seq_len=seq_len,
... topk=topk,
... permuted_m=permuted_m,
... group_m_list=group_m_list,
... mma_tiler_mn=(256, 128),
... )
"""
m_aligned = mma_tiler_mn[0]

# Initialize mapping tensor with -1 (invalid)
permuted_idx_to_expanded_idx = torch.empty(
(permuted_m,), dtype=torch.int32, device="cuda"
).fill_(-1)

# Create normalized token final scales
token_final_scales = torch.rand(
seq_len, topk, dtype=final_scale_dtype, device="cuda"
)
token_final_scales = token_final_scales / token_final_scales.sum(
dim=1, keepdim=True
)

start_idx = 0
for group_idx, m_per_group in enumerate(group_m_list):
if m_per_group > 0:
# Sequential/Blocked assignment for better atomic add memory access
# Experts are grouped into sets of size topk.
# Expert Set S (experts S*topk ... S*topk+topk-1) serves a contiguous block of tokens.
# This ensures that within an expert, we process tokens T, T+1, T+2... sequentially.

expert_set_idx = group_idx // topk
k_in_set = group_idx % topk

# Start token index for this expert set
start_token = expert_set_idx * m_per_group

# Generate sequential token indices for this expert
token_indices = torch.arange(
start_token, start_token + m_per_group, dtype=torch.int32, device="cuda"
)
token_indices = token_indices % seq_len

# expanded_idx = token_idx * topk + k
expanded_idx = token_indices * topk + k_in_set

permuted_idx_to_expanded_idx[start_idx : (start_idx + m_per_group)] = (
expanded_idx
)

# Move to next aligned group
m_aligned_per_group = ((m_per_group + m_aligned - 1) // m_aligned) * m_aligned
start_idx += m_aligned_per_group

return permuted_idx_to_expanded_idx, token_final_scales
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how create_finalize_fusion_tensors is used in the codebase
rg -n "create_finalize_fusion_tensors" --type=py -C 3

Repository: flashinfer-ai/flashinfer

Length of output: 5255


🏁 Script executed:

#!/bin/bash
# Read the file to see the function definition and decorators
head -n 180 flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | tail -n +1

Repository: flashinfer-ai/flashinfer

Length of output: 6791


🏁 Script executed:

#!/bin/bash
# Read more of the file to see _get_compiled_finalize_kernel and the main API function
sed -n '180,250p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 2432


🏁 Script executed:

#!/bin/bash
# Read the main API function to check for decorators
sed -n '250,400p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 6876


🏁 Script executed:

#!/bin/bash
# Look for test files that might use this function
find . -name "*test*.py" -type f | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 925


🏁 Script executed:

#!/bin/bash
# Search for tests that use create_finalize_fusion_tensors
rg -n "create_finalize_fusion_tensors\|blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4" tests/ --type=py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check if there are any actual uses of this function outside of documentation
rg -n "create_finalize_fusion_tensors" --type=py --exclude-dir=tests -C 5 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 483


🏁 Script executed:

#!/bin/bash
# Check if there are any actual uses of this function outside of documentation
rg -n "create_finalize_fusion_tensors" --type=py -C 5 | grep -v "Example:" | grep -v ">>>" | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 6026


🏁 Script executed:

#!/bin/bash
# Check other similar API functions in the same file to see decorator patterns
rg -n "@functools.cache|@flashinfer_api" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 87


🏁 Script executed:

#!/bin/bash
# Check other functions that create tensors for similar kernels to see if they have decorators
rg -B 2 "def create_" flashinfer/cute_dsl/*.py --type=py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1259


🏁 Script executed:

#!/bin/bash
# Check if create_gather_gemm_tensors has any decorators
rg -B 5 "def create_gather_gemm_tensors" flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 207


🏁 Script executed:

#!/bin/bash
# Check the full context of create_gather_gemm_tensors to see if it also uses random values
sed -n '1,150p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py | grep -A 50 "def create_gather_gemm_tensors"

Repository: flashinfer-ai/flashinfer

Length of output: 2249


🏁 Script executed:

#!/bin/bash
# Look at the end of create_gather_gemm_tensors function
sed -n '50,150p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py | tail -80

Repository: flashinfer-ai/flashinfer

Length of output: 3232


🏁 Script executed:

#!/bin/bash
# Get the complete create_gather_gemm_tensors function
sed -n '40,150p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 4151


🏁 Script executed:

#!/bin/bash
# Get more of the create_gather_gemm_tensors function - the end
sed -n '120,180p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 2381


🏁 Script executed:

#!/bin/bash
# Check if there are imports of functools.cache or similar caching in the file
rg -n "functools|cache" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 881


Add @flashinfer_api decorator to public API function and clarify tensor initialization approach.

The function create_finalize_fusion_tensors is exported as a public API (line 68 in __all__) but lacks the @flashinfer_api decorator used by other public functions in the module. Add the decorator to enable API logging/debugging via FLASHINFER_LOGLEVEL.

Additionally, token_final_scales are initialized with random normalized values without any way for users to provide pre-computed routing weights. The function should either:

  1. Accept optional token_final_scales parameter for real routing weights, or
  2. Clearly document that returned scales are placeholder values meant for testing/examples only

The kernel caching approach using _finalize_kernel_cache is appropriate and does not need @functools.cache since it correctly caches by tactic parameters only, allowing kernel reuse across different problem sizes.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 72 - 160, Add the missing `@flashinfer_api` decorator to the public
function create_finalize_fusion_tensors to enable FLASHINFER_LOGLEVEL-based API
logging; also modify its signature to accept an optional token_final_scales:
Optional[torch.Tensor] = None parameter (dtype final_scale_dtype, shape
(seq_len, topk)) and, if provided, validate shape/dtype and use it instead of
generating random values, otherwise keep the current randomized normalized
initialization but update the docstring to state these are placeholder/test
values; leave the existing _finalize_kernel_cache behavior unchanged.

Comment on lines +241 to +272
# Check if we're doing FP4 quantization
generate_sfc = c_dtype == "float4_e2m1fn"
if generate_sfc:
if global_scale is None:
raise ValueError("global_scale is required when c_dtype is 'float4_e2m1fn'")

# Create output tensor if not provided
if out is None:
if generate_sfc:
# FP4 output: 2 values per byte
out = torch.empty(
(permuted_m, intermediate_size // 2),
dtype=torch.uint8,
device=a.device,
)
else:
out = torch.empty(
(permuted_m, intermediate_size),
dtype=cutlass_to_torch_dtype(c_dtype_cutlass),
device=a.device,
)

# Create output scale tensor if needed and not provided
if generate_sfc and out_scale is None:
# Scale factor layout for output
scale_intermediate_size = intermediate_size // sf_vec_size
# MMA-compatible scale factor shape
out_scale = torch.empty(
(32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1),
dtype=torch.uint8, # FP8 E4M3
device=a.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate FP4 output scale layout assumptions before allocation.

out_scale uses permuted_m // 128 and scale_intermediate_size // 4; without divisibility checks, the buffer can be undersized and the kernel may write out of bounds. Add explicit guards for FP4 output.

🐛 Proposed fix
     generate_sfc = c_dtype == "float4_e2m1fn"
     if generate_sfc:
         if global_scale is None:
             raise ValueError("global_scale is required when c_dtype is 'float4_e2m1fn'")
+        if permuted_m % 128 != 0:
+            raise ValueError("permuted_m must be a multiple of 128 for FP4 output scale layout")
+        if intermediate_size % (sf_vec_size * 4) != 0:
+            raise ValueError(
+                "intermediate_size must be divisible by sf_vec_size * 4 for FP4 output scale layout"
+            )
🧰 Tools
🪛 Ruff (0.14.13)

245-245: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py`
around lines 241 - 272, The FP4 output scale allocation assumes permuted_m is
divisible by 128 and scale_intermediate_size (computed from intermediate_size //
sf_vec_size) is divisible by 4; add explicit validation in the generate_sfc
branch before allocating out_scale: check that permuted_m % 128 == 0 and
scale_intermediate_size % 4 == 0 (using the existing symbols generate_sfc,
out_scale, permuted_m, intermediate_size, sf_vec_size, scale_intermediate_size),
and raise a clear ValueError if either check fails so the buffer size cannot be
undersized and cause out-of-bounds writes.

Comment on lines +361 to +363
# Validate inputs
assert a.device.type == "cuda", "Input tensors must be on CUDA device"
assert b.device.type == "cuda", "Input tensors must be on CUDA device"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace assert with proper exceptions for runtime validation.

assert statements are stripped when Python runs with -O (optimized mode). For production code validating user inputs, use explicit if checks with ValueError/RuntimeError.

🔧 Suggested fix
-    assert a.device.type == "cuda", "Input tensors must be on CUDA device"
-    assert b.device.type == "cuda", "Input tensors must be on CUDA device"
+    if a.device.type != "cuda":
+        raise ValueError("Input tensor 'a' must be on CUDA device")
+    if b.device.type != "cuda":
+        raise ValueError("Input tensor 'b' must be on CUDA device")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Validate inputs
assert a.device.type == "cuda", "Input tensors must be on CUDA device"
assert b.device.type == "cuda", "Input tensors must be on CUDA device"
# Validate inputs
if a.device.type != "cuda":
raise ValueError("Input tensor 'a' must be on CUDA device")
if b.device.type != "cuda":
raise ValueError("Input tensor 'b' must be on CUDA device")
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py` around lines 361
- 363, Replace the two assert statements that check tensor device types (the
checks using a.device.type == "cuda" and b.device.type == "cuda") with explicit
runtime validation: use if statements that raise a clear exception (e.g.,
ValueError or RuntimeError) when a or b are not on CUDA, and include a helpful
message mentioning which tensor failed the check; update the validation block in
blockscaled_contiguous_grouped_gemm.py to perform these explicit checks instead
of using assert so they are not stripped in optimized Python.

Comment on lines +260 to +266
lambda shapes, dtype, device: torch.randint(
0,
8,
shapes,
dtype=torch.int32,
device=device, # num_experts=8 typical
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded num_experts=8 assumption in initializer.

The comment says "num_experts=8 typical" but this initializer is used generically. If tuning profiles are generated with different expert counts, this could produce invalid expert indices.

🔧 Suggested fix

Consider parameterizing this or using the actual num_experts from the runner instance:

         # 2: token_selected_experts - expert indices (int32, 0 to num_experts-1)
         lambda shapes, dtype, device: torch.randint(
             0,
-            8,
+            max(shapes[1], 1),  # Use top_k dimension as proxy, or accept num_experts param
             shapes,
             dtype=torch.int32,
-            device=device,  # num_experts=8 typical
+            device=device,
         ),

Or document that this assumes a minimum of 8 experts during profiling.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lambda shapes, dtype, device: torch.randint(
0,
8,
shapes,
dtype=torch.int32,
device=device, # num_experts=8 typical
),
lambda shapes, dtype, device: torch.randint(
0,
max(shapes[1], 1), # Use top_k dimension as proxy, or accept num_experts param
shapes,
dtype=torch.int32,
device=device,
),
🧰 Tools
🪛 Ruff (0.14.13)

260-260: Unused lambda argument: dtype

(ARG005)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/tuner.py` around lines 260 - 266, The initializer lambda
in tuner.py currently samples expert indices with a hardcoded range of 0..7
(comment "num_experts=8 typical"), which can produce invalid indices for models
with different expert counts; update the lambda used in the tuner initialization
(the anonymous function that calls torch.randint) to derive the upper bound from
the actual model/runner expert count (e.g., use a passed-in num_experts or
runner.num_experts) or at minimum clamp to min/max to avoid out-of-range values;
locate the lambda in tuner.py and replace the hardcoded 8 with the dynamic
num_experts value (or a safe expression like max(1, num_experts)) so sampled
indices are always valid.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Fix all issues with AI agents
In
`@flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 482-489: The sm_count parameter is currently ignored; after
computing max_active_clusters via get_max_active_clusters(cluster_shape_mn[0] *
cluster_shape_mn[1]), clamp it to sm_count (which you already set from
get_num_sm(a.device) when None) by replacing max_active_clusters with
min(max_active_clusters, sm_count) so the API honor sm_count; reference symbols:
sm_count, get_num_sm, get_max_active_clusters, max_active_clusters,
cluster_shape_mn.
- Around line 471-481: The out_scale allocation in the
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion code assumes permuted_m
and intermediate_size are divisible by the vectorization factors (128 and
sf_vec_size/4) and can cause OOB writes; add explicit validation before the
allocation (when generate_sfc is true and out_scale is None) to assert or raise
a clear error if permuted_m % 128 != 0 or intermediate_size % (sf_vec_size * 4)
!= 0 (or the equivalent divisibility used to compute scale_intermediate_size and
the shape dims), and adjust the calculation of scale_intermediate_size
accordingly to use integer division only after validation so out_scale has the
correct size for the rest of the code paths referencing out_scale.

In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 416-423: The code computes max_active_clusters via
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) but currently
ignores the sm_count parameter; update the logic so that after computing
max_active_clusters you clamp it using sm_count (use get_num_sm(a.device) only
if sm_count is None), e.g., determine sm_count via sm_count = sm_count or
get_num_sm(a.device), then set max_active_clusters = min(max_active_clusters,
sm_count) so kernel scheduling respects the provided sm_count limit; adjust
references around sm_count, get_num_sm, get_max_active_clusters,
max_active_clusters and cluster_shape_mn accordingly.
- Around line 459-471: Validate token_final_scales.dtype explicitly before
mapping to Cutlass types: handle torch.float32, torch.bfloat16, and
torch.float16 (set token_scales_dtype to cutlass.Float32, cutlass.BFloat16,
cutlass.Float16 respectively) and raise a clear error if any other dtype is
passed; then call make_ptr(token_scales_dtype, token_final_scales.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16). Locate symbols token_final_scales,
token_scales_dtype, make_ptr, and the cutlass type mappings to implement this
guard and error path.

In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py`:
- Around line 416-437: The computed max_active_clusters from
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) is not being
clamped by the sm_count limit; after computing max_active_clusters, clamp it
with sm_count (e.g., max_active_clusters = min(max_active_clusters, sm_count))
so the API contract is respected. Ensure sm_count is defined (it may be set via
get_num_sm(a.device)) before the clamp and apply this change right after the
call to get_max_active_clusters in the block that includes sm_count, get_num_sm,
and cluster_shape_mn.
- Around line 120-130: The code allows padding when permuted_m > valid_m but
doesn't ensure the padding size is a multiple of the tile size (mma_tiler_m),
which can cause mismatched lengths and OOB access; inside the block that handles
permuted_m > valid_m (the same place where num_padding_tiles is computed and
tile_idx_to_group_idx_list is extended), validate that (permuted_m - valid_m) %
mma_tiler_m == 0 and if not raise a ValueError explaining that permuted_m -
valid_m must be divisible by mma_tiler_m; keep the existing behavior of
computing num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m and
extending tile_idx_to_group_idx_list with the padding only after this check.

In `@flashinfer/cute_dsl/utils.py`:
- Around line 84-110: The cached HardwareInfo and get_max_active_clusters are
device-agnostic and will return wrong values on multi-GPU machines; update
caching to be device-aware: change get_hardware_info to accept an optional
device identifier (or obtain current device internally) and replace the single
_hardware_info_cache with a per-device cache (e.g., dict keyed by device id) for
the HardwareInfo singleton; also update get_max_active_clusters to include the
device id in its cache key (remove or replace the `@functools.cache` usage with a
device-keyed cache or make the function accept a device parameter so caching is
per-device). Ensure you reference and update the symbols _hardware_info_cache,
get_hardware_info, get_max_active_clusters, and the use of `@functools.cache`
accordingly so each GPU gets correct, device-specific values.
♻️ Duplicate comments (4)
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

73-161: Public helper should be API‑logged and allow caller‑provided scales.
This mirrors earlier feedback: add @flashinfer_api and allow callers to pass real routing scales (or clearly document randomized placeholders).

♻️ Suggested update
+@flashinfer_api
 def create_finalize_fusion_tensors(
     seq_len: int,
     topk: int,
     permuted_m: int,
     group_m_list: List[int],
     mma_tiler_mn: Tuple[int, int],
     final_scale_dtype: torch.dtype = torch.float32,
+    token_final_scales: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
@@
-    # Create normalized token final scales
-    token_final_scales = torch.rand(
-        seq_len, topk, dtype=final_scale_dtype, device="cuda"
-    )
-    token_final_scales = token_final_scales / token_final_scales.sum(
-        dim=1, keepdim=True
-    )
+    # Create or validate token final scales
+    if token_final_scales is None:
+        token_final_scales = torch.rand(
+            seq_len, topk, dtype=final_scale_dtype, device="cuda"
+        )
+        token_final_scales = token_final_scales / token_final_scales.sum(
+            dim=1, keepdim=True
+        )
+    else:
+        if token_final_scales.shape != (seq_len, topk):
+            raise ValueError("token_final_scales must have shape (seq_len, topk)")
+        if token_final_scales.dtype != final_scale_dtype:
+            raise ValueError("token_final_scales dtype must match final_scale_dtype")
As per coding guidelines, please add `@flashinfer_api` to public API helpers.
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py (2)

69-82: Remove unused shape args and switch to @functools.cache.
The cache key currently includes unused shape parameters, causing unbounded cache growth; and lru_cache(maxsize=None) should be replaced per guidelines.

🐛 Suggested fix
-@functools.lru_cache(maxsize=None)
+@functools.cache
 def _get_compiled_swiglu_kernel(
-    permuted_m: int,
-    n: int,  # This is 2*intermediate_size
-    k: int,
-    num_experts: int,
     ab_dtype_name: str,
     sf_dtype_name: str,
     c_dtype_name: str,
     sf_vec_size: int,
     mma_tiler_mn: Tuple[int, int],
     cluster_shape_mn: Tuple[int, int],
     vectorized_f32: bool,
 ):
-    gemm, _, _, _ = _get_compiled_swiglu_kernel(
-        permuted_m=permuted_m,
-        n=n,
-        k=k,
-        num_experts=num_experts,
+    gemm, _, _, _ = _get_compiled_swiglu_kernel(
         ab_dtype_name=ab_dtype,
         sf_dtype_name=sf_dtype,
         c_dtype_name=c_dtype,
         sf_vec_size=sf_vec_size,
         mma_tiler_mn=mma_tiler_mn,
         cluster_shape_mn=cluster_shape_mn,
         vectorized_f32=vectorized_f32,
     )
As per coding guidelines, use `@functools.cache` for module‑level caching.

241-272: Validate FP4 out_scale layout divisibility.
The layout assumes permuted_m and intermediate_size alignment; without checks, the buffer can be undersized.

🔧 Suggested validation
     generate_sfc = c_dtype == "float4_e2m1fn"
     if generate_sfc:
         if global_scale is None:
             raise ValueError("global_scale is required when c_dtype is 'float4_e2m1fn'")
+        if permuted_m % 128 != 0:
+            raise ValueError("permuted_m must be a multiple of 128 for FP4 output scale layout")
+        if intermediate_size % (sf_vec_size * 4) != 0:
+            raise ValueError(
+                "intermediate_size must be divisible by sf_vec_size * 4 for FP4 output scale layout"
+            )
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py (1)

258-286: Remove unused shape params and switch to @functools.cache.
Unused shape args inflate the cache key without effect, and lru_cache(maxsize=None) should be replaced per guidelines.

🐛 Suggested fix
-@functools.lru_cache(maxsize=None)
+@functools.cache
 def _get_compiled_kernel(
-    permuted_m: int,
-    n: int,
-    k: int,
-    num_experts: int,
     ab_dtype_name: str,
     sf_dtype_name: str,
     c_dtype_name: str,
     sf_vec_size: int,
     mma_tiler_mn: Tuple[int, int],
     cluster_shape_mn: Tuple[int, int],
 ):
-    gemm, _, _, _ = _get_compiled_kernel(
-        permuted_m=permuted_m,
-        n=n,
-        k=k,
-        num_experts=num_experts,
+    gemm, _, _, _ = _get_compiled_kernel(
         ab_dtype_name=ab_dtype,
         sf_dtype_name=sf_dtype,
         c_dtype_name=c_dtype,
         sf_vec_size=sf_vec_size,
         mma_tiler_mn=mma_tiler_mn,
         cluster_shape_mn=cluster_shape_mn,
     )
As per coding guidelines, use `@functools.cache` for module‑level caching.
🧹 Nitpick comments (3)
flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (1)

76-187: Add API logging to the public tensor‑creation helper.
create_gather_gemm_tensors is exported in __all__ but lacks @flashinfer_api.

♻️ Suggested change
+@flashinfer_api
 def create_gather_gemm_tensors(
     seq_len: int,
     topk: int,
     group_m_list: List[int],
     mma_tiler_m: int,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, List[int]]:
As per coding guidelines, please add `@flashinfer_api` to public API helpers.
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py (2)

72-76: Add API logging for create_tile_mapping.
This helper is exported in __all__ but isn’t decorated.

♻️ Suggested change
+@flashinfer_api
 def create_tile_mapping(
     group_m_list: torch.Tensor,
     mma_tiler_m: int,
     permuted_m: Optional[int] = None,
 ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
As per coding guidelines, please add `@flashinfer_api` to public API helpers.

145-152: Add API logging for create_scale_factor_tensor.
This exported helper should participate in standard API logging.

♻️ Suggested change
+@flashinfer_api
 def create_scale_factor_tensor(
     l: int,
     mn: int,
     k: int,
     sf_vec_size: int,
     dtype: Type[cutlass.Numeric],
 ) -> Tuple[torch.Tensor, cute.Tensor, torch.Tensor]:
As per coding guidelines, please add `@flashinfer_api` to public API helpers.

Comment on lines +471 to +481
# Create output scale tensor if needed and not provided
if generate_sfc and out_scale is None:
# Scale factor layout for output
scale_intermediate_size = intermediate_size // sf_vec_size
# MMA-compatible scale factor shape
out_scale = torch.empty(
(32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1),
dtype=torch.uint8, # FP8 E4M3
device=a.device,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard FP4 out_scale layout divisibility to prevent OOB writes.
out_scale sizing assumes permuted_m and intermediate_size are aligned; without checks the buffer can be undersized.

🔧 Suggested validation
     if generate_sfc and out_scale is None:
         # Scale factor layout for output
         scale_intermediate_size = intermediate_size // sf_vec_size
+        if permuted_m % 128 != 0:
+            raise ValueError("permuted_m must be a multiple of 128 for FP4 output scale layout")
+        if intermediate_size % (sf_vec_size * 4) != 0:
+            raise ValueError(
+                "intermediate_size must be divisible by sf_vec_size * 4 for FP4 output scale layout"
+            )
         # MMA-compatible scale factor shape
         out_scale = torch.empty(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Create output scale tensor if needed and not provided
if generate_sfc and out_scale is None:
# Scale factor layout for output
scale_intermediate_size = intermediate_size // sf_vec_size
# MMA-compatible scale factor shape
out_scale = torch.empty(
(32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1),
dtype=torch.uint8, # FP8 E4M3
device=a.device,
)
# Create output scale tensor if needed and not provided
if generate_sfc and out_scale is None:
# Scale factor layout for output
scale_intermediate_size = intermediate_size // sf_vec_size
if permuted_m % 128 != 0:
raise ValueError("permuted_m must be a multiple of 128 for FP8 output scale layout")
if intermediate_size % (sf_vec_size * 4) != 0:
raise ValueError(
"intermediate_size must be divisible by sf_vec_size * 4 for FP8 output scale layout"
)
# MMA-compatible scale factor shape
out_scale = torch.empty(
(32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1),
dtype=torch.uint8, # FP8 E4M3
device=a.device,
)
🤖 Prompt for AI Agents
In
`@flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 471 - 481, The out_scale allocation in the
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion code assumes permuted_m
and intermediate_size are divisible by the vectorization factors (128 and
sf_vec_size/4) and can cause OOB writes; add explicit validation before the
allocation (when generate_sfc is true and out_scale is None) to assert or raise
a clear error if permuted_m % 128 != 0 or intermediate_size % (sf_vec_size * 4)
!= 0 (or the equivalent divisibility used to compute scale_intermediate_size and
the shape dims), and adjust the calculation of scale_intermediate_size
accordingly to use integer division only after validation so out_scale has the
correct size for the rest of the code paths referencing out_scale.

Comment on lines +482 to +489
# Get SM count
if sm_count is None:
sm_count = get_num_sm(a.device)

# Compute max active clusters (cached to avoid expensive HardwareInfo queries)
max_active_clusters = get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

sm_count parameter is unused.
Honor the API by clamping max_active_clusters to sm_count.

🔧 Suggested fix
-    max_active_clusters = get_max_active_clusters(
-        cluster_shape_mn[0] * cluster_shape_mn[1]
-    )
+    max_active_clusters = min(
+        get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
+        sm_count,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get SM count
if sm_count is None:
sm_count = get_num_sm(a.device)
# Compute max active clusters (cached to avoid expensive HardwareInfo queries)
max_active_clusters = get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
# Get SM count
if sm_count is None:
sm_count = get_num_sm(a.device)
# Compute max active clusters (cached to avoid expensive HardwareInfo queries)
max_active_clusters = min(
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
sm_count,
)
🤖 Prompt for AI Agents
In
`@flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 482 - 489, The sm_count parameter is currently ignored; after
computing max_active_clusters via get_max_active_clusters(cluster_shape_mn[0] *
cluster_shape_mn[1]), clamp it to sm_count (which you already set from
get_num_sm(a.device) when None) by replacing max_active_clusters with
min(max_active_clusters, sm_count) so the API honor sm_count; reference symbols:
sm_count, get_num_sm, get_max_active_clusters, max_active_clusters,
cluster_shape_mn.

Comment on lines +416 to +423
# Get SM count
if sm_count is None:
sm_count = get_num_sm(a.device)

# Compute max active clusters (cached to avoid expensive HardwareInfo queries)
max_active_clusters = get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

sm_count is currently ignored.
The API advertises sm_count, but it doesn’t affect kernel scheduling. Clamp max_active_clusters with it to respect user limits and avoid oversubscription.

🔧 Suggested fix
-    max_active_clusters = get_max_active_clusters(
-        cluster_shape_mn[0] * cluster_shape_mn[1]
-    )
+    max_active_clusters = min(
+        get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
+        sm_count,
+    )
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 416 - 423, The code computes max_active_clusters via
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) but currently
ignores the sm_count parameter; update the logic so that after computing
max_active_clusters you clamp it using sm_count (use get_num_sm(a.device) only
if sm_count is None), e.g., determine sm_count via sm_count = sm_count or
get_num_sm(a.device), then set max_active_clusters = min(max_active_clusters,
sm_count) so kernel scheduling respects the provided sm_count limit; adjust
references around sm_count, get_num_sm, get_max_active_clusters,
max_active_clusters and cluster_shape_mn accordingly.

Comment on lines +459 to +471
# Token final scales - determine dtype and create pointer
if token_final_scales.dtype == torch.float32:
token_scales_dtype = cutlass.Float32
elif token_final_scales.dtype == torch.bfloat16:
token_scales_dtype = cutlass.BFloat16
else:
token_scales_dtype = cutlass.Float16
token_scales_ptr = make_ptr(
token_scales_dtype,
token_final_scales.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the file
find . -type f -name "blockscaled_contiguous_grouped_gemm_finalize_fusion.py" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 231


🏁 Script executed:

# Check if file exists and examine the target lines
if [ -f "flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py" ]; then
  wc -l "flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py"
fi

Repository: flashinfer-ai/flashinfer

Length of output: 146


🏁 Script executed:

# Read the file around the specified lines
sed -n '450,480p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 1187


🏁 Script executed:

# Find the function that contains this code block
sed -n '400,480p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 3049


🏁 Script executed:

# Search for calls to this function/module to understand dtype usage
rg "token_final_scales" flashinfer/ -B 2 -A 2 --max-count=20

Repository: flashinfer-ai/flashinfer

Length of output: 27246


🏁 Script executed:

# Find the function containing lines 459-471
sed -n '380,475p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3404


🏁 Script executed:

# Get full function definition starting from earlier lines
sed -n '340,480p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 5175


🏁 Script executed:

# Search for function definition pattern around line 459
grep -n "^def " flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 212


🏁 Script executed:

# Check for torch.float16 usage in other parts of the codebase
rg "torch\.float16" flashinfer/cute_dsl/ -B 2 -A 2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 2597


🏁 Script executed:

# Check if similar dtype validation exists elsewhere
rg "raise ValueError.*dtype" flashinfer/ -B 1 -A 1 | grep -A 2 -B 2 "float" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1939


🏁 Script executed:

# Check the function signature and any decorators
sed -n '258,280p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 909


🏁 Script executed:

# Search for any decorators or imports related to caching/API decorators in this file
head -50 flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Repository: flashinfer-ai/flashinfer

Length of output: 2196


🏁 Script executed:

# Check if `@functools.cache` or `@flashinfer_api` is used elsewhere in flashinfer/**/*.py files
rg "@functools\.cache|@flashinfer_api" flashinfer/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 858


🏁 Script executed:

# Check for validation patterns similar to what's suggested
rg "if.*\.dtype ==" flashinfer/cute_dsl/ --type py -B 1 -A 4 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1935


🏁 Script executed:

# Check the complete function to understand validation context
sed -n '258,550p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | tail -50

Repository: flashinfer-ai/flashinfer

Length of output: 1452


🏁 Script executed:

# Look at how token_final_scales is used downstream to understand if defaulting to Float16 could cause issues
rg "token_final_scales" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py -A 5 -B 2 | grep -A 10 "dtype"

Repository: flashinfer-ai/flashinfer

Length of output: 2628


Validate token_final_scales dtype before selecting Cutlass type.

The code currently defaults to Float16 for any unrecognized dtype, which silently misinterprets data (e.g., if torch.float64 is passed). Additionally, torch.float16 is documented as supported but lacks explicit handling—it falls through to the else clause accidentally. Gate supported dtypes explicitly and reject unsupported ones.

🔧 Suggested fix
-    if token_final_scales.dtype == torch.float32:
-        token_scales_dtype = cutlass.Float32
-    elif token_final_scales.dtype == torch.bfloat16:
-        token_scales_dtype = cutlass.BFloat16
-    else:
-        token_scales_dtype = cutlass.Float16
+    if token_final_scales.dtype == torch.float32:
+        token_scales_dtype = cutlass.Float32
+    elif token_final_scales.dtype == torch.bfloat16:
+        token_scales_dtype = cutlass.BFloat16
+    elif token_final_scales.dtype == torch.float16:
+        token_scales_dtype = cutlass.Float16
+    else:
+        raise ValueError(
+            "token_final_scales must be float16, bfloat16, or float32"
+        )
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 459 - 471, Validate token_final_scales.dtype explicitly before
mapping to Cutlass types: handle torch.float32, torch.bfloat16, and
torch.float16 (set token_scales_dtype to cutlass.Float32, cutlass.BFloat16,
cutlass.Float16 respectively) and raise a clear error if any other dtype is
passed; then call make_ptr(token_scales_dtype, token_final_scales.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16). Locate symbols token_final_scales,
token_scales_dtype, make_ptr, and the cutlass type mappings to implement this
guard and error path.

Comment on lines +120 to +130
# Apply padding if requested (for CUDA graph support)
if permuted_m is not None:
if permuted_m < valid_m:
raise ValueError(
f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). "
f"Cannot pad to a smaller size."
)
if permuted_m > valid_m:
num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m
# Pad with invalid index (these tiles won't be accessed)
tile_idx_to_group_idx_list.extend([int(-2e9)] * num_padding_tiles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate permuted_m padding aligns to mma_tiler_m.
If (permuted_m - valid_m) isn’t a multiple of the tile size, the padded mapping length won’t match permuted_m, risking out‑of‑bounds indexing.

🔧 Suggested validation
     if permuted_m is not None:
         if permuted_m < valid_m:
             raise ValueError(
                 f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). "
                 f"Cannot pad to a smaller size."
             )
+        if (permuted_m - valid_m) % mma_tiler_m != 0:
+            raise ValueError(
+                "permuted_m padding must be a multiple of mma_tiler_m"
+            )
         if permuted_m > valid_m:
             num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Apply padding if requested (for CUDA graph support)
if permuted_m is not None:
if permuted_m < valid_m:
raise ValueError(
f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). "
f"Cannot pad to a smaller size."
)
if permuted_m > valid_m:
num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m
# Pad with invalid index (these tiles won't be accessed)
tile_idx_to_group_idx_list.extend([int(-2e9)] * num_padding_tiles)
# Apply padding if requested (for CUDA graph support)
if permuted_m is not None:
if permuted_m < valid_m:
raise ValueError(
f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). "
f"Cannot pad to a smaller size."
)
if (permuted_m - valid_m) % mma_tiler_m != 0:
raise ValueError(
"permuted_m padding must be a multiple of mma_tiler_m"
)
if permuted_m > valid_m:
num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m
# Pad with invalid index (these tiles won't be accessed)
tile_idx_to_group_idx_list.extend([int(-2e9)] * num_padding_tiles)
🧰 Tools
🪛 Ruff (0.14.13)

123-126: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py` around lines 120
- 130, The code allows padding when permuted_m > valid_m but doesn't ensure the
padding size is a multiple of the tile size (mma_tiler_m), which can cause
mismatched lengths and OOB access; inside the block that handles permuted_m >
valid_m (the same place where num_padding_tiles is computed and
tile_idx_to_group_idx_list is extended), validate that (permuted_m - valid_m) %
mma_tiler_m == 0 and if not raise a ValueError explaining that permuted_m -
valid_m must be divisible by mma_tiler_m; keep the existing behavior of
computing num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m and
extending tile_idx_to_group_idx_list with the padding only after this check.

Comment on lines +416 to +437
# Get SM count
if sm_count is None:
sm_count = get_num_sm(a.device)

# Get or compile the kernel
gemm, _, _, _ = _get_compiled_kernel(
permuted_m=permuted_m,
n=n,
k=k,
num_experts=num_experts,
ab_dtype_name=ab_dtype,
sf_dtype_name=sf_dtype,
c_dtype_name=c_dtype,
sf_vec_size=sf_vec_size,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
)

# Compute max active clusters (cached to avoid expensive HardwareInfo queries)
max_active_clusters = get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

sm_count limit isn’t applied to max_active_clusters.
Respect the API contract by clamping the computed value.

🔧 Suggested fix
-    max_active_clusters = get_max_active_clusters(
-        cluster_shape_mn[0] * cluster_shape_mn[1]
-    )
+    max_active_clusters = min(
+        get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
+        sm_count,
+    )
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py` around lines 416
- 437, The computed max_active_clusters from
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) is not being
clamped by the sm_count limit; after computing max_active_clusters, clamp it
with sm_count (e.g., max_active_clusters = min(max_active_clusters, sm_count))
so the API contract is respected. Ensure sm_count is defined (it may be set via
get_num_sm(a.device)) before the clamp and apply this change right after the
call to get_max_active_clusters in the block that includes sm_count, get_num_sm,
and cluster_shape_mn.

Comment on lines +84 to +110
# Cache for HardwareInfo - it's expensive to create on every call
_hardware_info_cache: "cutlass.utils.HardwareInfo | None" = None


def get_hardware_info() -> "cutlass.utils.HardwareInfo":
"""Get cached HardwareInfo singleton.

HardwareInfo queries CUDA device capabilities, which can be expensive.
This function caches the singleton to avoid repeated queries.
"""
global _hardware_info_cache
if _hardware_info_cache is None:
_hardware_info_cache = cutlass.utils.HardwareInfo()
return _hardware_info_cache


@functools.cache
def get_max_active_clusters(cluster_size: int) -> int:
"""Get max active clusters for a given cluster size (cached).

Args:
cluster_size: Product of cluster_shape_mn dimensions.

Returns:
Maximum number of active clusters supported by hardware.
"""
return get_hardware_info().get_max_active_clusters(cluster_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, verify the file exists and check its structure
find . -name "utils.py" -path "*/cute_dsl/*" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

# Check the actual content around lines 84-110
if [ -f "flashinfer/cute_dsl/utils.py" ]; then
  wc -l flashinfer/cute_dsl/utils.py
  echo "=== Lines 84-110 ===" 
  sed -n '84,110p' flashinfer/cute_dsl/utils.py
  echo "=== Full context (wider range) ==="
  sed -n '1,120p' flashinfer/cute_dsl/utils.py
fi

Repository: flashinfer-ai/flashinfer

Length of output: 4907


🏁 Script executed:

# Check imports at the top of the file
head -50 flashinfer/cute_dsl/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 1545


🏁 Script executed:

# Search for device-related handling in the entire flashinfer directory
rg "cuda\.current_device|torch\.cuda\.device|device.*index" flashinfer/ -A 2 -B 2 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 5755


🏁 Script executed:

# Check if HardwareInfo is documented or used elsewhere
rg "HardwareInfo" flashinfer/ --type py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1225


🏁 Script executed:

# Search for usages of get_max_active_clusters function
rg "get_max_active_clusters" flashinfer/ -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 5382


🏁 Script executed:

# Search for usages of get_hardware_info function
rg "get_hardware_info" flashinfer/ -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 606


🏁 Script executed:

# Check if there are any multi-device tests or examples
rg "cuda\.device_count|multi.*device|device.*0.*1" flashinfer/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2114


🏁 Script executed:

# Look at how cutlass HardwareInfo is used in the blockscaled files
grep -A 10 "get_max_active_clusters" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 635


🏁 Script executed:

# Check if there's any device switching context around the usages
rg -B 5 "get_max_active_clusters" flashinfer/cute_dsl/blockscaled_gemm.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 647


🏁 Script executed:

# Check if blockscaled_gemm.py has any device-specific initialization
grep -n "device\|cuda" flashinfer/cute_dsl/blockscaled_gemm.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 834


🏁 Script executed:

# Check if cutlass.utils.HardwareInfo is mentioned in CUTLASS documentation patterns
web_search
CUTLASS HardwareInfo device-aware per-GPU capabilities

Repository: flashinfer-ai/flashinfer

Length of output: 164


🌐 Web query:

CUTLASS cutlass.utils.HardwareInfo device-specific GPU capabilities

💡 Result:

cutlass.utils.HardwareInfo is a small helper class in CUTLASS’s Python CuTe DSL that queries device-specific GPU capabilities at runtime. It exposes methods to obtain/inspect low-level hardware properties and driver compatibility used by CUTLASS kernels, for example:

  • get_device_multiprocessor_count(): number of SMs (multiprocessors). [1]
  • get_l2_cache_size_in_bytes(): L2 cache size. [1]
  • get_max_active_clusters(cluster_size): compute maximum active clusters for a given cluster size. [1]
  • Utility/driver-check helpers (internal): _checkCudaErrors, _cuda_driver_version_ge / _lt, _cudaGetErrorEnum, plus small host/device helper functions used to probe properties. [1]
  • Related utility in cutlass.utils: get_smem_capacity_in_bytes(compute_capability) for shared-memory sizing by compute capability. [1]

For implementation and repository context see the CUTLASS source/docs (CuTe DSL utils / HardwareInfo). [1][2]

Sources:

  1. CUTLASS cutlass.utils documentation (HardwareInfo, get_smem_capacity_in_bytes). [docs.nvidia.com CUTLASS cutlass.utils]. [2]
  2. NVIDIA CUTLASS GitHub repository (project context and supported architectures). [github.com/NVIDIA/cutlass].

🏁 Script executed:

# Check the actual context around "on current device" comments
grep -B 5 -A 5 "on current device" flashinfer/cute_dsl/blockscaled_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 602


🏁 Script executed:

# Look for any tests or examples that use multiple devices
find flashinfer/tests -name "*.py" -o -name "*.py" | xargs grep -l "device.*0.*device.*1\|cuda:0.*cuda:1" 2>/dev/null | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 119


🏁 Script executed:

# Check if there's a specific device passed anywhere in the call stack
rg "MaskedBatchedMatmulCuteDSL|get_max_active_clusters" flashinfer/ -B 3 -A 3 | grep -i "device\|cuda" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 434


Include device identity in the cache key to support multi-device setups.
The global singleton HardwareInfo instance and @functools.cache keyed only by cluster_size will return incorrect values when running on different GPUs with different capabilities, as get_max_active_clusters has no device parameter. Compare with get_num_sm(device) in the same file, which correctly handles device-specific caching. For multi-device support, either add a device parameter to both get_hardware_info() and get_max_active_clusters(), or maintain a per-device cache dictionary.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/utils.py` around lines 84 - 110, The cached HardwareInfo
and get_max_active_clusters are device-agnostic and will return wrong values on
multi-GPU machines; update caching to be device-aware: change get_hardware_info
to accept an optional device identifier (or obtain current device internally)
and replace the single _hardware_info_cache with a per-device cache (e.g., dict
keyed by device id) for the HardwareInfo singleton; also update
get_max_active_clusters to include the device id in its cache key (remove or
replace the `@functools.cache` usage with a device-keyed cache or make the
function accept a device parameter so caching is per-device). Ensure you
reference and update the symbols _hardware_info_cache, get_hardware_info,
get_max_active_clusters, and the use of `@functools.cache` accordingly so each GPU
gets correct, device-specific values.

@nv-yunzheq nv-yunzheq requested a review from Anerudhan as a code owner January 23, 2026 00:41
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 245-267: The benchmark is converting tensors inside the run()
closure which inflates timing; precompute the converted tensors once before
defining or entering run(): create ti_int = ti.to(torch.int) and w1_long =
inputs["w1_fp4"].contiguous().view(torch.long) and w2_long =
inputs["w2_fp4"].contiguous().view(torch.long) (and any other .to/.view
conversions like bias.float() or similar) and then call fused_topk_deepseek(...)
and cutlass_fused_moe(hidden_fp4, ti_int, tv, w1_long, w2_long, torch.bfloat16,
quant_scales=quant_scales, input_sf=input_sf, output=output) from inside run().
- Around line 693-697: The current parsing of args.num_tokens into tokens fails
when --num-tokens "" is passed because "".split(",") yields [''] and int('')
raises ValueError; update the tokens assignment logic (where tokens is computed
and args.num_tokens and TOKEN_COUNTS are referenced) to first validate/trim
args.num_tokens and skip empty segments before int conversion (e.g., check
args.num_tokens is not empty/whitespace and filter split parts with x.strip()
before calling int), or fall back to TOKEN_COUNTS; optionally add a clear error
message if parsing still fails.
🧹 Nitpick comments (7)
tests/moe/test_cute_dsl_fused_moe.py (3)

36-50: Use flashinfer.utils.is_sm100a_supported() instead of custom GPU check.

Per coding guidelines, test implementations should use flashinfer.utils functions for GPU architecture checks to ensure consistency across the test suite.

♻️ Suggested refactor
-def is_blackwell():
-    """Check if running on Blackwell GPU (SM100+)."""
-    if not torch.cuda.is_available():
-        return False
-    props = torch.cuda.get_device_properties(0)
-    return props.major >= 10
+from flashinfer.utils import is_sm100a_supported


 # Skip decorators
 cute_dsl_available = pytest.mark.skipif(
     not is_cute_dsl_available(), reason="CuteDSL not available"
 )
 blackwell_required = pytest.mark.skipif(
-    not is_blackwell(), reason="Requires Blackwell GPU (SM100+)"
+    not is_sm100a_supported(), reason="Requires Blackwell GPU (SM100+)"
 )

186-194: Set CUDA seed for reproducible GPU tensor generation.

torch.manual_seed() only affects CPU RNG. Since tensors are created directly on CUDA, also set torch.cuda.manual_seed(seed) for deterministic test behavior.

♻️ Suggested fix
     torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
     sf_vec_size = 16

469-485: Add num_local_experts parameter for consistency with other tests.

The other test methods explicitly pass num_local_experts. While the parameter has a default, including it here improves consistency and makes the test configuration explicit.

♻️ Suggested fix
         with autotune(True):
             result = cute_dsl_fused_moe_nvfp4(
                 x=tensors["x"],
                 x_sf=tensors["x_sf"],
                 token_selected_experts=tensors["token_selected_experts"],
                 token_final_scales=tensors["token_final_scales"],
                 w1_weight=tensors["w1_weight"],
                 w1_weight_sf=tensors["w1_weight_sf"],
                 w1_alpha=tensors["w1_alpha"],
                 fc2_input_scale=tensors["fc2_input_scale"],
                 w2_weight=tensors["w2_weight"],
                 w2_weight_sf=tensors["w2_weight_sf"],
                 w2_alpha=tensors["w2_alpha"],
                 num_experts=num_experts,
                 top_k=top_k,
+                num_local_experts=num_experts,
             )
benchmarks/bench_moe_deepseek.py (4)

46-53: Consider adding a comment explaining the bpe calculation.

The bpe = 0.5 + 1/16 value represents bytes per element for FP4 format (0.5 bytes for 4-bit value + 0.0625 bytes for scale factor overhead), but this isn't immediately obvious to readers.

Suggested clarification
 def calc_bw(n, ms):
+    # bpe: 0.5 bytes for FP4 data + 1/16 bytes for block scale factor (1 scale per 16 elements)
     bpe = 0.5 + 1 / 16

56-63: Missing validation for tensor dimension divisibility.

The interleave function assumes M is divisible by gs * 2 (default 128). If this precondition fails, the error from view() will be cryptic. Consider adding an assertion.

Suggested improvement
 def interleave(x, gs=64):
     M, K = x.shape[-2], x.shape[-1]
+    assert M % (gs * 2) == 0, f"M ({M}) must be divisible by {gs * 2}"
     return (
         x.view(*x.shape[:-2], 2, M // (gs * 2), gs, K)

66-70: Missing CUDA random seed for full reproducibility.

torch.manual_seed(42) only sets the CPU RNG. For reproducible GPU tensor generation (e.g., torch.randn(..., device="cuda")), you should also set the CUDA seed.

Suggested fix
 def create_inputs(n, dev="cuda"):
     """Create inputs for all backends (CuteDSL, CUTLASS, TRTLLM)."""
     from flashinfer.fp4_quantization import fp4_quantize

     torch.manual_seed(42)
+    torch.cuda.manual_seed(42)
     sv = 16

375-584: Consider extracting shared input preparation logic to reduce duplication.

The run_autotune function duplicates significant preparation logic from bench_cute_dsl, bench_cutlass, and bench_trtllm. The prep() (lines 513-519) and shuf() (lines 528-537) helpers are also duplicated from bench_trtllm (lines 297-303 and 312-321).

For a WIP benchmark script this is acceptable, but extracting common preparation into shared helper functions would improve maintainability.

Additionally, lines 486-489 have the same tensor conversion overhead issue noted for bench_cutlass:

ti.to(torch.int),  # Line 486
inputs["w1_fp4"].contiguous().view(torch.long),  # Line 488
inputs["w2_fp4"].contiguous().view(torch.long),  # Line 489

These should be pre-computed outside the autotune loop for accurate profiling.

Comment on lines 693 to 697
tokens = (
[int(x) for x in args.num_tokens.split(",")]
if args.num_tokens
else TOKEN_COUNTS
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle edge case for empty --num-tokens argument.

If --num-tokens "" is passed, "".split(",") returns [''], and int('') raises a ValueError. Consider validating the input or using a more robust parsing approach.

Suggested fix
     tokens = (
-        [int(x) for x in args.num_tokens.split(",")]
-        if args.num_tokens
+        [int(x.strip()) for x in args.num_tokens.split(",") if x.strip()]
+        if args.num_tokens and args.num_tokens.strip()
         else TOKEN_COUNTS
     )
+    if not tokens:
+        tokens = TOKEN_COUNTS
🤖 Prompt for AI Agents
In `@benchmarks/bench_moe_deepseek.py` around lines 693 - 697, The current parsing
of args.num_tokens into tokens fails when --num-tokens "" is passed because
"".split(",") yields [''] and int('') raises ValueError; update the tokens
assignment logic (where tokens is computed and args.num_tokens and TOKEN_COUNTS
are referenced) to first validate/trim args.num_tokens and skip empty segments
before int conversion (e.g., check args.num_tokens is not empty/whitespace and
filter split parts with x.strip() before calling int), or fall back to
TOKEN_COUNTS; optionally add a clear error message if parsing still fails.

…-ai#2445)

<!-- .github/pull_request_template.md -->

## 📌 Description

bugfix to flashinfer-ai#2093, the fundamental issue is we should not write to
`jit_env.FLASHINFER_CSRC_DIR` (which might be read-only) for fused-moe
module, instead we should use `FLASHINFER_GEN_SRC_DIR` which is supposed
to be writable.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* JIT kernel generation now correctly uses user-writable cache
directories instead of package directories, resolving compatibility
issues in post-installation and read-only environments.

* **Documentation**
* Updated JIT directory rules clarifying which directories are writable
versus read-only. Added details on GPU auto-detection and CUDA
architecture environment variable controls.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/logging.h (1)

33-39: ⚠️ Potential issue | 🟠 Major

Avoid global spdlog::set_level side effects.

spdlog::set_level(lvl) modifies the level for all registered loggers in the process, not just the "flashinfer" logger. This creates an unintended global side effect that can alter application logging behavior. Set the level on the logger instance instead.

Suggested fix (scope level change to library logger only)
 inline void set_log_level(spdlog::level::level_enum lvl) {
   auto fmt = "[%Y-%m-%d %H:%M:%S.%f] [%n] [%^%l%$] %v";
   auto console_sink = std::make_shared<spdlog::sinks::stdout_color_sink_mt>();
   console_sink->set_pattern(fmt);
   console_sink->set_level(lvl);
-  spdlog::set_default_logger(std::make_shared<spdlog::logger>("flashinfer", console_sink));
-  spdlog::set_level(lvl);
+  auto logger = std::make_shared<spdlog::logger>("flashinfer", console_sink);
+  logger->set_level(lvl);
+  spdlog::set_default_logger(logger);
 }
🧹 Nitpick comments (1)
flashinfer/jit/fused_moe.py (1)

174-174: Consider: rglob may pick up stale generated files from previous builds.

If the generation logic changes and produces fewer or differently-named .generated.cu files, stale files from a previous run could still be included in the build, potentially causing compilation errors or linking unexpected kernels.

Consider clearing the output directory before regenerating, or using a more deterministic approach that returns the exact list of generated files from generate_gemm_operations().

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !266 has been updated with latest changes, and the CI pipeline #43116439 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43116439: canceled

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !266 has been updated with latest changes, and the CI pipeline #43121276 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43121276: canceled

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed UTs should be fixed by #2468 , LGTM otherwise.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 3, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !266 has been updated with latest changes, and the CI pipeline #43210969 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43210969: canceled

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !266 has been updated with latest changes, and the CI pipeline #43303418 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43303418: 12/20 passed

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !266 has been updated with latest changes, and the CI pipeline #43447420 is currently running. I'll report back once the pipeline job completes.

@nv-yunzheq nv-yunzheq enabled auto-merge (squash) February 6, 2026 21:56
@nv-yunzheq nv-yunzheq disabled auto-merge February 6, 2026 21:56
@nv-yunzheq nv-yunzheq merged commit 99562e5 into flashinfer-ai:main Feb 7, 2026
28 checks passed
yzh119 pushed a commit that referenced this pull request Feb 16, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

The PR is follow up to PR #2398 
To integration [TRTLLM PR
10987](NVIDIA/TensorRT-LLM#10987). Use TMA.RED
to improve effective memory bandwidth

Perf data is (tested on GB200):

Tokens | CuteDSL (main) ms | CuteDSL (TMA.RED) ms | TRTLLM gen ms |
CUTLASS ms | Winner | CuteDSL Speedup (main/TMA.RED)
-- | -- | -- | -- | -- | -- | --
1 | 0.064 | 0.064 | 0.053 | 0.099 | TRTLLM | 1.000x
2 | 0.077 | 0.077 | 0.063 | 0.107 | TRTLLM | 1.000x
4 | 0.096 | 0.096 | 0.085 | 0.131 | TRTLLM | 1.000x
8 | 0.096 | 0.096 | 0.091 | 0.131 | TRTLLM | 1.000x
16 | 0.101 | 0.102 | 0.103 | 0.138 | CuteDSL | 0.990x
32 | 0.114 | 0.114 | 0.142 | 0.152 | CuteDSL | 1.000x
62 | 0.122 | 0.122 | 0.183 | 0.163 | CuteDSL | 1.000x
128 | 0.133 | 0.132 | 0.173 | 0.220 | CuteDSL | 1.008x
256 | 0.142 | 0.138 | 0.220 | 0.251 | CuteDSL | 1.029x
512 | 0.190 | 0.183 | 0.271 | 0.333 | CuteDSL | 1.038x
1024 | 0.286 | 0.278 | 0.576 | 0.482 | CuteDSL | 1.029x
2048 | 0.472 | 0.461 | 0.555 | 0.723 | CuteDSL | 1.024x
4096 | 0.855 | 0.824 | 0.873 | 1.278 | CuteDSL | 1.038x
8192 | 1.764 | 1.713 | 1.653 | 2.383 | TRTLLM | 1.030x



## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Introduced block-reduction optimization in MOE finalization kernels
for improved performance on latest hardware.
* Added support for block-wise reduction operations across multiple data
types (BF16, FP32, FP16).

* **Performance**
* Optimized GPU memory utilization by reducing unnecessary cross-device
data transfers during computation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants