feat: cuteDSL fp4 moe for better DSR1 performance.#2398
feat: cuteDSL fp4 moe for better DSR1 performance.#2398nv-yunzheq merged 33 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR introduces a comprehensive Mixture-of-Experts (MoE) pipeline for DeepSeek-V3 on Blackwell GPUs (SM100+), spanning CUDA kernels, TVM FFI bindings, Python CuTe-DSL wrappers, and integrated testing/benchmarking. It adds permute/unpermute/sort/activation kernels with multi-precision support (FP16, BF16, FP8, FP4) and high-level APIs for fused MoE GEMM operations with persistent tile scheduling and CUDA graph compatibility. Changes
Sequence Diagram(s)sequenceDiagram
participant Python as Python/PyTorch
participant Permute as moe_permute Kernel
participant Gemm1 as GEMM1 + SwiGLU
participant Sort as moe_sort (Routing)
participant Gather as Gather Kernel
participant Gemm2 as GEMM2 + Finalize
participant Unpermute as moe_unpermute Kernel
participant Output as Output Buffer
Python->>Sort: Route tokens to experts<br/>(token_selected_experts, scores)
Sort->>Permute: Compute routing mappings<br/>(tile indices, permutation)
Python->>Permute: Call moe_permute<br/>(input → permuted layout)
Permute->>Gemm1: Permuted tokens<br/>(FP4 quantized)
Gemm1->>Gemm1: Per-expert GEMM + SwiGLU<br/>(fused activation)
Gemm1->>Gather: Intermediate outputs
Gather->>Gemm2: Gather by expert<br/>(without re-permutation)
Gemm2->>Gemm2: Per-expert GEMM2<br/>(with finalize epilogue)
Gemm2->>Unpermute: Expert outputs<br/>(FP4 → FP16/BF16)
Python->>Unpermute: Call moe_unpermute<br/>(apply top-k scaling)
Unpermute->>Output: Scatter to original<br/>token positions
Output->>Python: Final MoE output<br/>(combined experts)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by adding a highly optimized, fused Mixture-of-Experts (MoE) implementation. By integrating CuTe-DSL kernels from TensorRT-LLM, the new pipeline leverages advanced GPU features to deliver efficient FP4 computations, reducing memory bandwidth and improving throughput for large language models. The focus is on providing a complete, end-to-end MoE solution with built-in auto-tuning for optimal performance on Blackwell GPUs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant new feature: fused Mixture of Experts (MoE) kernels using CuteDSL for FP4 on Blackwell architecture. The implementation is adapted from TensorRT-LLM and includes several advanced techniques like warp specialization and persistent scheduling. The changes are extensive, adding new CUDA kernels, Python bindings, and high-level APIs with auto-tuning support.
My review focused on the integration and correctness of the new components. I've identified a few issues:
- A missing validation check in
moe_sortthat could lead to runtime errors iftile_tokens_dimis not a power of 2. - A potential bug in a
wrappermethod within a GEMM kernel due to incorrect integer division. - Some code duplication in utility files that should be refactored.
Overall, this is a substantial contribution that brings high-performance MoE capabilities to FlashInfer. The code is well-structured, though complex due to the nature of CuteDSL. Addressing the identified issues will improve the robustness and maintainability of this new feature.
| def moe_sort( | ||
| token_selected_experts: torch.Tensor, | ||
| token_final_scales: torch.Tensor, | ||
| num_experts: int, | ||
| top_k: int, | ||
| local_expert_offset: int = 0, | ||
| num_local_experts: Optional[int] = None, | ||
| tile_tokens_dim: int = 128, | ||
| enable_pdl: bool = False, | ||
| ) -> Tuple[ |
There was a problem hiding this comment.
The moe_sort function passes tile_tokens_dim to a CUDA kernel that expects it to be a power of two. The underlying C++ code uses a computeLog2 function that will return -1 for non-power-of-two inputs, which can lead to undefined behavior or cryptic errors in the routingDeepSeek kernel. It's crucial to validate this parameter in the Python wrapper to prevent such issues.
Please add an assertion at the beginning of the function body:
assert (tile_tokens_dim > 0) and ((tile_tokens_dim & (tile_tokens_dim - 1)) == 0), "tile_tokens_dim must be a power of 2"| epilogue_op: cutlass.Constexpr = lambda x: x, | ||
| ): | ||
| scale_k = k // scaling_vector_size | ||
| num_tiles = m // tile_size |
There was a problem hiding this comment.
The calculation of num_tiles using integer division m // tile_size is incorrect if m is not perfectly divisible by tile_size. This will result in an undersized tile_idx_to_group_idx tensor, leading to out-of-bounds access when the kernel scheduler processes the last partial tile. You should use ceiling division to ensure the number of tiles is calculated correctly.
| num_tiles = m // tile_size | |
| num_tiles = (m + tile_size - 1) // tile_size |
| self._c_pointer = None | ||
| assert int(self._pointer) % self._assumed_align == 0, ( | ||
| f"pointer must be {self._assumed_align} bytes aligned" | ||
| ) | ||
|
|
||
| def size_in_bytes(self) -> int: | ||
| return ctypes.sizeof(ctypes.c_void_p(int(self._pointer))) | ||
|
|
||
| def __get_mlir_types__(self): | ||
| return [self.mlir_type] | ||
|
|
||
| def __c_pointers__(self): | ||
| if self._c_pointer is None: | ||
| self._desc = ctypes.c_void_p(int(self._pointer)) | ||
| self._c_pointer = ctypes.addressof(self._desc) | ||
| return [self._c_pointer] | ||
|
|
||
| def __new_from_mlir_values__(self, values): | ||
| assert len(values) == 1 | ||
| return values[0] | ||
|
|
||
| # Move mlir Type out of __init__ to decouple with mlir Context | ||
| @property | ||
| def mlir_type(self) -> ir.Type: | ||
| return _cute_ir.PtrType.get( | ||
| self._dtype.mlir_type, self._addr_space, self._assumed_align | ||
| ) | ||
|
|
||
| @property | ||
| def dtype(self) -> Type[Numeric]: | ||
| return self._dtype | ||
|
|
||
| @property | ||
| def memspace(self): | ||
| return self._addr_space | ||
|
|
||
| def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: | ||
| raise NotImplementedError("align is not supported in runtime") | ||
|
|
||
| def verify(self, expected_py_type): | ||
| if expected_py_type is Pointer or ( | ||
| isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer | ||
| ): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" | ||
|
|
||
| def __repr__(self): | ||
| return self.__str__() | ||
|
|
||
|
|
||
| def make_ptr( | ||
| dtype: Type[Numeric], | ||
| value: Union[int, ctypes._Pointer], | ||
| mem_space: AddressSpace = AddressSpace.generic, | ||
| assumed_align=None, | ||
| ) -> Pointer: | ||
| """Creates a pointer from a memory address. | ||
|
|
||
| Args: | ||
| dtype (Type[Numeric]): Data type of the pointer elements. | ||
| value (Union[int, ctypes._Pointer]): Memory address as an integer or ctypes pointer. | ||
| mem_space (AddressSpace, optional): Memory address space. Defaults to AddressSpace.generic. | ||
| assumed_align (int, optional): Alignment in bytes. Defaults to None. | ||
|
|
||
| Returns: | ||
| Pointer: A pointer object. | ||
|
|
||
| Example: | ||
| ```python | ||
| import numpy as np | ||
| import ctypes | ||
| from cutlass import Float32 | ||
| from cutlass.cute.runtime import make_ptr | ||
|
|
||
| # Create a numpy array | ||
| a = np.random.randn(16, 32).astype(np.float32) | ||
| # Get pointer address as ctypes pointer | ||
| ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) | ||
| # Create pointer from address | ||
| y = make_ptr(cutlass.Float32, ptr_address) | ||
| ``` | ||
| """ | ||
| # check if value is int or ctypes.POINTER | ||
| if isinstance(value, int): | ||
| address_value = value | ||
| elif isinstance(value, ctypes._Pointer): | ||
| # get address value | ||
| address_value = ctypes.cast(value, ctypes.c_void_p).value | ||
| assert address_value is not None, "Pointer address is None" | ||
| else: | ||
| raise TypeError( | ||
| f"Expect int or ctypes.POINTER for value but got {type(value)=}" | ||
| ) | ||
|
|
||
| return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align) | ||
|
|
There was a problem hiding this comment.
The _Pointer class and make_ptr function are duplicated from flashinfer/cute_dsl/utils.py. This duplicated code appears to be unused within the blackwell module, as other files import these utilities from the central flashinfer.cute_dsl.utils location. Removing this duplication will improve maintainability and prevent potential inconsistencies in the future.
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Fix all issues with AI agents
In `@csrc/moe_utils_binding.cu`:
- Around line 273-338: The kernel assumes token_final_scales is bfloat16 but the
public API allows float32; in moe_sort validate the dtype for
token_final_scales_ptr (or enforce conversion in the Python wrapper) and set
routingData fields accordingly: if callers pass float32, either (A) convert the
tensor to bfloat16 before calling moe_sort (mirror how token_selected_experts is
converted to int32) or (B) detect float32 here and set routingData.mDtypeExpW
and routingData.mDtypeBias to Fp32 and ensure routingData.mPtrTopKWeights is
treated as float*; update the docstring if you choose to restrict to bfloat16.
Ensure checks reference token_final_scales_ptr, routingData.mDtypeExpW,
routingData.mDtypeBias, and routingData.mPtrTopKWeights (or the Python wrapper
conversion path used for token_selected_experts).
In `@csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu`:
- Around line 119-121: The cached static smCount computed via
tensorrt_llm::common::getMultiProcessorCount() prevents correct behavior when
the process switches CUDA devices; change the declaration so smCount is
evaluated per-call (e.g., remove static or make it thread_local as done
elsewhere) and recompute it before calculating maxBlocksPerSM and blocks; update
all occurrences where smCount is defined (the instances that call
getMultiProcessorCount()) so each call queries the current device rather than
reusing a process-global cached value.
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh`:
- Around line 56-69: The gate is only clamped from above which can let very
negative values pass into the sigmoid and cause numerical instability; in
SwigluBiasAdaptor::operator() clamp the gate symmetrically (e.g., use
cutlass::maximum<T>{}(cutlass::minimum<T>{}(gate, limit), -limit)) before
computing the sigmoid and using it for the gate multiplication so the sigmoid
receives a bounded input and numerical overflow/underflow is avoided.
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 72-160: Add the missing `@flashinfer_api` decorator to the public
function create_finalize_fusion_tensors to enable FLASHINFER_LOGLEVEL-based API
logging; also modify its signature to accept an optional token_final_scales:
Optional[torch.Tensor] = None parameter (dtype final_scale_dtype, shape
(seq_len, topk)) and, if provided, validate shape/dtype and use it instead of
generating random values, otherwise keep the current randomized normalized
initialization but update the docstring to state these are placeholder/test
values; leave the existing _finalize_kernel_cache behavior unchanged.
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py`:
- Around line 241-272: The FP4 output scale allocation assumes permuted_m is
divisible by 128 and scale_intermediate_size (computed from intermediate_size //
sf_vec_size) is divisible by 4; add explicit validation in the generate_sfc
branch before allocating out_scale: check that permuted_m % 128 == 0 and
scale_intermediate_size % 4 == 0 (using the existing symbols generate_sfc,
out_scale, permuted_m, intermediate_size, sf_vec_size, scale_intermediate_size),
and raise a clear ValueError if either check fails so the buffer size cannot be
undersized and cause out-of-bounds writes.
- Around line 69-82: Replace the `@functools.lru_cache`(maxsize=None) decorator on
_get_compiled_swiglu_kernel with `@functools.cache` and remove the unused shape
parameters permuted_m, n, k, and num_experts from the function signature and all
call sites so they are not included in the cache key; update any callers that
pass those four parameters to stop supplying them and adjust the function
internals to use only the remaining parameters (ab_dtype_name, sf_dtype_name,
c_dtype_name, sf_vec_size, mma_tiler_mn, cluster_shape_mn, vectorized_f32). Also
make the same decorator/signature change for the other identical
_get_compiled_swiglu_kernel occurrence in this module.
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py`:
- Around line 361-363: Replace the two assert statements that check tensor
device types (the checks using a.device.type == "cuda" and b.device.type ==
"cuda") with explicit runtime validation: use if statements that raise a clear
exception (e.g., ValueError or RuntimeError) when a or b are not on CUDA, and
include a helpful message mentioning which tensor failed the check; update the
validation block in blockscaled_contiguous_grouped_gemm.py to perform these
explicit checks instead of using assert so they are not stripped in optimized
Python.
In `@flashinfer/cute_dsl/tuner.py`:
- Around line 260-266: The initializer lambda in tuner.py currently samples
expert indices with a hardcoded range of 0..7 (comment "num_experts=8 typical"),
which can produce invalid indices for models with different expert counts;
update the lambda used in the tuner initialization (the anonymous function that
calls torch.randint) to derive the upper bound from the actual model/runner
expert count (e.g., use a passed-in num_experts or runner.num_experts) or at
minimum clamp to min/max to avoid out-of-range values; locate the lambda in
tuner.py and replace the hardcoded 8 with the dynamic num_experts value (or a
safe expression like max(1, num_experts)) so sampled indices are always valid.
🧹 Nitpick comments (28)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh (2)
28-40: Unused member variablesalpha,beta,limitinIdentityAdaptor.These members are declared but never used in
operator(). If they're for API consistency with other adaptors, consider documenting this intent. Otherwise, they add unnecessary storage overhead.
42-54: Same observation forGLUAdaptor- unusedalpha,beta,limitmembers.The members are not referenced in the
operator()implementation. If these are reserved for future use or API consistency, a brief comment would clarify intent.flashinfer/cute_dsl/utils.py (1)
29-31: Duplicateceil_divimplementation - consider reusing existing utility.This function is already defined in
flashinfer/utils.py(with proper docstring) and duplicated in several other files. Consider importing fromflashinfer.utilsto reduce duplication.♻️ Suggested change
-def ceil_div(a: int, b: int) -> int: - """Ceiling division.""" - return (a + b - 1) // b +from flashinfer.utils import ceil_divBased on relevant code snippets showing
ceil_divexists inflashinfer/utils.py(lines 621-632).flashinfer/cute_dsl/fused_moe.py (4)
110-125: Prefix unused variable with underscore.
total_num_padded_tokensfrommoe_sortis unpacked but never used. Prefix with underscore to indicate intentional non-use.♻️ Suggested fix
( tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, - total_num_padded_tokens, + _total_num_padded_tokens, num_non_exiting_tiles, ) = moe_sort(
128-136: Auxiliary stream creation without explicit lifecycle management.A new
torch.cuda.Stream()is created at line 133 ifaux_stream is None, but there's no mechanism to reuse or properly manage this stream across calls. Consider:
- Documenting that callers should pass a reusable stream for optimal performance
- Using a module-level cached stream to avoid repeated allocation
336-342: Duplicate output allocation logic.The
moe_outputallocation at lines 337-342 duplicates the logic in_cute_dsl_fused_moe_nvfp4_impl(lines 102-107). Sincemoe_outputis passed to the runner, the allocation in the public API is necessary, but consider removing the duplicate in_implor documenting why both are needed.
179-180:memset_event.wait()could be more explicit about stream context.The call at line 180 waits on the default stream after the
memset_event.record()at line 177 is called onaux_stream. While this synchronization pattern is correct, consider explicitly documenting the stream interaction or adding error handling for potential failures in the aux_stream work (e.g., exceptions inmoe_output_memset).flashinfer/cute_dsl/tuner.py (3)
250-273: Mutable class attributes should be annotated withClassVar.
dynamic_tensor_initializersandtuning_configare class-level attributes that should be typed withtyping.ClassVarto indicate they're shared across instances and not instance attributes.♻️ Suggested fix
+from typing import Any, Callable, ClassVar, Dict, List, Tuple + class CuteDslFusedMoENvfp4Runner(TunableRunner): ... # Tensor initializers for dynamic tensors (indices 0, 1, 2, 3, 11) # These create valid dummy tensors for profiling with different num_tokens - dynamic_tensor_initializers = [ + dynamic_tensor_initializers: ClassVar[List[Callable]] = [ ... ] # Tuning config with dynamic tensor specs for num_tokens dimension - tuning_config = TuningConfig( + tuning_config: ClassVar[TuningConfig] = TuningConfig( ... )
341-347: PEP 484 violation: implicitOptionaltype.
tactic: Tuple[Any, ...] = Noneimplicitly allowsNonebut the type hint doesn't reflect this. Use explicitOptionalor union syntax.♻️ Suggested fix
+from typing import Any, Callable, Dict, List, Optional, Tuple + def forward( # type: ignore[override] self, inputs: List[torch.Tensor], - tactic: Tuple[Any, ...] = None, # type: ignore[assignment] + tactic: Optional[Tuple[Any, ...]] = None, do_preparation: bool = False, **kwargs: Any, ) -> torch.Tensor:
362-363: Handletactic == -1edge case.The condition
tactic is None or tactic == -1suggests-1is a sentinel value. Document this behavior or use a more explicit sentinel (e.g., a constant).flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py (2)
144-150: Ambiguous variable namel- rename for clarity.The variable
l(lowercase L) is easily confused with1(one). Rename to something more descriptive likenum_groupsorbatch_dim.♻️ Suggested fix
def create_scale_factor_tensor( - l: int, + num_groups: int, mn: int, k: int, sf_vec_size: int, dtype: Type[cutlass.Numeric], ) -> Tuple[torch.Tensor, cute.Tensor, torch.Tensor]: """Create scale factor tensors in the MMA-compatible layout. ... Args: - l: Batch/expert dimension + num_groups: Batch/expert dimensionAnd update all references to
lwithin the function.
176-177: Another duplicateceil_divdefinition.This is the same utility already present in
flashinfer/utils.pyandflashinfer/cute_dsl/utils.py. Import from the canonical location instead.♻️ Suggested fix
+from flashinfer.utils import ceil_div + def create_scale_factor_tensor(...): - def ceil_div(a, b): - return (a + b - 1) // b - sf_k = ceil_div(k, sf_vec_size)csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu (2)
134-136: Missing error checking for cudaLaunchKernelEx.The return value of
cudaLaunchKernelExis not checked. Consider adding error handling to catch launch failures.🔧 Suggested fix
- cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, - tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, - hidden_size, top_k, tile_size); + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, + hidden_size, top_k, tile_size));This applies to all
cudaLaunchKernelExcalls at lines 235, 334, and 463 as well.
56-62: Missing return value check for cudaOccupancyMaxActiveBlocksPerMultiprocessor.The CUDA API call result is discarded. If it fails,
numBlocksremains 0, which could cause issues downstream.🔧 Suggested fix
template <typename KernelFunc> int32_t getMaxActiveBlocksPerSM(KernelFunc kernel, int32_t threadsPerBlock, size_t dynamicSmemSize) { int numBlocks = 0; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, threadsPerBlock, - dynamicSmemSize); + TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, threadsPerBlock, + dynamicSmemSize)); return numBlocks; }csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h (1)
108-113: Forward declaration without implementation.
moeActivationQuantizeis declared here but not implemented in the corresponding.cufile. Line 480 ofmoeUtils.cunotes this is deferred. Consider adding a comment here to indicate the function is not yet implemented to avoid linker errors if called.+// Note: Implementation deferred - will be added when NVFP4 output support is needed. template <typename InputType, typename OutputType, typename SFType> void moeActivationQuantize(InputType const* input, OutputType* output, float const* global_sf,flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (3)
293-318: Consider adding@functools.cachedecorator for module-level caching.Per coding guidelines, Python API functions should use
@functools.cachedecorator to avoid recompilation. While the function uses_gather_kernel_cacheinternally, adding the decorator could provide additional caching benefits at the function level for repeated identical calls.
400-401: Replaceassertwith explicit exception for input validation.Using
assertfor input validation in public APIs can be disabled with-Oflag. Use explicit exceptions instead.🔧 Suggested fix
- assert a.device.type == "cuda", "Input tensors must be on CUDA device" - assert b.device.type == "cuda", "Input tensors must be on CUDA device" + if a.device.type != "cuda": + raise ValueError("Input tensor 'a' must be on CUDA device") + if b.device.type != "cuda": + raise ValueError("Input tensor 'b' must be on CUDA device")
189-191: Global kernel cache is not thread-safe.
_gather_kernel_cacheis a module-level mutable dictionary that could cause race conditions in multi-threaded scenarios. Consider usingthreading.Lockorfunctools.lru_cachefor thread-safe caching.🔧 Suggested approach
import threading _gather_kernel_cache: Dict[Tuple, Any] = {} _gather_kernel_cache_lock = threading.Lock() # Then in _get_compiled_gather_kernel: with _gather_kernel_cache_lock: if cache_key not in _gather_kernel_cache: # ... compile kernel ... _gather_kernel_cache[cache_key] = compiled_gemm return _gather_kernel_cache[cache_key]flashinfer/cute_dsl/blackwell/custom_pipeline.py (2)
61-71: Unused parametercta_layout_vmnk.The parameter is accepted but never used in the function body. If this is reserved for future use, consider documenting it or using
_prefix.🔧 Suggested fix
-def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): # noqa: ARG001 """Initializes the mbarrier and synchronizes the threadblock or cluster. This function places a fence on the mbarrier initialization to ensure proper synchronization across the threadblock or cluster. Args: - cta_layout_vmnk (Optional[cute.Layout]): The CTA layout for VMNK. Defaults to None. + cta_layout_vmnk (Optional[cute.Layout]): Reserved for future cluster sync. Defaults to None. """ cute.arch.mbarrier_init_fence()
184-187: Consider usingTypeErrorfor type validation.Static analysis suggests
TypeErroris more appropriate when checking instance types. This is a minor style improvement.🔧 Suggested fix
if not isinstance(barrier_storage, cute.Pointer): - raise ValueError( + raise TypeError( f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" )This applies to similar checks at lines 329-332 and 484-487.
flashinfer/cute_dsl/blackwell/utils.py (1)
61-196: Duplicated code fromflashinfer/cute_dsl/utils.py.The
_Pointerclass (lines 62-149) andmake_ptrfunction (lines 152-196) are nearly identical to those inflashinfer/cute_dsl/utils.py(see relevant_code_snippets). The comment on line 61 mentions "WAR for CuTeDSL make_ptr implementation" - if this is a temporary workaround, consider adding a TODO to consolidate once the upstream issue is resolved.# WAR for CuTeDSL make_ptr implementation # TODO: Remove this once upstream CuTeDSL provides the fix, and import from flashinfer.cute_dsl.utilsflashinfer/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py (2)
754-754: Unused unpacked variablesbidyandbidz.These variables are unpacked but never used. Use underscore prefix to indicate intentional discard.
🔧 Suggested fix
- bidx, bidy, bidz = cute.arch.block_idx() + bidx, _bidy, _bidz = cute.arch.block_idx()
1864-1871: Unused parametertidxinepilog_gmem_copy_and_partition.The
tidxparameter is declared but not used in the method body. Consider removing it or adding# noqa: ARG002if it's part of a required interface.🔧 Suggested fix
def epilog_gmem_copy_and_partition( self, - tidx: cutlass.Int32, + tidx: cutlass.Int32, # noqa: ARG002 - kept for interface consistency atom: Union[cute.CopyAtom, cute.TiledCopy],flashinfer/jit/moe_utils.py (1)
17-26: Cache the JIT spec generator to avoid redundant registrations.A module-level cache keeps JitSpec creation idempotent and aligns with the repo caching guidance. As per coding guidelines, please add caching.
♻️ Proposed change
+import functools + def gen_moe_utils_module() -> JitSpec: +@functools.cache +def gen_moe_utils_module() -> JitSpec:flashinfer/cute_dsl/blackwell/__init__.py (1)
43-53: Optional: sort__all__to satisfy Ruff (RUF022).If Ruff is enforced, sorting will keep lint clean.
♻️ Proposed change
__all__ = [ - "Sm100BlockScaledContiguousGroupedGemmKernel", - "Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel", - "Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel", "BlockScaledContiguousGatherGroupedGemmKernel", + "Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel", + "Sm100BlockScaledContiguousGroupedGemmKernel", + "Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel", + "TRTLLM_ENABLE_PDL", "cvt_sf_MKL_to_M32x4xrm_K4xrk_L", - "TRTLLM_ENABLE_PDL", "griddepcontrol_launch_dependents", "griddepcontrol_wait", "is_power_of_2", ]tests/moe/test_cute_dsl_fused_moe.py (2)
36-41: Consider usingflashinfer.utils.is_sm100a_supported()for GPU capability check.As per coding guidelines, tests should use
flashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures.♻️ Suggested refactor
-def is_blackwell(): - """Check if running on Blackwell GPU (SM100+).""" - if not torch.cuda.is_available(): - return False - props = torch.cuda.get_device_properties(0) - return props.major >= 10 +from flashinfer.utils import is_sm100a_supported + +def is_blackwell(): + """Check if running on Blackwell GPU (SM100+).""" + return is_sm100a_supported()
341-394: Consider adding a numerical accuracy check for expert parallelism tests.The test validates that results don't contain NaN/Inf but skips strict accuracy comparison. While the comment explains the semantic difference from filtering, a basic sanity check (e.g., output magnitude is reasonable) would strengthen the test.
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)
256-278: Consider adding@functools.cachedecorator per coding guidelines.As per coding guidelines, Python API functions should use
@functools.cachedecorator to implement module-level caching and avoid recompilation. The@flashinfer_apidecorator is correctly used for debugging.Note: The internal
_get_compiled_finalize_kernelalready implements caching, so this may be intentionally omitted to avoid double-caching. If so, a brief comment explaining this would help future maintainers.
| void moe_sort( | ||
| // Inputs | ||
| int64_t token_selected_experts_ptr, // [num_tokens, top_k], int32 | ||
| int64_t token_final_scales_ptr, // [num_tokens, top_k], float32 or bf16 | ||
| int32_t num_tokens, int32_t num_experts, int32_t top_k, int32_t local_expert_offset, | ||
| int32_t num_local_experts, int32_t tile_tokens_dim, bool use_pdl, | ||
| // Outputs (pre-allocated buffers) | ||
| int64_t tile_idx_to_expert_idx_ptr, int64_t tile_idx_to_mn_limit_ptr, | ||
| int64_t expanded_idx_to_permuted_idx_ptr, int64_t permuted_idx_to_expanded_idx_ptr, | ||
| int64_t total_num_padded_tokens_ptr, int64_t num_non_exiting_tiles_ptr, | ||
| // Optional: expert counts buffer for large token counts (>1024) | ||
| // Should be size 2 * num_experts, int32 | ||
| int64_t expert_counts_ptr) { | ||
| // Set up the routing data structure | ||
| moe::dev::routing::routingDeepSeek::Data routingData; | ||
|
|
||
| // Configure dtypes | ||
| routingData.mDtypeExpW = batchedGemm::trtllm::gen::Dtype::Bfloat16; | ||
| routingData.mDtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; | ||
| routingData.mDtypeScore = batchedGemm::trtllm::gen::Dtype::Fp32; | ||
| routingData.mUsePdl = use_pdl; | ||
|
|
||
| // Input tensors (pre-computed expert selections) | ||
| routingData.mPtrTopKIds = reinterpret_cast<int32_t*>(token_selected_experts_ptr); | ||
| routingData.mPtrTopKWeights = reinterpret_cast<void*>(token_final_scales_ptr); | ||
| routingData.mPtrScores = nullptr; // Not using routing logits | ||
| routingData.mPtrRoutingBias = nullptr; // Not using bias | ||
|
|
||
| // Output tensors | ||
| routingData.mPtrCtaIdxXyToBatchIdx = reinterpret_cast<int32_t*>(tile_idx_to_expert_idx_ptr); | ||
| routingData.mPtrCtaIdxXyToMnLimit = reinterpret_cast<int32_t*>(tile_idx_to_mn_limit_ptr); | ||
| routingData.mPtrExpandedIdxToPermutedIdx = | ||
| reinterpret_cast<int32_t*>(expanded_idx_to_permuted_idx_ptr); | ||
| routingData.mPtrPermutedIdxToTokenIdx = | ||
| reinterpret_cast<int32_t*>(permuted_idx_to_expanded_idx_ptr); | ||
| routingData.mPtrPermutedIdxSize = reinterpret_cast<int32_t*>(total_num_padded_tokens_ptr); | ||
| routingData.mPtrNumNonExitingCtas = reinterpret_cast<int32_t*>(num_non_exiting_tiles_ptr); | ||
|
|
||
| // Not using packed format since we have explicit TopK IDs | ||
| routingData.mPtrTopKPacked = nullptr; | ||
|
|
||
| // Expert counts buffer: required when num_tokens > 1024 | ||
| // The kernel will set this to nullptr internally for small token counts | ||
| routingData.mPtrExpertCounts = reinterpret_cast<int32_t*>(expert_counts_ptr); | ||
|
|
||
| // Metadata | ||
| routingData.mNumTokens = num_tokens; | ||
| routingData.mNumExperts = num_experts; | ||
| routingData.mTopK = top_k; | ||
| routingData.mPaddingLog2 = computeLog2(tile_tokens_dim); | ||
| routingData.mTileTokensDim = tile_tokens_dim; | ||
| routingData.mLocalExpertsStartIdx = local_expert_offset; | ||
| routingData.mLocalExpertsStrideLog2 = 0; | ||
| routingData.mNumLocalExperts = num_local_experts; | ||
|
|
||
| // DeepSeekV3 specific parameters | ||
| // For moe_sort, we use n_group=1, topk_group=1 since experts are already selected | ||
| routingData.mNumExpertGroups = 1; | ||
| routingData.mNumLimitedGroups = 1; | ||
| routingData.mRouteScale = 1.0f; | ||
| routingData.mUseRoutingSoftmax = false; | ||
|
|
||
| // Run the routing kernel | ||
| cudaStream_t stream = get_current_stream(); | ||
| moe::dev::routing::routingDeepSeek::run(routingData, stream); | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && find . -name "*.py" -type f 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Search for moe_sort calls in Python files
find . -name "*.py" -type f -exec grep -l "moe_sort" {} \; 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 214
🏁 Script executed:
# Search more broadly for moe_sort references and how it's called
rg "moe_sort" --type=py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 11915
🏁 Script executed:
# Also check the csrc directory structure for binding files
ls -la csrc/ 2>/dev/null | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2289
🏁 Script executed:
rg "def moe_sort" -A 50 flashinfer/moe_utils.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 2412
🏁 Script executed:
# Get more context on token_final_scales handling
rg "token_final_scales" flashinfer/moe_utils.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 1955
🏁 Script executed:
# Check if there's dtype conversion or validation before calling the C++ kernel
sed -n '/def moe_sort/,/^def /p' flashinfer/moe_utils.py | head -120Repository: flashinfer-ai/flashinfer
Length of output: 5232
🏁 Script executed:
# Check if there's dtype validation for token_final_scales anywhere
rg "token_final_scales" -B 5 -A 5 flashinfer/moe_utils.py | grep -A 15 "dtype"Repository: flashinfer-ai/flashinfer
Length of output: 1617
🏁 Script executed:
# Check the full moe_sort function implementation to see if dtype is checked
sed -n '/def moe_sort/,/^ func(/p' flashinfer/moe_utils.py | tail -50Repository: flashinfer-ai/flashinfer
Length of output: 1878
🏁 Script executed:
# Check if mLocalExpertsStrideLog2 is ever used non-zero in the codebase
rg "mLocalExpertsStrideLog2" -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 19411
🏁 Script executed:
# Check the routing kernel implementation to understand what this parameter means
rg "routingDeepSeek" csrc/ --type=cu -lRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Search for documentation or comments about stride in DeepSeek routing
rg "Stride\|stride\|STRIDE" csrc/ --type=cu -B 2 -A 2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 94
Add dtype validation for token_final_scales or document bfloat16-only requirement.
The Python API documents token_final_scales as supporting both torch.float32 and torch.bfloat16, but the C++ kernel hardcodes mDtypeExpW and mDtypeBias to Bfloat16 (lines 290-291). If callers pass float32 tensors, the kernel will misinterpret the bit patterns as bfloat16, causing silent correctness issues. Either enforce dtype conversion in the Python wrapper (similar to how token_selected_experts is converted to int32) or update the docstring to document that only bfloat16 is supported.
Note: mLocalExpertsStrideLog2 = 0 is intentional and correct for non-strided deployment (all experts local to a single GPU).
🤖 Prompt for AI Agents
In `@csrc/moe_utils_binding.cu` around lines 273 - 338, The kernel assumes
token_final_scales is bfloat16 but the public API allows float32; in moe_sort
validate the dtype for token_final_scales_ptr (or enforce conversion in the
Python wrapper) and set routingData fields accordingly: if callers pass float32,
either (A) convert the tensor to bfloat16 before calling moe_sort (mirror how
token_selected_experts is converted to int32) or (B) detect float32 here and set
routingData.mDtypeExpW and routingData.mDtypeBias to Fp32 and ensure
routingData.mPtrTopKWeights is treated as float*; update the docstring if you
choose to restrict to bfloat16. Ensure checks reference token_final_scales_ptr,
routingData.mDtypeExpW, routingData.mDtypeBias, and routingData.mPtrTopKWeights
(or the Python wrapper conversion path used for token_selected_experts).
| static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); | ||
| int32_t const maxBlocksPerSM = getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); | ||
| int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, check if the file exists and get its structure
fd "moeUtils.cu" --type f | head -5Repository: flashinfer-ai/flashinfer
Length of output: 132
🏁 Script executed:
#!/bin/bash
# Read the specific lines mentioned in the review (119-121, 319, 448)
if [ -f "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu" ]; then
echo "=== Lines 119-121 ==="
sed -n '119,121p' "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu"
echo ""
echo "=== Lines 315-325 (around line 319) ==="
sed -n '315,325p' "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu"
echo ""
echo "=== Lines 444-454 (around line 448) ==="
sed -n '444,454p' "csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu"
fiRepository: flashinfer-ai/flashinfer
Length of output: 1395
🏁 Script executed:
#!/bin/bash
# Search for getMultiProcessorCount and understand its implementation
rg -n "getMultiProcessorCount" --type cpp -C 3Repository: flashinfer-ai/flashinfer
Length of output: 4808
🏁 Script executed:
#!/bin/bash
# Check how this function is defined and where it comes from
rg -n "def getMultiProcessorCount|getMultiProcessorCount\s*\(" --type cpp --type h -A 5 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 4806
Static SM count caching prevents proper multi-GPU support.
The getMultiProcessorCount() function queries the current device dynamically, but caching its result in a static variable means it is evaluated only once per process. If the application switches to a different device, this cached value becomes incorrect. Consider using thread_local (as used elsewhere in the codebase, e.g., fp4Op.cpp) or querying per-call if multi-GPU support is required.
This pattern appears at lines 119, 319, and 448.
🤖 Prompt for AI Agents
In `@csrc/nv_internal/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu` around
lines 119 - 121, The cached static smCount computed via
tensorrt_llm::common::getMultiProcessorCount() prevents correct behavior when
the process switches CUDA devices; change the declaration so smCount is
evaluated per-call (e.g., remove static or make it thread_local as done
elsewhere) and recompute it before calculating maxBlocksPerSM and blocks; update
all occurrences where smCount is defined (the instances that call
getMultiProcessorCount()) so each call queries the current device rather than
reusing a process-global cached value.
| struct SwigluBiasAdaptor { | ||
| constexpr static bool IS_GLU = true; | ||
| float alpha = 1.0f; | ||
| float beta = 0.0f; | ||
| float limit = std::numeric_limits<float>::infinity(); | ||
|
|
||
| template <class T> | ||
| __device__ T operator()(T const& gate, T const& linear) const { | ||
| cutlass::epilogue::thread::Sigmoid<T> fn{}; | ||
| T linear_clamped = cutlass::maximum<T>{}(cutlass::minimum<T>{}(linear, limit), -limit); | ||
| T gate_clamped = cutlass::minimum<T>{}(gate, limit); | ||
| return gate_clamped * fn(gate_clamped * alpha) * (linear_clamped + beta); | ||
| } | ||
| }; |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
SwiGLU activation function implementation clamping
💡 Result:
Short answer: clamp the gate input to a reasonable float32-safe range (e.g. [-50, 50] or [-88, 88]) before applying SiLU/Sigmoid when you implement SwiGLU manually or in custom kernels; most frameworks' builtin sigmoid/silu are already numerically-stable but clamping avoids exp overflow/underflow in custom code. [1][2][3][4]
Example (PyTorch):
gate, up = x.chunk(2, dim=-1)
gate = gate.clamp(-50.0, 50.0) # or clamp(-88.0, 88.0)
out = torch.nn.functional.silu(gate) * up
Why: Sigmoid/SiLU involve exp() and large inputs can overflow; numerical-stable implementations or conditional formulas avoid this, but simple clamping is a safe, common shortcut. [2][3]
Sources:
- PyTorch SiLU docs (SiLU = Swish). [1]
- Notes / examples on numerically-stable sigmoid implementations and clipping. [2][3]
- Megatron / fused SwiGLU kernel docs showing fused SwiGLU implementations (practical kernel context). [4]
🌐 Web query:
SwiGLU gate linear output sigmoid clamping numerical stability
💡 Result:
Short answer: SwiGLU itself isn’t specially unstable, but practical numerical issues come from (a) computing sigmoid/SiLU (exp overflow/underflow in extreme inputs), and (b) the elementwise multiplication between a (bounded) gated path and an unbounded linear path which can amplify large values and produce INF/NaN or tiny gradients in low-precision. Use stable built-ins, clamp where needed, and apply usual FFN safeguards. Sources: GLU-variants discussion and SwiGLU descriptions, and SiLU docs. [1][2][3]
Practical checklist (concise):
- Use a numerically stable implementation: call a fused/engine-optimized SiLU/torch.nn.functional.silu rather than hand-writing x * torch.sigmoid(x) when possible (better perf and lower fp16 risk). [3]
- Prevent extreme inputs to sigmoid/exp: ensure reasonable weight init (Xavier/Kaiming) and consider activation normalization (LayerNorm) before FFN to keep pre-activations in safe range. [1][3]
- Avoid catastrophic multiplicative growth: after SwiGLU, use an output projection with appropriate weight scale and/or scalar gating (learned scale) to limit dynamic range. [1][2]
- Clamp only when necessary: if you must clamp sigmoid outputs (e.g., to avoid exact 0/1 in downstream log/div), use small eps like 1e-6 (torch.clamp(sigmoid(x), eps, 1-eps)), but prefer fixing sources of extreme values first. Over-clamping changes gradients and representational capacity. [1]
- FP16/AMP tips: use PyTorch’s fused SiLU and mixed-precision autocast, keep LayerNorm in fp32 or use loss-scaling to avoid underflow/overflow; avoid in-place ops on tensors used in backward. [3]
- Debugging NaNs/Infs: log max/min of pre-activation (both paths), check gradient norms, temporarily clip gradients and activations to locate which path explodes (linear path vs gated path). [1][2]
If you want, I can:
- show compact PyTorch-safe code for SwiGLU with fp16/AMP best practices and optional sigmoid clamping; or
- run a short experiment to demonstrate where NaNs arise and how clamping/scale fixes them.
🏁 Script executed:
# Search for other SwiGLU implementations or references in the codebase
rg -i "swiglu|swiglu" --type cpp --type h --type cuh -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Look for tests or documentation related to this kernel
fd -e "test" -e "md" | head -20 | xargs rg -l -i "swiglu\|moe_kernels" 2>/dev/null || trueRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are comments or context around the clamping logic
cat -n csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh | sed -n '40,80p'Repository: flashinfer-ai/flashinfer
Length of output: 1251
SwigluBiasAdaptor gate clamping deviates from standard SwiGLU numerical stability practices.
The gate value is only clamped from above (line 66: cutlass::minimum<T>{}(gate, limit)), while linear is clamped symmetrically to [-limit, limit]. Standard SwiGLU implementations clamp the gate input symmetrically (e.g., [-50, 50]) before applying sigmoid to prevent exp overflow/underflow in the sigmoid computation. Very negative gate values in this implementation could cause numerical instability.
🤖 Prompt for AI Agents
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh`
around lines 56 - 69, The gate is only clamped from above which can let very
negative values pass into the sigmoid and cause numerical instability; in
SwigluBiasAdaptor::operator() clamp the gate symmetrically (e.g., use
cutlass::maximum<T>{}(cutlass::minimum<T>{}(gate, limit), -limit)) before
computing the sigmoid and using it for the gate multiplication so the sigmoid
receives a bounded input and numerical overflow/underflow is avoided.
| def create_finalize_fusion_tensors( | ||
| seq_len: int, | ||
| topk: int, | ||
| permuted_m: int, | ||
| group_m_list: List[int], | ||
| mma_tiler_mn: Tuple[int, int], | ||
| final_scale_dtype: torch.dtype = torch.float32, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """Create tensors required for finalize fusion. | ||
|
|
||
| This function creates the mapping tensor and final scale tensor needed | ||
| for the fused finalize operation in GEMM2. | ||
|
|
||
| Args: | ||
| seq_len: Number of output tokens (original sequence length) | ||
| topk: Number of experts per token | ||
| permuted_m: Total permuted M dimension (sum of aligned group sizes) | ||
| group_m_list: List of actual (unaligned) M values per expert | ||
| mma_tiler_mn: MMA tile shape (M, N) for alignment | ||
| final_scale_dtype: Data type for token final scales. Default: torch.float32 | ||
|
|
||
| Returns: | ||
| Tuple of: | ||
| - permuted_idx_to_expanded_idx: Mapping tensor, shape (permuted_m,), int32 | ||
| Maps permuted row index to expanded_idx = token_idx * topk + k_idx | ||
| Invalid rows are marked with -1. | ||
| - token_final_scales: Router scale tensor, shape (seq_len, topk), final_scale_dtype | ||
| Normalized routing weights for each (token, topk) pair. | ||
|
|
||
| Example: | ||
| >>> seq_len, topk, num_experts = 4096, 8, 8 | ||
| >>> group_m_list = [512, 480, 256, 320, 640, 512, 384, 704] # Tokens per expert | ||
| >>> permuted_m = sum(align_to(m, 256) for m in group_m_list) # Aligned total | ||
| >>> | ||
| >>> permuted_idx_to_expanded_idx, token_final_scales = create_finalize_fusion_tensors( | ||
| ... seq_len=seq_len, | ||
| ... topk=topk, | ||
| ... permuted_m=permuted_m, | ||
| ... group_m_list=group_m_list, | ||
| ... mma_tiler_mn=(256, 128), | ||
| ... ) | ||
| """ | ||
| m_aligned = mma_tiler_mn[0] | ||
|
|
||
| # Initialize mapping tensor with -1 (invalid) | ||
| permuted_idx_to_expanded_idx = torch.empty( | ||
| (permuted_m,), dtype=torch.int32, device="cuda" | ||
| ).fill_(-1) | ||
|
|
||
| # Create normalized token final scales | ||
| token_final_scales = torch.rand( | ||
| seq_len, topk, dtype=final_scale_dtype, device="cuda" | ||
| ) | ||
| token_final_scales = token_final_scales / token_final_scales.sum( | ||
| dim=1, keepdim=True | ||
| ) | ||
|
|
||
| start_idx = 0 | ||
| for group_idx, m_per_group in enumerate(group_m_list): | ||
| if m_per_group > 0: | ||
| # Sequential/Blocked assignment for better atomic add memory access | ||
| # Experts are grouped into sets of size topk. | ||
| # Expert Set S (experts S*topk ... S*topk+topk-1) serves a contiguous block of tokens. | ||
| # This ensures that within an expert, we process tokens T, T+1, T+2... sequentially. | ||
|
|
||
| expert_set_idx = group_idx // topk | ||
| k_in_set = group_idx % topk | ||
|
|
||
| # Start token index for this expert set | ||
| start_token = expert_set_idx * m_per_group | ||
|
|
||
| # Generate sequential token indices for this expert | ||
| token_indices = torch.arange( | ||
| start_token, start_token + m_per_group, dtype=torch.int32, device="cuda" | ||
| ) | ||
| token_indices = token_indices % seq_len | ||
|
|
||
| # expanded_idx = token_idx * topk + k | ||
| expanded_idx = token_indices * topk + k_in_set | ||
|
|
||
| permuted_idx_to_expanded_idx[start_idx : (start_idx + m_per_group)] = ( | ||
| expanded_idx | ||
| ) | ||
|
|
||
| # Move to next aligned group | ||
| m_aligned_per_group = ((m_per_group + m_aligned - 1) // m_aligned) * m_aligned | ||
| start_idx += m_aligned_per_group | ||
|
|
||
| return permuted_idx_to_expanded_idx, token_final_scales |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how create_finalize_fusion_tensors is used in the codebase
rg -n "create_finalize_fusion_tensors" --type=py -C 3Repository: flashinfer-ai/flashinfer
Length of output: 5255
🏁 Script executed:
#!/bin/bash
# Read the file to see the function definition and decorators
head -n 180 flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | tail -n +1Repository: flashinfer-ai/flashinfer
Length of output: 6791
🏁 Script executed:
#!/bin/bash
# Read more of the file to see _get_compiled_finalize_kernel and the main API function
sed -n '180,250p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 2432
🏁 Script executed:
#!/bin/bash
# Read the main API function to check for decorators
sed -n '250,400p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 6876
🏁 Script executed:
#!/bin/bash
# Look for test files that might use this function
find . -name "*test*.py" -type f | head -20Repository: flashinfer-ai/flashinfer
Length of output: 925
🏁 Script executed:
#!/bin/bash
# Search for tests that use create_finalize_fusion_tensors
rg -n "create_finalize_fusion_tensors\|blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4" tests/ --type=pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check if there are any actual uses of this function outside of documentation
rg -n "create_finalize_fusion_tensors" --type=py --exclude-dir=tests -C 5 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 483
🏁 Script executed:
#!/bin/bash
# Check if there are any actual uses of this function outside of documentation
rg -n "create_finalize_fusion_tensors" --type=py -C 5 | grep -v "Example:" | grep -v ">>>" | head -80Repository: flashinfer-ai/flashinfer
Length of output: 6026
🏁 Script executed:
#!/bin/bash
# Check other similar API functions in the same file to see decorator patterns
rg -n "@functools.cache|@flashinfer_api" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 87
🏁 Script executed:
#!/bin/bash
# Check other functions that create tensors for similar kernels to see if they have decorators
rg -B 2 "def create_" flashinfer/cute_dsl/*.py --type=py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1259
🏁 Script executed:
#!/bin/bash
# Check if create_gather_gemm_tensors has any decorators
rg -B 5 "def create_gather_gemm_tensors" flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 207
🏁 Script executed:
#!/bin/bash
# Check the full context of create_gather_gemm_tensors to see if it also uses random values
sed -n '1,150p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py | grep -A 50 "def create_gather_gemm_tensors"Repository: flashinfer-ai/flashinfer
Length of output: 2249
🏁 Script executed:
#!/bin/bash
# Look at the end of create_gather_gemm_tensors function
sed -n '50,150p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py | tail -80Repository: flashinfer-ai/flashinfer
Length of output: 3232
🏁 Script executed:
#!/bin/bash
# Get the complete create_gather_gemm_tensors function
sed -n '40,150p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 4151
🏁 Script executed:
#!/bin/bash
# Get more of the create_gather_gemm_tensors function - the end
sed -n '120,180p' flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 2381
🏁 Script executed:
#!/bin/bash
# Check if there are imports of functools.cache or similar caching in the file
rg -n "functools|cache" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 881
Add @flashinfer_api decorator to public API function and clarify tensor initialization approach.
The function create_finalize_fusion_tensors is exported as a public API (line 68 in __all__) but lacks the @flashinfer_api decorator used by other public functions in the module. Add the decorator to enable API logging/debugging via FLASHINFER_LOGLEVEL.
Additionally, token_final_scales are initialized with random normalized values without any way for users to provide pre-computed routing weights. The function should either:
- Accept optional
token_final_scalesparameter for real routing weights, or - Clearly document that returned scales are placeholder values meant for testing/examples only
The kernel caching approach using _finalize_kernel_cache is appropriate and does not need @functools.cache since it correctly caches by tactic parameters only, allowing kernel reuse across different problem sizes.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 72 - 160, Add the missing `@flashinfer_api` decorator to the public
function create_finalize_fusion_tensors to enable FLASHINFER_LOGLEVEL-based API
logging; also modify its signature to accept an optional token_final_scales:
Optional[torch.Tensor] = None parameter (dtype final_scale_dtype, shape
(seq_len, topk)) and, if provided, validate shape/dtype and use it instead of
generating random values, otherwise keep the current randomized normalized
initialization but update the docstring to state these are placeholder/test
values; leave the existing _finalize_kernel_cache behavior unchanged.
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
Outdated
Show resolved
Hide resolved
| # Check if we're doing FP4 quantization | ||
| generate_sfc = c_dtype == "float4_e2m1fn" | ||
| if generate_sfc: | ||
| if global_scale is None: | ||
| raise ValueError("global_scale is required when c_dtype is 'float4_e2m1fn'") | ||
|
|
||
| # Create output tensor if not provided | ||
| if out is None: | ||
| if generate_sfc: | ||
| # FP4 output: 2 values per byte | ||
| out = torch.empty( | ||
| (permuted_m, intermediate_size // 2), | ||
| dtype=torch.uint8, | ||
| device=a.device, | ||
| ) | ||
| else: | ||
| out = torch.empty( | ||
| (permuted_m, intermediate_size), | ||
| dtype=cutlass_to_torch_dtype(c_dtype_cutlass), | ||
| device=a.device, | ||
| ) | ||
|
|
||
| # Create output scale tensor if needed and not provided | ||
| if generate_sfc and out_scale is None: | ||
| # Scale factor layout for output | ||
| scale_intermediate_size = intermediate_size // sf_vec_size | ||
| # MMA-compatible scale factor shape | ||
| out_scale = torch.empty( | ||
| (32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1), | ||
| dtype=torch.uint8, # FP8 E4M3 | ||
| device=a.device, | ||
| ) |
There was a problem hiding this comment.
Validate FP4 output scale layout assumptions before allocation.
out_scale uses permuted_m // 128 and scale_intermediate_size // 4; without divisibility checks, the buffer can be undersized and the kernel may write out of bounds. Add explicit guards for FP4 output.
🐛 Proposed fix
generate_sfc = c_dtype == "float4_e2m1fn"
if generate_sfc:
if global_scale is None:
raise ValueError("global_scale is required when c_dtype is 'float4_e2m1fn'")
+ if permuted_m % 128 != 0:
+ raise ValueError("permuted_m must be a multiple of 128 for FP4 output scale layout")
+ if intermediate_size % (sf_vec_size * 4) != 0:
+ raise ValueError(
+ "intermediate_size must be divisible by sf_vec_size * 4 for FP4 output scale layout"
+ )🧰 Tools
🪛 Ruff (0.14.13)
245-245: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py`
around lines 241 - 272, The FP4 output scale allocation assumes permuted_m is
divisible by 128 and scale_intermediate_size (computed from intermediate_size //
sf_vec_size) is divisible by 4; add explicit validation in the generate_sfc
branch before allocating out_scale: check that permuted_m % 128 == 0 and
scale_intermediate_size % 4 == 0 (using the existing symbols generate_sfc,
out_scale, permuted_m, intermediate_size, sf_vec_size, scale_intermediate_size),
and raise a clear ValueError if either check fails so the buffer size cannot be
undersized and cause out-of-bounds writes.
| # Validate inputs | ||
| assert a.device.type == "cuda", "Input tensors must be on CUDA device" | ||
| assert b.device.type == "cuda", "Input tensors must be on CUDA device" |
There was a problem hiding this comment.
Replace assert with proper exceptions for runtime validation.
assert statements are stripped when Python runs with -O (optimized mode). For production code validating user inputs, use explicit if checks with ValueError/RuntimeError.
🔧 Suggested fix
- assert a.device.type == "cuda", "Input tensors must be on CUDA device"
- assert b.device.type == "cuda", "Input tensors must be on CUDA device"
+ if a.device.type != "cuda":
+ raise ValueError("Input tensor 'a' must be on CUDA device")
+ if b.device.type != "cuda":
+ raise ValueError("Input tensor 'b' must be on CUDA device")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Validate inputs | |
| assert a.device.type == "cuda", "Input tensors must be on CUDA device" | |
| assert b.device.type == "cuda", "Input tensors must be on CUDA device" | |
| # Validate inputs | |
| if a.device.type != "cuda": | |
| raise ValueError("Input tensor 'a' must be on CUDA device") | |
| if b.device.type != "cuda": | |
| raise ValueError("Input tensor 'b' must be on CUDA device") |
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py` around lines 361
- 363, Replace the two assert statements that check tensor device types (the
checks using a.device.type == "cuda" and b.device.type == "cuda") with explicit
runtime validation: use if statements that raise a clear exception (e.g.,
ValueError or RuntimeError) when a or b are not on CUDA, and include a helpful
message mentioning which tensor failed the check; update the validation block in
blockscaled_contiguous_grouped_gemm.py to perform these explicit checks instead
of using assert so they are not stripped in optimized Python.
| lambda shapes, dtype, device: torch.randint( | ||
| 0, | ||
| 8, | ||
| shapes, | ||
| dtype=torch.int32, | ||
| device=device, # num_experts=8 typical | ||
| ), |
There was a problem hiding this comment.
Hardcoded num_experts=8 assumption in initializer.
The comment says "num_experts=8 typical" but this initializer is used generically. If tuning profiles are generated with different expert counts, this could produce invalid expert indices.
🔧 Suggested fix
Consider parameterizing this or using the actual num_experts from the runner instance:
# 2: token_selected_experts - expert indices (int32, 0 to num_experts-1)
lambda shapes, dtype, device: torch.randint(
0,
- 8,
+ max(shapes[1], 1), # Use top_k dimension as proxy, or accept num_experts param
shapes,
dtype=torch.int32,
- device=device, # num_experts=8 typical
+ device=device,
),Or document that this assumes a minimum of 8 experts during profiling.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| lambda shapes, dtype, device: torch.randint( | |
| 0, | |
| 8, | |
| shapes, | |
| dtype=torch.int32, | |
| device=device, # num_experts=8 typical | |
| ), | |
| lambda shapes, dtype, device: torch.randint( | |
| 0, | |
| max(shapes[1], 1), # Use top_k dimension as proxy, or accept num_experts param | |
| shapes, | |
| dtype=torch.int32, | |
| device=device, | |
| ), |
🧰 Tools
🪛 Ruff (0.14.13)
260-260: Unused lambda argument: dtype
(ARG005)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/tuner.py` around lines 260 - 266, The initializer lambda
in tuner.py currently samples expert indices with a hardcoded range of 0..7
(comment "num_experts=8 typical"), which can produce invalid indices for models
with different expert counts; update the lambda used in the tuner initialization
(the anonymous function that calls torch.randint) to derive the upper bound from
the actual model/runner expert count (e.g., use a passed-in num_experts or
runner.num_experts) or at minimum clamp to min/max to avoid out-of-range values;
locate the lambda in tuner.py and replace the hardcoded 8 with the dynamic
num_experts value (or a safe expression like max(1, num_experts)) so sampled
indices are always valid.
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In
`@flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 482-489: The sm_count parameter is currently ignored; after
computing max_active_clusters via get_max_active_clusters(cluster_shape_mn[0] *
cluster_shape_mn[1]), clamp it to sm_count (which you already set from
get_num_sm(a.device) when None) by replacing max_active_clusters with
min(max_active_clusters, sm_count) so the API honor sm_count; reference symbols:
sm_count, get_num_sm, get_max_active_clusters, max_active_clusters,
cluster_shape_mn.
- Around line 471-481: The out_scale allocation in the
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion code assumes permuted_m
and intermediate_size are divisible by the vectorization factors (128 and
sf_vec_size/4) and can cause OOB writes; add explicit validation before the
allocation (when generate_sfc is true and out_scale is None) to assert or raise
a clear error if permuted_m % 128 != 0 or intermediate_size % (sf_vec_size * 4)
!= 0 (or the equivalent divisibility used to compute scale_intermediate_size and
the shape dims), and adjust the calculation of scale_intermediate_size
accordingly to use integer division only after validation so out_scale has the
correct size for the rest of the code paths referencing out_scale.
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 416-423: The code computes max_active_clusters via
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) but currently
ignores the sm_count parameter; update the logic so that after computing
max_active_clusters you clamp it using sm_count (use get_num_sm(a.device) only
if sm_count is None), e.g., determine sm_count via sm_count = sm_count or
get_num_sm(a.device), then set max_active_clusters = min(max_active_clusters,
sm_count) so kernel scheduling respects the provided sm_count limit; adjust
references around sm_count, get_num_sm, get_max_active_clusters,
max_active_clusters and cluster_shape_mn accordingly.
- Around line 459-471: Validate token_final_scales.dtype explicitly before
mapping to Cutlass types: handle torch.float32, torch.bfloat16, and
torch.float16 (set token_scales_dtype to cutlass.Float32, cutlass.BFloat16,
cutlass.Float16 respectively) and raise a clear error if any other dtype is
passed; then call make_ptr(token_scales_dtype, token_final_scales.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16). Locate symbols token_final_scales,
token_scales_dtype, make_ptr, and the cutlass type mappings to implement this
guard and error path.
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py`:
- Around line 416-437: The computed max_active_clusters from
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) is not being
clamped by the sm_count limit; after computing max_active_clusters, clamp it
with sm_count (e.g., max_active_clusters = min(max_active_clusters, sm_count))
so the API contract is respected. Ensure sm_count is defined (it may be set via
get_num_sm(a.device)) before the clamp and apply this change right after the
call to get_max_active_clusters in the block that includes sm_count, get_num_sm,
and cluster_shape_mn.
- Around line 120-130: The code allows padding when permuted_m > valid_m but
doesn't ensure the padding size is a multiple of the tile size (mma_tiler_m),
which can cause mismatched lengths and OOB access; inside the block that handles
permuted_m > valid_m (the same place where num_padding_tiles is computed and
tile_idx_to_group_idx_list is extended), validate that (permuted_m - valid_m) %
mma_tiler_m == 0 and if not raise a ValueError explaining that permuted_m -
valid_m must be divisible by mma_tiler_m; keep the existing behavior of
computing num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m and
extending tile_idx_to_group_idx_list with the padding only after this check.
In `@flashinfer/cute_dsl/utils.py`:
- Around line 84-110: The cached HardwareInfo and get_max_active_clusters are
device-agnostic and will return wrong values on multi-GPU machines; update
caching to be device-aware: change get_hardware_info to accept an optional
device identifier (or obtain current device internally) and replace the single
_hardware_info_cache with a per-device cache (e.g., dict keyed by device id) for
the HardwareInfo singleton; also update get_max_active_clusters to include the
device id in its cache key (remove or replace the `@functools.cache` usage with a
device-keyed cache or make the function accept a device parameter so caching is
per-device). Ensure you reference and update the symbols _hardware_info_cache,
get_hardware_info, get_max_active_clusters, and the use of `@functools.cache`
accordingly so each GPU gets correct, device-specific values.
♻️ Duplicate comments (4)
flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)
73-161: Public helper should be API‑logged and allow caller‑provided scales.
This mirrors earlier feedback: add@flashinfer_apiand allow callers to pass real routing scales (or clearly document randomized placeholders).As per coding guidelines, please add `@flashinfer_api` to public API helpers.♻️ Suggested update
+@flashinfer_api def create_finalize_fusion_tensors( seq_len: int, topk: int, permuted_m: int, group_m_list: List[int], mma_tiler_mn: Tuple[int, int], final_scale_dtype: torch.dtype = torch.float32, + token_final_scales: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ - # Create normalized token final scales - token_final_scales = torch.rand( - seq_len, topk, dtype=final_scale_dtype, device="cuda" - ) - token_final_scales = token_final_scales / token_final_scales.sum( - dim=1, keepdim=True - ) + # Create or validate token final scales + if token_final_scales is None: + token_final_scales = torch.rand( + seq_len, topk, dtype=final_scale_dtype, device="cuda" + ) + token_final_scales = token_final_scales / token_final_scales.sum( + dim=1, keepdim=True + ) + else: + if token_final_scales.shape != (seq_len, topk): + raise ValueError("token_final_scales must have shape (seq_len, topk)") + if token_final_scales.dtype != final_scale_dtype: + raise ValueError("token_final_scales dtype must match final_scale_dtype")flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py (2)
69-82: Remove unused shape args and switch to@functools.cache.
The cache key currently includes unused shape parameters, causing unbounded cache growth; andlru_cache(maxsize=None)should be replaced per guidelines.As per coding guidelines, use `@functools.cache` for module‑level caching.🐛 Suggested fix
-@functools.lru_cache(maxsize=None) +@functools.cache def _get_compiled_swiglu_kernel( - permuted_m: int, - n: int, # This is 2*intermediate_size - k: int, - num_experts: int, ab_dtype_name: str, sf_dtype_name: str, c_dtype_name: str, sf_vec_size: int, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], vectorized_f32: bool, ):- gemm, _, _, _ = _get_compiled_swiglu_kernel( - permuted_m=permuted_m, - n=n, - k=k, - num_experts=num_experts, + gemm, _, _, _ = _get_compiled_swiglu_kernel( ab_dtype_name=ab_dtype, sf_dtype_name=sf_dtype, c_dtype_name=c_dtype, sf_vec_size=sf_vec_size, mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=cluster_shape_mn, vectorized_f32=vectorized_f32, )
241-272: Validate FP4 out_scale layout divisibility.
The layout assumespermuted_mandintermediate_sizealignment; without checks, the buffer can be undersized.🔧 Suggested validation
generate_sfc = c_dtype == "float4_e2m1fn" if generate_sfc: if global_scale is None: raise ValueError("global_scale is required when c_dtype is 'float4_e2m1fn'") + if permuted_m % 128 != 0: + raise ValueError("permuted_m must be a multiple of 128 for FP4 output scale layout") + if intermediate_size % (sf_vec_size * 4) != 0: + raise ValueError( + "intermediate_size must be divisible by sf_vec_size * 4 for FP4 output scale layout" + )flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py (1)
258-286: Remove unused shape params and switch to@functools.cache.
Unused shape args inflate the cache key without effect, andlru_cache(maxsize=None)should be replaced per guidelines.As per coding guidelines, use `@functools.cache` for module‑level caching.🐛 Suggested fix
-@functools.lru_cache(maxsize=None) +@functools.cache def _get_compiled_kernel( - permuted_m: int, - n: int, - k: int, - num_experts: int, ab_dtype_name: str, sf_dtype_name: str, c_dtype_name: str, sf_vec_size: int, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], ):- gemm, _, _, _ = _get_compiled_kernel( - permuted_m=permuted_m, - n=n, - k=k, - num_experts=num_experts, + gemm, _, _, _ = _get_compiled_kernel( ab_dtype_name=ab_dtype, sf_dtype_name=sf_dtype, c_dtype_name=c_dtype, sf_vec_size=sf_vec_size, mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=cluster_shape_mn, )
🧹 Nitpick comments (3)
flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (1)
76-187: Add API logging to the public tensor‑creation helper.
create_gather_gemm_tensorsis exported in__all__but lacks@flashinfer_api.As per coding guidelines, please add `@flashinfer_api` to public API helpers.♻️ Suggested change
+@flashinfer_api def create_gather_gemm_tensors( seq_len: int, topk: int, group_m_list: List[int], mma_tiler_m: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, List[int]]:flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py (2)
72-76: Add API logging forcreate_tile_mapping.
This helper is exported in__all__but isn’t decorated.As per coding guidelines, please add `@flashinfer_api` to public API helpers.♻️ Suggested change
+@flashinfer_api def create_tile_mapping( group_m_list: torch.Tensor, mma_tiler_m: int, permuted_m: Optional[int] = None, ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
145-152: Add API logging forcreate_scale_factor_tensor.
This exported helper should participate in standard API logging.As per coding guidelines, please add `@flashinfer_api` to public API helpers.♻️ Suggested change
+@flashinfer_api def create_scale_factor_tensor( l: int, mn: int, k: int, sf_vec_size: int, dtype: Type[cutlass.Numeric], ) -> Tuple[torch.Tensor, cute.Tensor, torch.Tensor]:
| # Create output scale tensor if needed and not provided | ||
| if generate_sfc and out_scale is None: | ||
| # Scale factor layout for output | ||
| scale_intermediate_size = intermediate_size // sf_vec_size | ||
| # MMA-compatible scale factor shape | ||
| out_scale = torch.empty( | ||
| (32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1), | ||
| dtype=torch.uint8, # FP8 E4M3 | ||
| device=a.device, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Guard FP4 out_scale layout divisibility to prevent OOB writes.
out_scale sizing assumes permuted_m and intermediate_size are aligned; without checks the buffer can be undersized.
🔧 Suggested validation
if generate_sfc and out_scale is None:
# Scale factor layout for output
scale_intermediate_size = intermediate_size // sf_vec_size
+ if permuted_m % 128 != 0:
+ raise ValueError("permuted_m must be a multiple of 128 for FP4 output scale layout")
+ if intermediate_size % (sf_vec_size * 4) != 0:
+ raise ValueError(
+ "intermediate_size must be divisible by sf_vec_size * 4 for FP4 output scale layout"
+ )
# MMA-compatible scale factor shape
out_scale = torch.empty(📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Create output scale tensor if needed and not provided | |
| if generate_sfc and out_scale is None: | |
| # Scale factor layout for output | |
| scale_intermediate_size = intermediate_size // sf_vec_size | |
| # MMA-compatible scale factor shape | |
| out_scale = torch.empty( | |
| (32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1), | |
| dtype=torch.uint8, # FP8 E4M3 | |
| device=a.device, | |
| ) | |
| # Create output scale tensor if needed and not provided | |
| if generate_sfc and out_scale is None: | |
| # Scale factor layout for output | |
| scale_intermediate_size = intermediate_size // sf_vec_size | |
| if permuted_m % 128 != 0: | |
| raise ValueError("permuted_m must be a multiple of 128 for FP8 output scale layout") | |
| if intermediate_size % (sf_vec_size * 4) != 0: | |
| raise ValueError( | |
| "intermediate_size must be divisible by sf_vec_size * 4 for FP8 output scale layout" | |
| ) | |
| # MMA-compatible scale factor shape | |
| out_scale = torch.empty( | |
| (32, 4, permuted_m // 128, 4, scale_intermediate_size // 4, 1), | |
| dtype=torch.uint8, # FP8 E4M3 | |
| device=a.device, | |
| ) |
🤖 Prompt for AI Agents
In
`@flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 471 - 481, The out_scale allocation in the
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion code assumes permuted_m
and intermediate_size are divisible by the vectorization factors (128 and
sf_vec_size/4) and can cause OOB writes; add explicit validation before the
allocation (when generate_sfc is true and out_scale is None) to assert or raise
a clear error if permuted_m % 128 != 0 or intermediate_size % (sf_vec_size * 4)
!= 0 (or the equivalent divisibility used to compute scale_intermediate_size and
the shape dims), and adjust the calculation of scale_intermediate_size
accordingly to use integer division only after validation so out_scale has the
correct size for the rest of the code paths referencing out_scale.
| # Get SM count | ||
| if sm_count is None: | ||
| sm_count = get_num_sm(a.device) | ||
|
|
||
| # Compute max active clusters (cached to avoid expensive HardwareInfo queries) | ||
| max_active_clusters = get_max_active_clusters( | ||
| cluster_shape_mn[0] * cluster_shape_mn[1] | ||
| ) |
There was a problem hiding this comment.
sm_count parameter is unused.
Honor the API by clamping max_active_clusters to sm_count.
🔧 Suggested fix
- max_active_clusters = get_max_active_clusters(
- cluster_shape_mn[0] * cluster_shape_mn[1]
- )
+ max_active_clusters = min(
+ get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
+ sm_count,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Get SM count | |
| if sm_count is None: | |
| sm_count = get_num_sm(a.device) | |
| # Compute max active clusters (cached to avoid expensive HardwareInfo queries) | |
| max_active_clusters = get_max_active_clusters( | |
| cluster_shape_mn[0] * cluster_shape_mn[1] | |
| ) | |
| # Get SM count | |
| if sm_count is None: | |
| sm_count = get_num_sm(a.device) | |
| # Compute max active clusters (cached to avoid expensive HardwareInfo queries) | |
| max_active_clusters = min( | |
| get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]), | |
| sm_count, | |
| ) |
🤖 Prompt for AI Agents
In
`@flashinfer/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 482 - 489, The sm_count parameter is currently ignored; after
computing max_active_clusters via get_max_active_clusters(cluster_shape_mn[0] *
cluster_shape_mn[1]), clamp it to sm_count (which you already set from
get_num_sm(a.device) when None) by replacing max_active_clusters with
min(max_active_clusters, sm_count) so the API honor sm_count; reference symbols:
sm_count, get_num_sm, get_max_active_clusters, max_active_clusters,
cluster_shape_mn.
| # Get SM count | ||
| if sm_count is None: | ||
| sm_count = get_num_sm(a.device) | ||
|
|
||
| # Compute max active clusters (cached to avoid expensive HardwareInfo queries) | ||
| max_active_clusters = get_max_active_clusters( | ||
| cluster_shape_mn[0] * cluster_shape_mn[1] | ||
| ) |
There was a problem hiding this comment.
sm_count is currently ignored.
The API advertises sm_count, but it doesn’t affect kernel scheduling. Clamp max_active_clusters with it to respect user limits and avoid oversubscription.
🔧 Suggested fix
- max_active_clusters = get_max_active_clusters(
- cluster_shape_mn[0] * cluster_shape_mn[1]
- )
+ max_active_clusters = min(
+ get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
+ sm_count,
+ )🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 416 - 423, The code computes max_active_clusters via
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) but currently
ignores the sm_count parameter; update the logic so that after computing
max_active_clusters you clamp it using sm_count (use get_num_sm(a.device) only
if sm_count is None), e.g., determine sm_count via sm_count = sm_count or
get_num_sm(a.device), then set max_active_clusters = min(max_active_clusters,
sm_count) so kernel scheduling respects the provided sm_count limit; adjust
references around sm_count, get_num_sm, get_max_active_clusters,
max_active_clusters and cluster_shape_mn accordingly.
| # Token final scales - determine dtype and create pointer | ||
| if token_final_scales.dtype == torch.float32: | ||
| token_scales_dtype = cutlass.Float32 | ||
| elif token_final_scales.dtype == torch.bfloat16: | ||
| token_scales_dtype = cutlass.BFloat16 | ||
| else: | ||
| token_scales_dtype = cutlass.Float16 | ||
| token_scales_ptr = make_ptr( | ||
| token_scales_dtype, | ||
| token_final_scales.data_ptr(), | ||
| cute.AddressSpace.gmem, | ||
| assumed_align=16, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the file
find . -type f -name "blockscaled_contiguous_grouped_gemm_finalize_fusion.py" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 231
🏁 Script executed:
# Check if file exists and examine the target lines
if [ -f "flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py" ]; then
wc -l "flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py"
fiRepository: flashinfer-ai/flashinfer
Length of output: 146
🏁 Script executed:
# Read the file around the specified lines
sed -n '450,480p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 1187
🏁 Script executed:
# Find the function that contains this code block
sed -n '400,480p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 3049
🏁 Script executed:
# Search for calls to this function/module to understand dtype usage
rg "token_final_scales" flashinfer/ -B 2 -A 2 --max-count=20Repository: flashinfer-ai/flashinfer
Length of output: 27246
🏁 Script executed:
# Find the function containing lines 459-471
sed -n '380,475p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3404
🏁 Script executed:
# Get full function definition starting from earlier lines
sed -n '340,480p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 5175
🏁 Script executed:
# Search for function definition pattern around line 459
grep -n "^def " flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 212
🏁 Script executed:
# Check for torch.float16 usage in other parts of the codebase
rg "torch\.float16" flashinfer/cute_dsl/ -B 2 -A 2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 2597
🏁 Script executed:
# Check if similar dtype validation exists elsewhere
rg "raise ValueError.*dtype" flashinfer/ -B 1 -A 1 | grep -A 2 -B 2 "float" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1939
🏁 Script executed:
# Check the function signature and any decorators
sed -n '258,280p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 909
🏁 Script executed:
# Search for any decorators or imports related to caching/API decorators in this file
head -50 flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyRepository: flashinfer-ai/flashinfer
Length of output: 2196
🏁 Script executed:
# Check if `@functools.cache` or `@flashinfer_api` is used elsewhere in flashinfer/**/*.py files
rg "@functools\.cache|@flashinfer_api" flashinfer/ --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 858
🏁 Script executed:
# Check for validation patterns similar to what's suggested
rg "if.*\.dtype ==" flashinfer/cute_dsl/ --type py -B 1 -A 4 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1935
🏁 Script executed:
# Check the complete function to understand validation context
sed -n '258,550p' flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py | tail -50Repository: flashinfer-ai/flashinfer
Length of output: 1452
🏁 Script executed:
# Look at how token_final_scales is used downstream to understand if defaulting to Float16 could cause issues
rg "token_final_scales" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py -A 5 -B 2 | grep -A 10 "dtype"Repository: flashinfer-ai/flashinfer
Length of output: 2628
Validate token_final_scales dtype before selecting Cutlass type.
The code currently defaults to Float16 for any unrecognized dtype, which silently misinterprets data (e.g., if torch.float64 is passed). Additionally, torch.float16 is documented as supported but lacks explicit handling—it falls through to the else clause accidentally. Gate supported dtypes explicitly and reject unsupported ones.
🔧 Suggested fix
- if token_final_scales.dtype == torch.float32:
- token_scales_dtype = cutlass.Float32
- elif token_final_scales.dtype == torch.bfloat16:
- token_scales_dtype = cutlass.BFloat16
- else:
- token_scales_dtype = cutlass.Float16
+ if token_final_scales.dtype == torch.float32:
+ token_scales_dtype = cutlass.Float32
+ elif token_final_scales.dtype == torch.bfloat16:
+ token_scales_dtype = cutlass.BFloat16
+ elif token_final_scales.dtype == torch.float16:
+ token_scales_dtype = cutlass.Float16
+ else:
+ raise ValueError(
+ "token_final_scales must be float16, bfloat16, or float32"
+ )🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 459 - 471, Validate token_final_scales.dtype explicitly before
mapping to Cutlass types: handle torch.float32, torch.bfloat16, and
torch.float16 (set token_scales_dtype to cutlass.Float32, cutlass.BFloat16,
cutlass.Float16 respectively) and raise a clear error if any other dtype is
passed; then call make_ptr(token_scales_dtype, token_final_scales.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16). Locate symbols token_final_scales,
token_scales_dtype, make_ptr, and the cutlass type mappings to implement this
guard and error path.
| # Apply padding if requested (for CUDA graph support) | ||
| if permuted_m is not None: | ||
| if permuted_m < valid_m: | ||
| raise ValueError( | ||
| f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). " | ||
| f"Cannot pad to a smaller size." | ||
| ) | ||
| if permuted_m > valid_m: | ||
| num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m | ||
| # Pad with invalid index (these tiles won't be accessed) | ||
| tile_idx_to_group_idx_list.extend([int(-2e9)] * num_padding_tiles) |
There was a problem hiding this comment.
Validate permuted_m padding aligns to mma_tiler_m.
If (permuted_m - valid_m) isn’t a multiple of the tile size, the padded mapping length won’t match permuted_m, risking out‑of‑bounds indexing.
🔧 Suggested validation
if permuted_m is not None:
if permuted_m < valid_m:
raise ValueError(
f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). "
f"Cannot pad to a smaller size."
)
+ if (permuted_m - valid_m) % mma_tiler_m != 0:
+ raise ValueError(
+ "permuted_m padding must be a multiple of mma_tiler_m"
+ )
if permuted_m > valid_m:
num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Apply padding if requested (for CUDA graph support) | |
| if permuted_m is not None: | |
| if permuted_m < valid_m: | |
| raise ValueError( | |
| f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). " | |
| f"Cannot pad to a smaller size." | |
| ) | |
| if permuted_m > valid_m: | |
| num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m | |
| # Pad with invalid index (these tiles won't be accessed) | |
| tile_idx_to_group_idx_list.extend([int(-2e9)] * num_padding_tiles) | |
| # Apply padding if requested (for CUDA graph support) | |
| if permuted_m is not None: | |
| if permuted_m < valid_m: | |
| raise ValueError( | |
| f"permuted_m ({permuted_m}) must be >= valid_m ({valid_m}). " | |
| f"Cannot pad to a smaller size." | |
| ) | |
| if (permuted_m - valid_m) % mma_tiler_m != 0: | |
| raise ValueError( | |
| "permuted_m padding must be a multiple of mma_tiler_m" | |
| ) | |
| if permuted_m > valid_m: | |
| num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m | |
| # Pad with invalid index (these tiles won't be accessed) | |
| tile_idx_to_group_idx_list.extend([int(-2e9)] * num_padding_tiles) |
🧰 Tools
🪛 Ruff (0.14.13)
123-126: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py` around lines 120
- 130, The code allows padding when permuted_m > valid_m but doesn't ensure the
padding size is a multiple of the tile size (mma_tiler_m), which can cause
mismatched lengths and OOB access; inside the block that handles permuted_m >
valid_m (the same place where num_padding_tiles is computed and
tile_idx_to_group_idx_list is extended), validate that (permuted_m - valid_m) %
mma_tiler_m == 0 and if not raise a ValueError explaining that permuted_m -
valid_m must be divisible by mma_tiler_m; keep the existing behavior of
computing num_padding_tiles = (permuted_m - valid_m) // mma_tiler_m and
extending tile_idx_to_group_idx_list with the padding only after this check.
| # Get SM count | ||
| if sm_count is None: | ||
| sm_count = get_num_sm(a.device) | ||
|
|
||
| # Get or compile the kernel | ||
| gemm, _, _, _ = _get_compiled_kernel( | ||
| permuted_m=permuted_m, | ||
| n=n, | ||
| k=k, | ||
| num_experts=num_experts, | ||
| ab_dtype_name=ab_dtype, | ||
| sf_dtype_name=sf_dtype, | ||
| c_dtype_name=c_dtype, | ||
| sf_vec_size=sf_vec_size, | ||
| mma_tiler_mn=mma_tiler_mn, | ||
| cluster_shape_mn=cluster_shape_mn, | ||
| ) | ||
|
|
||
| # Compute max active clusters (cached to avoid expensive HardwareInfo queries) | ||
| max_active_clusters = get_max_active_clusters( | ||
| cluster_shape_mn[0] * cluster_shape_mn[1] | ||
| ) |
There was a problem hiding this comment.
sm_count limit isn’t applied to max_active_clusters.
Respect the API contract by clamping the computed value.
🔧 Suggested fix
- max_active_clusters = get_max_active_clusters(
- cluster_shape_mn[0] * cluster_shape_mn[1]
- )
+ max_active_clusters = min(
+ get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]),
+ sm_count,
+ )🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py` around lines 416
- 437, The computed max_active_clusters from
get_max_active_clusters(cluster_shape_mn[0] * cluster_shape_mn[1]) is not being
clamped by the sm_count limit; after computing max_active_clusters, clamp it
with sm_count (e.g., max_active_clusters = min(max_active_clusters, sm_count))
so the API contract is respected. Ensure sm_count is defined (it may be set via
get_num_sm(a.device)) before the clamp and apply this change right after the
call to get_max_active_clusters in the block that includes sm_count, get_num_sm,
and cluster_shape_mn.
| # Cache for HardwareInfo - it's expensive to create on every call | ||
| _hardware_info_cache: "cutlass.utils.HardwareInfo | None" = None | ||
|
|
||
|
|
||
| def get_hardware_info() -> "cutlass.utils.HardwareInfo": | ||
| """Get cached HardwareInfo singleton. | ||
|
|
||
| HardwareInfo queries CUDA device capabilities, which can be expensive. | ||
| This function caches the singleton to avoid repeated queries. | ||
| """ | ||
| global _hardware_info_cache | ||
| if _hardware_info_cache is None: | ||
| _hardware_info_cache = cutlass.utils.HardwareInfo() | ||
| return _hardware_info_cache | ||
|
|
||
|
|
||
| @functools.cache | ||
| def get_max_active_clusters(cluster_size: int) -> int: | ||
| """Get max active clusters for a given cluster size (cached). | ||
|
|
||
| Args: | ||
| cluster_size: Product of cluster_shape_mn dimensions. | ||
|
|
||
| Returns: | ||
| Maximum number of active clusters supported by hardware. | ||
| """ | ||
| return get_hardware_info().get_max_active_clusters(cluster_size) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, verify the file exists and check its structure
find . -name "utils.py" -path "*/cute_dsl/*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 192
🏁 Script executed:
# Check the actual content around lines 84-110
if [ -f "flashinfer/cute_dsl/utils.py" ]; then
wc -l flashinfer/cute_dsl/utils.py
echo "=== Lines 84-110 ==="
sed -n '84,110p' flashinfer/cute_dsl/utils.py
echo "=== Full context (wider range) ==="
sed -n '1,120p' flashinfer/cute_dsl/utils.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 4907
🏁 Script executed:
# Check imports at the top of the file
head -50 flashinfer/cute_dsl/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 1545
🏁 Script executed:
# Search for device-related handling in the entire flashinfer directory
rg "cuda\.current_device|torch\.cuda\.device|device.*index" flashinfer/ -A 2 -B 2 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 5755
🏁 Script executed:
# Check if HardwareInfo is documented or used elsewhere
rg "HardwareInfo" flashinfer/ --type py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1225
🏁 Script executed:
# Search for usages of get_max_active_clusters function
rg "get_max_active_clusters" flashinfer/ -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 5382
🏁 Script executed:
# Search for usages of get_hardware_info function
rg "get_hardware_info" flashinfer/ -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 606
🏁 Script executed:
# Check if there are any multi-device tests or examples
rg "cuda\.device_count|multi.*device|device.*0.*1" flashinfer/ --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2114
🏁 Script executed:
# Look at how cutlass HardwareInfo is used in the blockscaled files
grep -A 10 "get_max_active_clusters" flashinfer/cute_dsl/blockscaled_contiguous_grouped_gemm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 635
🏁 Script executed:
# Check if there's any device switching context around the usages
rg -B 5 "get_max_active_clusters" flashinfer/cute_dsl/blockscaled_gemm.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 647
🏁 Script executed:
# Check if blockscaled_gemm.py has any device-specific initialization
grep -n "device\|cuda" flashinfer/cute_dsl/blockscaled_gemm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 834
🏁 Script executed:
# Check if cutlass.utils.HardwareInfo is mentioned in CUTLASS documentation patterns
web_search
CUTLASS HardwareInfo device-aware per-GPU capabilitiesRepository: flashinfer-ai/flashinfer
Length of output: 164
🌐 Web query:
CUTLASS cutlass.utils.HardwareInfo device-specific GPU capabilities
💡 Result:
cutlass.utils.HardwareInfo is a small helper class in CUTLASS’s Python CuTe DSL that queries device-specific GPU capabilities at runtime. It exposes methods to obtain/inspect low-level hardware properties and driver compatibility used by CUTLASS kernels, for example:
- get_device_multiprocessor_count(): number of SMs (multiprocessors). [1]
- get_l2_cache_size_in_bytes(): L2 cache size. [1]
- get_max_active_clusters(cluster_size): compute maximum active clusters for a given cluster size. [1]
- Utility/driver-check helpers (internal): _checkCudaErrors, _cuda_driver_version_ge / _lt, _cudaGetErrorEnum, plus small host/device helper functions used to probe properties. [1]
- Related utility in cutlass.utils: get_smem_capacity_in_bytes(compute_capability) for shared-memory sizing by compute capability. [1]
For implementation and repository context see the CUTLASS source/docs (CuTe DSL utils / HardwareInfo). [1][2]
Sources:
- CUTLASS cutlass.utils documentation (HardwareInfo, get_smem_capacity_in_bytes). [docs.nvidia.com CUTLASS cutlass.utils]. [2]
- NVIDIA CUTLASS GitHub repository (project context and supported architectures). [github.com/NVIDIA/cutlass].
🏁 Script executed:
# Check the actual context around "on current device" comments
grep -B 5 -A 5 "on current device" flashinfer/cute_dsl/blockscaled_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 602
🏁 Script executed:
# Look for any tests or examples that use multiple devices
find flashinfer/tests -name "*.py" -o -name "*.py" | xargs grep -l "device.*0.*device.*1\|cuda:0.*cuda:1" 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 119
🏁 Script executed:
# Check if there's a specific device passed anywhere in the call stack
rg "MaskedBatchedMatmulCuteDSL|get_max_active_clusters" flashinfer/ -B 3 -A 3 | grep -i "device\|cuda" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 434
Include device identity in the cache key to support multi-device setups.
The global singleton HardwareInfo instance and @functools.cache keyed only by cluster_size will return incorrect values when running on different GPUs with different capabilities, as get_max_active_clusters has no device parameter. Compare with get_num_sm(device) in the same file, which correctly handles device-specific caching. For multi-device support, either add a device parameter to both get_hardware_info() and get_max_active_clusters(), or maintain a per-device cache dictionary.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/utils.py` around lines 84 - 110, The cached HardwareInfo
and get_max_active_clusters are device-agnostic and will return wrong values on
multi-GPU machines; update caching to be device-aware: change get_hardware_info
to accept an optional device identifier (or obtain current device internally)
and replace the single _hardware_info_cache with a per-device cache (e.g., dict
keyed by device id) for the HardwareInfo singleton; also update
get_max_active_clusters to include the device id in its cache key (remove or
replace the `@functools.cache` usage with a device-keyed cache or make the
function accept a device parameter so caching is per-device). Ensure you
reference and update the symbols _hardware_info_cache, get_hardware_info,
get_max_active_clusters, and the use of `@functools.cache` accordingly so each GPU
gets correct, device-specific values.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 245-267: The benchmark is converting tensors inside the run()
closure which inflates timing; precompute the converted tensors once before
defining or entering run(): create ti_int = ti.to(torch.int) and w1_long =
inputs["w1_fp4"].contiguous().view(torch.long) and w2_long =
inputs["w2_fp4"].contiguous().view(torch.long) (and any other .to/.view
conversions like bias.float() or similar) and then call fused_topk_deepseek(...)
and cutlass_fused_moe(hidden_fp4, ti_int, tv, w1_long, w2_long, torch.bfloat16,
quant_scales=quant_scales, input_sf=input_sf, output=output) from inside run().
- Around line 693-697: The current parsing of args.num_tokens into tokens fails
when --num-tokens "" is passed because "".split(",") yields [''] and int('')
raises ValueError; update the tokens assignment logic (where tokens is computed
and args.num_tokens and TOKEN_COUNTS are referenced) to first validate/trim
args.num_tokens and skip empty segments before int conversion (e.g., check
args.num_tokens is not empty/whitespace and filter split parts with x.strip()
before calling int), or fall back to TOKEN_COUNTS; optionally add a clear error
message if parsing still fails.
🧹 Nitpick comments (7)
tests/moe/test_cute_dsl_fused_moe.py (3)
36-50: Useflashinfer.utils.is_sm100a_supported()instead of custom GPU check.Per coding guidelines, test implementations should use
flashinfer.utilsfunctions for GPU architecture checks to ensure consistency across the test suite.♻️ Suggested refactor
-def is_blackwell(): - """Check if running on Blackwell GPU (SM100+).""" - if not torch.cuda.is_available(): - return False - props = torch.cuda.get_device_properties(0) - return props.major >= 10 +from flashinfer.utils import is_sm100a_supported # Skip decorators cute_dsl_available = pytest.mark.skipif( not is_cute_dsl_available(), reason="CuteDSL not available" ) blackwell_required = pytest.mark.skipif( - not is_blackwell(), reason="Requires Blackwell GPU (SM100+)" + not is_sm100a_supported(), reason="Requires Blackwell GPU (SM100+)" )
186-194: Set CUDA seed for reproducible GPU tensor generation.
torch.manual_seed()only affects CPU RNG. Since tensors are created directly on CUDA, also settorch.cuda.manual_seed(seed)for deterministic test behavior.♻️ Suggested fix
torch.manual_seed(seed) + torch.cuda.manual_seed(seed) sf_vec_size = 16
469-485: Addnum_local_expertsparameter for consistency with other tests.The other test methods explicitly pass
num_local_experts. While the parameter has a default, including it here improves consistency and makes the test configuration explicit.♻️ Suggested fix
with autotune(True): result = cute_dsl_fused_moe_nvfp4( x=tensors["x"], x_sf=tensors["x_sf"], token_selected_experts=tensors["token_selected_experts"], token_final_scales=tensors["token_final_scales"], w1_weight=tensors["w1_weight"], w1_weight_sf=tensors["w1_weight_sf"], w1_alpha=tensors["w1_alpha"], fc2_input_scale=tensors["fc2_input_scale"], w2_weight=tensors["w2_weight"], w2_weight_sf=tensors["w2_weight_sf"], w2_alpha=tensors["w2_alpha"], num_experts=num_experts, top_k=top_k, + num_local_experts=num_experts, )benchmarks/bench_moe_deepseek.py (4)
46-53: Consider adding a comment explaining thebpecalculation.The
bpe = 0.5 + 1/16value represents bytes per element for FP4 format (0.5 bytes for 4-bit value + 0.0625 bytes for scale factor overhead), but this isn't immediately obvious to readers.Suggested clarification
def calc_bw(n, ms): + # bpe: 0.5 bytes for FP4 data + 1/16 bytes for block scale factor (1 scale per 16 elements) bpe = 0.5 + 1 / 16
56-63: Missing validation for tensor dimension divisibility.The
interleavefunction assumesMis divisible bygs * 2(default 128). If this precondition fails, the error fromview()will be cryptic. Consider adding an assertion.Suggested improvement
def interleave(x, gs=64): M, K = x.shape[-2], x.shape[-1] + assert M % (gs * 2) == 0, f"M ({M}) must be divisible by {gs * 2}" return ( x.view(*x.shape[:-2], 2, M // (gs * 2), gs, K)
66-70: Missing CUDA random seed for full reproducibility.
torch.manual_seed(42)only sets the CPU RNG. For reproducible GPU tensor generation (e.g.,torch.randn(..., device="cuda")), you should also set the CUDA seed.Suggested fix
def create_inputs(n, dev="cuda"): """Create inputs for all backends (CuteDSL, CUTLASS, TRTLLM).""" from flashinfer.fp4_quantization import fp4_quantize torch.manual_seed(42) + torch.cuda.manual_seed(42) sv = 16
375-584: Consider extracting shared input preparation logic to reduce duplication.The
run_autotunefunction duplicates significant preparation logic frombench_cute_dsl,bench_cutlass, andbench_trtllm. Theprep()(lines 513-519) andshuf()(lines 528-537) helpers are also duplicated frombench_trtllm(lines 297-303 and 312-321).For a WIP benchmark script this is acceptable, but extracting common preparation into shared helper functions would improve maintainability.
Additionally, lines 486-489 have the same tensor conversion overhead issue noted for
bench_cutlass:ti.to(torch.int), # Line 486 inputs["w1_fp4"].contiguous().view(torch.long), # Line 488 inputs["w2_fp4"].contiguous().view(torch.long), # Line 489These should be pre-computed outside the autotune loop for accurate profiling.
benchmarks/bench_moe_deepseek.py
Outdated
| tokens = ( | ||
| [int(x) for x in args.num_tokens.split(",")] | ||
| if args.num_tokens | ||
| else TOKEN_COUNTS | ||
| ) |
There was a problem hiding this comment.
Handle edge case for empty --num-tokens argument.
If --num-tokens "" is passed, "".split(",") returns [''], and int('') raises a ValueError. Consider validating the input or using a more robust parsing approach.
Suggested fix
tokens = (
- [int(x) for x in args.num_tokens.split(",")]
- if args.num_tokens
+ [int(x.strip()) for x in args.num_tokens.split(",") if x.strip()]
+ if args.num_tokens and args.num_tokens.strip()
else TOKEN_COUNTS
)
+ if not tokens:
+ tokens = TOKEN_COUNTS🤖 Prompt for AI Agents
In `@benchmarks/bench_moe_deepseek.py` around lines 693 - 697, The current parsing
of args.num_tokens into tokens fails when --num-tokens "" is passed because
"".split(",") yields [''] and int('') raises ValueError; update the tokens
assignment logic (where tokens is computed and args.num_tokens and TOKEN_COUNTS
are referenced) to first validate/trim args.num_tokens and skip empty segments
before int conversion (e.g., check args.num_tokens is not empty/whitespace and
filter split parts with x.strip() before calling int), or fall back to
TOKEN_COUNTS; optionally add a clear error message if parsing still fails.
…-ai#2445) <!-- .github/pull_request_template.md --> ## 📌 Description bugfix to flashinfer-ai#2093, the fundamental issue is we should not write to `jit_env.FLASHINFER_CSRC_DIR` (which might be read-only) for fused-moe module, instead we should use `FLASHINFER_GEN_SRC_DIR` which is supposed to be writable. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * JIT kernel generation now correctly uses user-writable cache directories instead of package directories, resolving compatibility issues in post-installation and read-only environments. * **Documentation** * Updated JIT directory rules clarifying which directories are writable versus read-only. Added details on GPU auto-detection and CUDA architecture environment variable controls. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/logging.h (1)
33-39:⚠️ Potential issue | 🟠 MajorAvoid global
spdlog::set_levelside effects.
spdlog::set_level(lvl)modifies the level for all registered loggers in the process, not just the "flashinfer" logger. This creates an unintended global side effect that can alter application logging behavior. Set the level on the logger instance instead.Suggested fix (scope level change to library logger only)
inline void set_log_level(spdlog::level::level_enum lvl) { auto fmt = "[%Y-%m-%d %H:%M:%S.%f] [%n] [%^%l%$] %v"; auto console_sink = std::make_shared<spdlog::sinks::stdout_color_sink_mt>(); console_sink->set_pattern(fmt); console_sink->set_level(lvl); - spdlog::set_default_logger(std::make_shared<spdlog::logger>("flashinfer", console_sink)); - spdlog::set_level(lvl); + auto logger = std::make_shared<spdlog::logger>("flashinfer", console_sink); + logger->set_level(lvl); + spdlog::set_default_logger(logger); }
🧹 Nitpick comments (1)
flashinfer/jit/fused_moe.py (1)
174-174: Consider:rglobmay pick up stale generated files from previous builds.If the generation logic changes and produces fewer or differently-named
.generated.cufiles, stale files from a previous run could still be included in the build, potentially causing compilation errors or linking unexpected kernels.Consider clearing the output directory before regenerating, or using a more deterministic approach that returns the exact list of generated files from
generate_gemm_operations().
|
/bot run |
|
[CANCELING] Pipeline #43116439: canceled |
|
/bot run |
|
[CANCELING] Pipeline #43121276: canceled |
|
/bot run |
|
[CANCELING] Pipeline #43210969: canceled |
|
/bot run |
|
[FAILED] Pipeline #43303418: 12/20 passed |
|
/bot run |
<!-- .github/pull_request_template.md --> ## 📌 Description The PR is follow up to PR #2398 To integration [TRTLLM PR 10987](NVIDIA/TensorRT-LLM#10987). Use TMA.RED to improve effective memory bandwidth Perf data is (tested on GB200): Tokens | CuteDSL (main) ms | CuteDSL (TMA.RED) ms | TRTLLM gen ms | CUTLASS ms | Winner | CuteDSL Speedup (main/TMA.RED) -- | -- | -- | -- | -- | -- | -- 1 | 0.064 | 0.064 | 0.053 | 0.099 | TRTLLM | 1.000x 2 | 0.077 | 0.077 | 0.063 | 0.107 | TRTLLM | 1.000x 4 | 0.096 | 0.096 | 0.085 | 0.131 | TRTLLM | 1.000x 8 | 0.096 | 0.096 | 0.091 | 0.131 | TRTLLM | 1.000x 16 | 0.101 | 0.102 | 0.103 | 0.138 | CuteDSL | 0.990x 32 | 0.114 | 0.114 | 0.142 | 0.152 | CuteDSL | 1.000x 62 | 0.122 | 0.122 | 0.183 | 0.163 | CuteDSL | 1.000x 128 | 0.133 | 0.132 | 0.173 | 0.220 | CuteDSL | 1.008x 256 | 0.142 | 0.138 | 0.220 | 0.251 | CuteDSL | 1.029x 512 | 0.190 | 0.183 | 0.271 | 0.333 | CuteDSL | 1.038x 1024 | 0.286 | 0.278 | 0.576 | 0.482 | CuteDSL | 1.029x 2048 | 0.472 | 0.461 | 0.555 | 0.723 | CuteDSL | 1.024x 4096 | 0.855 | 0.824 | 0.873 | 1.278 | CuteDSL | 1.038x 8192 | 1.764 | 1.713 | 1.653 | 2.383 | TRTLLM | 1.030x ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Introduced block-reduction optimization in MOE finalization kernels for improved performance on latest hardware. * Added support for block-wise reduction operations across multiple data types (BF16, FP32, FP16). * **Performance** * Optimized GPU memory utilization by reducing unnecessary cross-device data transfers during computation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
cuteDSL fp4 moe from TRTLLM with fusion. issue #2259
The collective fusion performance improvement collected from TRTLLM PR 8880, PR 9288, PR 9618,.
Introducing two new API for moe
cute_dsl_fused_moe_nvfp4,CuteDslMoEWrapper.The two API are functionally equivalent. The
cute_dsl_fused_moe_nvfp4function could run the operation directly, while the wrapper API splits tensor setup and execution to better support cuda graph. Using wrapper API is recommended.The PR also introduces autotune functionality for this function.
Performance data:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Documentation