refactor: refactoring cuda code to cute-dsl (part 1)#2428
refactor: refactoring cuda code to cute-dsl (part 1)#2428yzh119 wants to merge 17 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request marks the initial phase of refactoring the project's normalization kernels to leverage CuTe-DSL, aiming to enhance JIT compilation speed and overall kernel performance. It introduces a comprehensive set of CuTe-DSL-based normalization kernels and integrates them into the existing API with a conditional dispatch mechanism, paving the way for more efficient GPU computations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds CuTe-DSL normalization kernels (RMSNorm, QK RMSNorm, FP8-quantized and fused variants, LayerNorm), new norm utilities, and runtime dispatch between CUDA-JIT and CuTe-DSL; conditions imports/exports on CuTe-DSL availability and adjusts some TVM-FFI host-call argument representations. Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant NormAPI as flashinfer.norm
participant Dispatcher as Dispatcher
participant CUDAJIT as CUDA JIT (gen_norm_module)
participant CuTeDSL as CuTe-DSL Path
participant Compiled as Compiled Kernel (TVM/ptx)
participant GPU as GPU Device
App->>NormAPI: call rmsnorm(...)
NormAPI->>Dispatcher: check FLASHINFER_USE_CUDA_NORM / is_cute_dsl_available()
alt CUDA JIT selected
Dispatcher->>CUDAJIT: request/jit module
CUDAJIT->>Compiled: produce kernel
Compiled->>GPU: execute kernel
else CuTe-DSL selected
Dispatcher->>CuTeDSL: request compiled CuTe kernel
CuTeDSL->>Compiled: produce kernel (TVM-FFI)
Compiled->>GPU: execute kernel
end
GPU-->>Compiled: result
Compiled-->>NormAPI: output tensor
NormAPI-->>App: return
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request is a significant refactoring effort, moving normalization kernels from a custom CUDA JIT implementation to the CuTe-DSL. This is a commendable step towards improving performance and maintainability. The new flashinfer/cute_dsl/norm.py file is extensive and well-structured. My review has identified a few critical and high-severity issues that need to be addressed, including a bug in the FP8 quantization logic, incorrect API parameter naming, and inefficient shared memory usage. Once these issues are resolved, this will be a solid improvement.
flashinfer/cute_dsl/norm.py
Outdated
| .reg .b16 fp8_pair; | ||
| .reg .f32 zero; | ||
| mov.f32 zero, 0f00000000; | ||
| cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0; |
There was a problem hiding this comment.
There is a bug in the PTX inline assembly. The cvt.rn.satfinite.e4m3x2.f32 instruction converts the second source operand and stores it in the upper half of the destination register. The st.global.b8 instruction then stores the lower 8 bits of the register. As written, this will store the converted zero value, not the intended val ($0).
To fix this, you should swap the source operands in the cvt instruction to place the converted value in the lower half of the fp8_pair register.
| cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0; | |
| cvt.rn.satfinite.e4m3x2.f32 fp8_pair, $0, zero; |
flashinfer/cute_dsl/norm.py
Outdated
| self.cols_per_tile_f32 * 4 * 2 | ||
| + self.cols_per_tile * elem_bytes * 2 | ||
| + 2 * self.num_warps * 4 |
There was a problem hiding this comment.
The shared memory calculation for LayerNormKernel includes space for gamma/beta in the input dtype, but these shared memory tiles (sGamma, sBeta) are allocated and partitioned but never actually used in the kernel. The kernel reads gamma and beta values directly from the float32 shared memory tiles (sGamma_f32, sBeta_f32).
This wastes a significant amount of shared memory, which can negatively impact performance by reducing occupancy.
You should remove the allocation of sGamma and sBeta (lines 1483-1492) and their partitioning (lines 1565-1566) in the kernel method, and update this shared memory size calculation.
| self.cols_per_tile_f32 * 4 * 2 | |
| + self.cols_per_tile * elem_bytes * 2 | |
| + 2 * self.num_warps * 4 | |
| self.cols_per_tile_f32 * 4 * 2 | |
| + 2 * self.num_warps * 4 |
flashinfer/cute_dsl/norm.py
Outdated
| def tensor_api( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| output: torch.Tensor, | ||
| B: int, | ||
| N: int, | ||
| eps: float, | ||
| num_blocks: int, | ||
| ) -> None: | ||
| compiled_kernel( | ||
| input, | ||
| weight, | ||
| output, | ||
| Int32(B), | ||
| Int32(N), | ||
| Float32(eps), | ||
| Int32(num_blocks), | ||
| ) |
There was a problem hiding this comment.
The enable_pdl parameter is not being passed to the compiled kernel. The qk_rmsnorm_cute function accepts enable_pdl, but it's lost because the tensor_api wrapper doesn't accept it and pass it to the compiled_kernel call.
This is a bug that prevents Programmatic Dependent Launch from being used with this kernel. You should update tensor_api to accept enable_pdl and pass it through. You'll also need to update the call to kernel in qk_rmsnorm_cute (line 2087) to pass this new argument.
| def tensor_api( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| B: int, | |
| N: int, | |
| eps: float, | |
| num_blocks: int, | |
| ) -> None: | |
| compiled_kernel( | |
| input, | |
| weight, | |
| output, | |
| Int32(B), | |
| Int32(N), | |
| Float32(eps), | |
| Int32(num_blocks), | |
| ) | |
| def tensor_api( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| B: int, | |
| N: int, | |
| eps: float, | |
| enable_pdl: bool, | |
| num_blocks: int, | |
| ) -> None: | |
| compiled_kernel( | |
| input, | |
| weight, | |
| output, | |
| Int32(B), | |
| Int32(N), | |
| Float32(eps), | |
| enable_pdl, | |
| Int32(num_blocks), | |
| ) |
flashinfer/cute_dsl/norm.py
Outdated
| def predicate_k_3d(tXcX: cute.Tensor, limit: int) -> cute.Tensor: | ||
| """Create predicate tensor for bounds checking (3D tensors). | ||
|
|
||
| For 3D tensors after local_tile, the last coordinate [2] is the head_dim dimension. | ||
| """ | ||
| tXpX = cute.make_rmem_tensor( | ||
| cute.make_layout( | ||
| ( | ||
| cute.size(tXcX, mode=[0, 1]), | ||
| cute.size(tXcX, mode=[1]), | ||
| cute.size(tXcX, mode=[2]), | ||
| ), | ||
| stride=(cute.size(tXcX, mode=[2]), 0, 1), | ||
| ), | ||
| cutlass.Boolean, | ||
| ) | ||
| for rest_v in cutlass.range_constexpr(tXpX.shape[0]): | ||
| for rest_k in cutlass.range_constexpr(tXpX.shape[2]): | ||
| # For 3D tensor, coordinate[2] is the head_dim index | ||
| tXpX[rest_v, 0, rest_k] = cute.elem_less( | ||
| tXcX[(0, rest_v), 0, rest_k][2], limit | ||
| ) | ||
| return tXpX |
flashinfer/cute_dsl/norm.py
Outdated
|
|
||
| idX = cute.make_identity_tensor(mX.shape) | ||
| gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) | ||
| cute.local_tile(mY, tiler_mn, (bidx, 0)) |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/norm.py`:
- Around line 2044-2088: The qk_rmsnorm_cute function accepts enable_pdl but
never forwards it to the kernel compilation (kernel created via
_get_compiled_qk_rmsnorm_kernel uses a hardcoded value); update qk_rmsnorm_cute
to pass the enable_pdl flag into _get_compiled_qk_rmsnorm_kernel (or else remove
enable_pdl from qk_rmsnorm_cute's signature) so the compiled kernel respects PDL
support — locate the _get_compiled_qk_rmsnorm_kernel call in qk_rmsnorm_cute and
change its arguments to include enable_pdl (and ensure any downstream kernel
invocation/signature matches this added parameter).
🧹 Nitpick comments (8)
flashinfer/cute_dsl/norm.py (8)
858-862: Dead code:cute.local_tile(mY, ...)result is unused.The result of
cute.local_tile(mY, tiler_mn, (bidx, 0))at line 860 is not assigned to a variable. The FP8 output is stored using PTX scalar stores later (lines 920-922), which accessmYdirectly with computed offsets. This call appears to be unnecessary.♻️ Proposed fix
idX = cute.make_identity_tensor(mX.shape) gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - cute.local_tile(mY, tiler_mn, (bidx, 0)) cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
1231-1236: Same issue:cute.local_tile(mY, ...)result is unused.Same dead code pattern as in
RMSNormQuantKernel.♻️ Proposed fix
idX = cute.make_identity_tensor(mX.shape) - cute.local_tile(mY, tiler_mn, (bidx, 0)) gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
1564-1567: Dead code:partition_Dresults are unused.The results of
thr_copy_load.partition_D(sGamma)andthr_copy_load.partition_D(sBeta)are not assigned to variables. Gamma/beta are loaded directly fromsGamma_f32/sBeta_f32at lines 1634-1635.♻️ Proposed fix
- # Partitions for gamma/beta (input dtype) - thr_copy_load.partition_D(sGamma) - thr_copy_load.partition_D(sBeta) - # Register fragments - initialize to zero for proper handling of out-of-bounds threads
2016-2042: Missing@flashinfer_apidecorator on public API function.The
rmsnorm_cutefunction is exported in__all__and thus part of the public API, but it lacks the@flashinfer_apidecorator required by coding guidelines.Additionally, the
enable_pdlparameter is accepted but completely ignored. The kernel is compiled with a hardcodedFalsevalue at line 1764. This breaks the API contract with callers who expect PDL to be honored.♻️ Proposed fix for decorator
+from ..api_logging import flashinfer_api + +@flashinfer_api def rmsnorm_cute( input: torch.Tensor,As per coding guidelines: "Use
@flashinfer_apidecorator for debugging API calls."
2090-2113: Same issues: missing@flashinfer_apidecorator and unusedenable_pdl.
rmsnorm_quant_cutehas the same issues asrmsnorm_cute.
2116-2135: Same issues: missing@flashinfer_apidecorator and unusedenable_pdl.
fused_add_rmsnorm_cutehas the same issues.
2138-2170: Same issues: missing@flashinfer_apidecorator and unusedenable_pdl.
fused_add_rmsnorm_quant_cutehas the same issues.
2173-2192: Missing@flashinfer_apidecorator.
layernorm_cuteis missing the@flashinfer_apidecorator. Note that this function doesn't have anenable_pdlparameter, which is consistent since it doesn't expose PDL functionality.
flashinfer/cute_dsl/norm.py
Outdated
| def qk_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| output: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Each warp processes one (batch, head) pair independently using warp-only reduction. | ||
|
|
||
| Args: | ||
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| weight: Weight tensor of shape [head_dim]. | ||
| output: Output tensor (same shape as input). | ||
| eps: Small constant for numerical stability. | ||
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | ||
| enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. | ||
| """ | ||
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | ||
|
|
||
| batch_size, num_heads, head_dim = input.shape | ||
| M = batch_size * num_heads | ||
|
|
||
| # Kernel configuration | ||
| num_warps = 4 | ||
|
|
||
| # Calculate grid size based on SM count and estimated occupancy | ||
| num_sms = get_num_sm(input.device) | ||
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | ||
| max_blocks = num_sms * blocks_per_sm | ||
| needed_blocks = (M + num_warps - 1) // num_warps | ||
| num_blocks = min(max_blocks, needed_blocks) | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_qk_rmsnorm_kernel( | ||
| dtype_str, head_dim, weight_bias, num_warps | ||
| ) | ||
|
|
||
| # Pass 3D tensors directly - kernel handles arbitrary stride | ||
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) | ||
|
|
There was a problem hiding this comment.
enable_pdl parameter is accepted but not effectively used.
The qk_rmsnorm_cute function accepts enable_pdl but the compiled kernel at line 1764 uses a hardcoded enable_pdl=False. The kernel supports PDL (lines 617-618, 747-748), but the parameter isn't being passed through during compilation.
🔧 Proposed fix to support PDL
To properly support PDL, the compilation would need to be done at runtime with the actual enable_pdl value, or the parameter should be removed from the API signature if PDL is intentionally disabled for CuTe-DSL kernels.
If PDL is intentionally disabled, consider removing the parameter:
def qk_rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
- enable_pdl: bool = False,
) -> None:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def qk_rmsnorm_cute( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| eps: float = 1e-6, | |
| weight_bias: float = 0.0, | |
| enable_pdl: bool = False, | |
| ) -> None: | |
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | |
| Supports arbitrary stride - no need to call contiguous(). | |
| Each warp processes one (batch, head) pair independently using warp-only reduction. | |
| Args: | |
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | |
| Last dimension must be contiguous (stride[-1] == 1). | |
| weight: Weight tensor of shape [head_dim]. | |
| output: Output tensor (same shape as input). | |
| eps: Small constant for numerical stability. | |
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | |
| enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. | |
| """ | |
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | |
| batch_size, num_heads, head_dim = input.shape | |
| M = batch_size * num_heads | |
| # Kernel configuration | |
| num_warps = 4 | |
| # Calculate grid size based on SM count and estimated occupancy | |
| num_sms = get_num_sm(input.device) | |
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | |
| max_blocks = num_sms * blocks_per_sm | |
| needed_blocks = (M + num_warps - 1) // num_warps | |
| num_blocks = min(max_blocks, needed_blocks) | |
| dtype_str = _torch_dtype_to_str(input.dtype) | |
| kernel = _get_compiled_qk_rmsnorm_kernel( | |
| dtype_str, head_dim, weight_bias, num_warps | |
| ) | |
| # Pass 3D tensors directly - kernel handles arbitrary stride | |
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) | |
| def qk_rmsnorm_cute( | |
| input: torch.Tensor, | |
| weight: torch.Tensor, | |
| output: torch.Tensor, | |
| eps: float = 1e-6, | |
| weight_bias: float = 0.0, | |
| ) -> None: | |
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | |
| Supports arbitrary stride - no need to call contiguous(). | |
| Each warp processes one (batch, head) pair independently using warp-only reduction. | |
| Args: | |
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | |
| Last dimension must be contiguous (stride[-1] == 1). | |
| weight: Weight tensor of shape [head_dim]. | |
| output: Output tensor (same shape as input). | |
| eps: Small constant for numerical stability. | |
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | |
| """ | |
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | |
| batch_size, num_heads, head_dim = input.shape | |
| M = batch_size * num_heads | |
| # Kernel configuration | |
| num_warps = 4 | |
| # Calculate grid size based on SM count and estimated occupancy | |
| num_sms = get_num_sm(input.device) | |
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | |
| max_blocks = num_sms * blocks_per_sm | |
| needed_blocks = (M + num_warps - 1) // num_warps | |
| num_blocks = min(max_blocks, needed_blocks) | |
| dtype_str = _torch_dtype_to_str(input.dtype) | |
| kernel = _get_compiled_qk_rmsnorm_kernel( | |
| dtype_str, head_dim, weight_bias, num_warps | |
| ) | |
| # Pass 3D tensors directly - kernel handles arbitrary stride | |
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) |
🧰 Tools
🪛 Ruff (0.14.14)
2050-2050: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 2044 - 2088, The qk_rmsnorm_cute
function accepts enable_pdl but never forwards it to the kernel compilation
(kernel created via _get_compiled_qk_rmsnorm_kernel uses a hardcoded value);
update qk_rmsnorm_cute to pass the enable_pdl flag into
_get_compiled_qk_rmsnorm_kernel (or else remove enable_pdl from
qk_rmsnorm_cute's signature) so the compiled kernel respects PDL support —
locate the _get_compiled_qk_rmsnorm_kernel call in qk_rmsnorm_cute and change
its arguments to include enable_pdl (and ensure any downstream kernel
invocation/signature matches this added parameter).
|
[FAILED] Pipeline #42732703: 1/20 passed |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/norm.py`:
- Around line 1875-1882: The parameter enable_pdl in rmsnorm_cute is unused and
triggers ARG001; explicitly mark it as intentionally unused by adding a no-op
assignment (e.g., _ = enable_pdl) or a targeted noqa comment inside rmsnorm_cute
to show API-parity intent, and apply the same change to the other wrapper
functions mentioned in the review so each unused enable_pdl is acknowledged
rather than left unused.
- Around line 1875-2051: The public CuTe-DSL wrapper functions (rmsnorm_cute,
qk_rmsnorm_cute, rmsnorm_quant_cute, fused_add_rmsnorm_cute,
fused_add_rmsnorm_quant_cute, layernorm_cute) need the `@flashinfer_api` decorator
added and the decorator imported from the project’s standard utilities; add a
single import for flashinfer_api near other imports and prepend `@flashinfer_api`
above each of these function definitions so all public entry points are traced
for API-call logging (keep existing signatures and bodies unchanged).
- Around line 371-379: Rename the unused kernel parameter M to _M in the kernel
signatures to silence Ruff ARG002 (e.g., change the argument name in
cute_dsl.norm.LayerNormKernel.kernel and the other flagged kernels
RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel,
FusedAddRMSNormQuantKernel.kernel); update the parameter name only in the
function signature (or alternatively add a targeted "# noqa: ARG002" comment) so
the intent is clear and linters stop reporting the unused argument.
- Around line 1188-1200: The scalar FP8 store computes out_offset assuming
row-major contiguous layout (out_offset = bidx * H + idx), which fails for
non-contiguous mY; update the store in the block that calls
cvt_and_store_f32_to_e4m3/get_ptr_as_int64 to compute the correct linear offset
using the output tensor's stride (e.g., out_offset = bidx * mY.stride[0] + idx)
or mirror the non-quantized kernels by using CuTe's local_tile/partition_D logic
(as in FusedAddRMSNormKernel) to derive the physical address; ensure you
reference mY.stride and preserve idx calculation so cvt_and_store_f32_to_e4m3
receives the correct out_ptr for any layout.
- Around line 835-847: The FP8 store currently computes out_offset as bidx * H +
idx which assumes a contiguous row stride; update the offset calculation to use
the actual row stride (sym_row_stride_y) so stores respect arbitrary output
tensor strides—replace the use of H in out_offset with sym_row_stride_y (i.e.,
compute out_offset = bidx * sym_row_stride_y + idx) in the block that calls
get_ptr_as_int64(mY, Int32(out_offset)) and cvt_and_store_f32_to_e4m3; ensure
any alternative tiled layout approach mirrors how inputs are handled so the
store remains stride-aware.
flashinfer/cute_dsl/norm.py
Outdated
| def kernel( | ||
| self, | ||
| mX: cute.Tensor, | ||
| mW: cute.Tensor, | ||
| mY: cute.Tensor, | ||
| M: Int32, | ||
| eps: Float32, | ||
| tv_layout: cute.Layout, | ||
| tiler_mn: cute.Shape, |
There was a problem hiding this comment.
Silence unused M kernel args to keep Ruff clean.
Ruff reports ARG002 for M in kernel signatures. Since M is not used inside kernels, rename it to _M (or add a targeted # noqa: ARG002) to document intent and satisfy lint. Apply the same pattern to the other kernel methods flagged by Ruff (RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel, FusedAddRMSNormQuantKernel.kernel, LayerNormKernel.kernel).
♻️ Example fix (apply similarly to other kernels)
- M: Int32,
+ _M: Int32,🧰 Tools
🪛 Ruff (0.14.14)
376-376: Unused method argument: M
(ARG002)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 371 - 379, Rename the unused kernel
parameter M to _M in the kernel signatures to silence Ruff ARG002 (e.g., change
the argument name in cute_dsl.norm.LayerNormKernel.kernel and the other flagged
kernels RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel,
FusedAddRMSNormQuantKernel.kernel); update the parameter name only in the
function signature (or alternatively add a targeted "# noqa: ARG002" comment) so
the intent is clear and linters stop reporting the unused argument.
flashinfer/cute_dsl/norm.py
Outdated
| col_offset = tidx * vec_size | ||
| for v in cutlass.range_constexpr(num_vec_blocks): | ||
| for e in cutlass.range_constexpr(vec_size): | ||
| idx = col_offset + v * threads_per_row * vec_size + e | ||
| if idx < H: | ||
| # Clamp and convert - use flat index for register tensor | ||
| flat_idx = v * vec_size + e | ||
| clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) | ||
| clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) | ||
| # Use PTX to convert and store FP8 byte | ||
| out_offset = bidx * H + idx | ||
| out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) | ||
| cvt_and_store_f32_to_e4m3(clamped, out_ptr) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp/repo && find . -name "norm.py" -path "*/cute_dsl/*" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && wc -l flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && sed -n '800,870p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && sed -n '700,850p' flashinfer/cute_dsl/norm.py | grep -n "def \|sym_row_stride_y\|out_offset\|stride"Repository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
cd /tmp/repo && rg "sym_row_stride_y" flashinfer/cute_dsl/norm.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 127
🏁 Script executed:
find . -name "norm.py" -path "*/cute_dsl/*"Repository: flashinfer-ai/flashinfer
Length of output: 97
🏁 Script executed:
wc -l flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 100
🏁 Script executed:
sed -n '800,870p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3324
🏁 Script executed:
rg "sym_row_stride_y" flashinfer/cute_dsl/norm.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 2900
🏁 Script executed:
sed -n '600,750p' flashinfer/cute_dsl/norm.py | grep -n "class RMSNormQuantKernel\|def __init__\|sym_row_stride_y"Repository: flashinfer-ai/flashinfer
Length of output: 118
🏁 Script executed:
sed -n '1,100p' flashinfer/cute_dsl/norm.py | grep -n "class RMSNormQuantKernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "class RMSNormQuantKernel" flashinfer/cute_dsl/norm.py -A 50Repository: flashinfer-ai/flashinfer
Length of output: 1741
🏁 Script executed:
sed -n '93,300p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 8375
🏁 Script executed:
rg "def __call__" flashinfer/cute_dsl/norm.py -A 100 | head -120Repository: flashinfer-ai/flashinfer
Length of output: 4016
🏁 Script executed:
sed -n '93,500p' flashinfer/cute_dsl/norm.py | grep -A 80 "class RMSNormQuantKernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '135,250p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 4708
🏁 Script executed:
sed -n '750,950p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 8235
🏁 Script executed:
rg "stride\|\.strides" flashinfer/cute_dsl/norm.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1200,1400p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 7964
🏁 Script executed:
rg "get_ptr_as_int64\|out_ptr\|iterator" flashinfer/cute_dsl/norm.py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -n "@dsl_user_op" flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 115
🏁 Script executed:
sed -n '1,100p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3944
🏁 Script executed:
sed -n '99,110p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 659
🏁 Script executed:
rg "\.strides\|\.shape\|iterator" flashinfer/cute_dsl/ -lRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '735,760p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 955
🏁 Script executed:
rg "mY\.|mX\." flashinfer/cute_dsl/norm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1174
🏁 Script executed:
sed -n '730,850p' flashinfer/cute_dsl/norm.py | grep -n "mY\|stride\|out_offset"Repository: flashinfer-ai/flashinfer
Length of output: 445
🏁 Script executed:
sed -n '1600,1750p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 5255
🏁 Script executed:
rg "sym_row_stride_y\|sym_row_stride_x" flashinfer/cute_dsl/norm.py -B 2 -A 10 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '760,800p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1635
🏁 Script executed:
rg "def kernel" flashinfer/cute_dsl/norm.py -A 15 | grep -A 15 "RMSNormQuantKernel" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "\.strides\|\.stride\(" flashinfer/ -t py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg "mY\[" flashinfer/cute_dsl/norm.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1600,1620p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 845
🏁 Script executed:
rg "mY\.shape\|mX\.shape" flashinfer/cute_dsl/norm.py -B 1 -A 1Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1550,1570p' flashinfer/cute_dsl/norm.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 649
🏁 Script executed:
sed -n '53,104p' flashinfer/cute_dsl/norm.py | grep -A 50 "_get_compiled_rmsnorm_quant_kernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
Use stride-aware offset calculation for FP8 output store.
Line 96: out_offset = bidx * H + idx assumes contiguous row stride equal to H, which breaks for arbitrary-stride outputs declared in the tensor layout (stride = sym_row_stride_y). Replace with out_offset = bidx * sym_row_stride_y + idx, or apply consistent tiled layout to the output tensor (as done for input) to automatically respect strides.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 835 - 847, The FP8 store currently
computes out_offset as bidx * H + idx which assumes a contiguous row stride;
update the offset calculation to use the actual row stride (sym_row_stride_y) so
stores respect arbitrary output tensor strides—replace the use of H in
out_offset with sym_row_stride_y (i.e., compute out_offset = bidx *
sym_row_stride_y + idx) in the block that calls get_ptr_as_int64(mY,
Int32(out_offset)) and cvt_and_store_f32_to_e4m3; ensure any alternative tiled
layout approach mirrors how inputs are handled so the store remains
stride-aware.
flashinfer/cute_dsl/norm.py
Outdated
| col_offset = tidx * vec_size | ||
| for v in cutlass.range_constexpr(num_vec_blocks): | ||
| for e in cutlass.range_constexpr(vec_size): | ||
| idx = col_offset + v * threads_per_row * vec_size + e | ||
| if idx < H: | ||
| # Clamp and convert - use flat index for register tensor | ||
| flat_idx = v * vec_size + e | ||
| clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) | ||
| clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) | ||
| # Use PTX to convert and store FP8 byte | ||
| out_offset = bidx * H + idx | ||
| out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) | ||
| cvt_and_store_f32_to_e4m3(clamped, out_ptr) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 100
🏁 Script executed:
sed -n '1170,1220p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2030
🏁 Script executed:
sed -n '1100,1170p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2690
🏁 Script executed:
sed -n '1050,1120p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2074
🏁 Script executed:
grep -n "def.*mY" flashinfer/cute_dsl/norm.py | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1000,1050p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1671
🏁 Script executed:
sed -n '1120,1160p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1691
🏁 Script executed:
grep -n "gY\|mY" flashinfer/cute_dsl/norm.py | grep -A5 -B5 "1188"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1155,1210p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2403
🏁 Script executed:
grep -n "get_ptr_as_int64" flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 309
🏁 Script executed:
grep -rn "def get_ptr_as_int64" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 315
🏁 Script executed:
grep -n "out_offset\|output.*offset" flashinfer/cute_dsl/norm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 321
🏁 Script executed:
sed -n '100,120p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 985
🏁 Script executed:
sed -n '840,860p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 933
🏁 Script executed:
sed -n '1060,1080p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 639
🏁 Script executed:
grep -B20 "def __call__" flashinfer/cute_dsl/norm.py | grep -A20 "FusedAddRMSNormQuantKernel"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -n "cute.local_tile\|partition_D" flashinfer/cute_dsl/norm.py | head -15Repository: flashinfer-ai/flashinfer
Length of output: 847
🏁 Script executed:
sed -n '1120,1135p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 644
🏁 Script executed:
sed -n '800,850p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2323
🏁 Script executed:
sed -n '400,430p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1343
🏁 Script executed:
sed -n '960,1000p' flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1683
🏁 Script executed:
grep -A15 "class FusedAddRMSNormKernel" flashinfer/cute_dsl/norm.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 465
🏁 Script executed:
grep -rn "FusedAddRMSNormQuantKernel\|RMSNormQuantKernel" flashinfer/ --include="*.py" | grep -v "class\|def\|#" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 720
🏁 Script executed:
grep -n "tensor.*stride\|stride.*tensor" flashinfer/cute_dsl/norm.pyRepository: flashinfer-ai/flashinfer
Length of output: 669
🏁 Script executed:
grep -rn "def get_ptr_as_int64" flashinfer/cute_dsl/ -A10Repository: flashinfer-ai/flashinfer
Length of output: 1828
The FP8 scalar store path assumes row-major contiguous output layout.
The hardcoded out_offset = bidx * H + idx breaks non-contiguous outputs. Use CuTe's local_tile and partition_D like the non-quantized kernels (e.g., FusedAddRMSNormKernel), or query the output tensor's stride and compute out_offset = bidx * mY.stride[0] + idx.
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1188 - 1200, The scalar FP8 store
computes out_offset assuming row-major contiguous layout (out_offset = bidx * H
+ idx), which fails for non-contiguous mY; update the store in the block that
calls cvt_and_store_f32_to_e4m3/get_ptr_as_int64 to compute the correct linear
offset using the output tensor's stride (e.g., out_offset = bidx * mY.stride[0]
+ idx) or mirror the non-quantized kernels by using CuTe's
local_tile/partition_D logic (as in FusedAddRMSNormKernel) to derive the
physical address; ensure you reference mY.stride and preserve idx calculation so
cvt_and_store_f32_to_e4m3 receives the correct out_ptr for any layout.
flashinfer/cute_dsl/norm.py
Outdated
| def rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| out: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: |
There was a problem hiding this comment.
enable_pdl is unused in most wrappers.
Ruff flags ARG001 for these functions. If the parameter is only for API parity, make the intent explicit (e.g., _ = enable_pdl or a targeted # noqa: ARG001). Otherwise, plumb it through once those kernels support PDL.
✅ Example (apply similarly to other wrappers)
def rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
+ _ = enable_pdl # reserved for future PDL support
"""CuTe DSL RMSNorm implementation.Also applies to: 1949-1957, 1975-1982, 1997-2006
🧰 Tools
🪛 Ruff (0.14.14)
1881-1881: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1875 - 1882, The parameter
enable_pdl in rmsnorm_cute is unused and triggers ARG001; explicitly mark it as
intentionally unused by adding a no-op assignment (e.g., _ = enable_pdl) or a
targeted noqa comment inside rmsnorm_cute to show API-parity intent, and apply
the same change to the other wrapper functions mentioned in the review so each
unused enable_pdl is acknowledged rather than left unused.
flashinfer/cute_dsl/norm.py
Outdated
| def rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| out: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL RMSNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
| H = input.shape[-1] | ||
| if input.dim() == 3: | ||
| M = input.shape[0] * input.shape[1] | ||
| input_2d = input.view(M, H) | ||
| out_2d = out.view(M, H) | ||
| else: | ||
| M = input.shape[0] | ||
| input_2d = input | ||
| out_2d = out | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_rmsnorm_kernel(dtype_str, H, weight_bias) | ||
| kernel(input_2d, weight, out_2d, M, eps) | ||
|
|
||
|
|
||
| def qk_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| output: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Each warp processes one (batch, head) pair independently using warp-only reduction. | ||
|
|
||
| Args: | ||
| input: Input tensor of shape [batch_size, num_heads, head_dim]. | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| weight: Weight tensor of shape [head_dim]. | ||
| output: Output tensor (same shape as input). | ||
| eps: Small constant for numerical stability. | ||
| weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). | ||
| enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. | ||
| """ | ||
| assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" | ||
|
|
||
| batch_size, num_heads, head_dim = input.shape | ||
| M = batch_size * num_heads | ||
|
|
||
| # Kernel configuration | ||
| num_warps = 4 | ||
|
|
||
| # Calculate grid size based on SM count and estimated occupancy | ||
| num_sms = get_num_sm(input.device) | ||
| blocks_per_sm = 16 # Theoretical max for 128-thread blocks | ||
| max_blocks = num_sms * blocks_per_sm | ||
| needed_blocks = (M + num_warps - 1) // num_warps | ||
| num_blocks = min(max_blocks, needed_blocks) | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_qk_rmsnorm_kernel( | ||
| dtype_str, head_dim, weight_bias, num_warps | ||
| ) | ||
|
|
||
| # Pass 3D tensors directly - kernel handles arbitrary stride | ||
| kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) | ||
|
|
||
|
|
||
| def rmsnorm_quant_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| scale: float, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL RMSNorm + FP8 quantization implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias | ||
| ) | ||
| kernel(out, input, weight, M, scale, eps) | ||
|
|
||
|
|
||
| def fused_add_rmsnorm_cute( | ||
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL Fused Add + RMSNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| kernel = _get_compiled_fused_add_rmsnorm_kernel(dtype_str, H, weight_bias) | ||
| kernel(input, residual, weight, M, eps) | ||
|
|
||
|
|
||
| def fused_add_rmsnorm_quant_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| scale: float, | ||
| eps: float = 1e-6, | ||
| weight_bias: float = 0.0, | ||
| enable_pdl: bool = False, | ||
| ) -> None: | ||
| """CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| out_dtype_str = _torch_dtype_to_str(out.dtype) | ||
| kernel = _get_compiled_fused_add_rmsnorm_quant_kernel( | ||
| dtype_str, out_dtype_str, H, weight_bias | ||
| ) | ||
| kernel( | ||
| out, | ||
| input, | ||
| residual, | ||
| weight, | ||
| M, | ||
| scale, | ||
| eps, | ||
| ) | ||
|
|
||
|
|
||
| def layernorm_cute( | ||
| out: torch.Tensor, | ||
| input: torch.Tensor, | ||
| gamma: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| ) -> None: | ||
| """CuTe DSL LayerNorm implementation. | ||
|
|
||
| Supports arbitrary stride - no need to call contiguous(). | ||
| Last dimension must be contiguous (stride[-1] == 1). | ||
| """ | ||
|
|
||
| H = input.shape[-1] | ||
| M = input.shape[0] | ||
|
|
||
| dtype_str = _torch_dtype_to_str(input.dtype) | ||
| gamma_dtype_str = _torch_dtype_to_str(gamma.dtype) | ||
| kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H) | ||
| kernel(out, input, gamma, beta, M, eps) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add @flashinfer_api on public CuTe-DSL API wrappers.
The public Python entry points here (e.g., rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute, fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute, layernorm_cute) should be decorated for API-call tracing. Please apply @flashinfer_api (and import it from the standard location used elsewhere in the repo).
As per coding guidelines: Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats).
🧰 Tools
🪛 Ruff (0.14.14)
1881-1881: Unused function argument: enable_pdl
(ARG001)
1909-1909: Unused function argument: enable_pdl
(ARG001)
1956-1956: Unused function argument: enable_pdl
(ARG001)
1981-1981: Unused function argument: enable_pdl
(ARG001)
2005-2005: Unused function argument: enable_pdl
(ARG001)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1875 - 2051, The public CuTe-DSL
wrapper functions (rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute,
fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute, layernorm_cute) need the
`@flashinfer_api` decorator added and the decorator imported from the project’s
standard utilities; add a single import for flashinfer_api near other imports
and prepend `@flashinfer_api` above each of these function definitions so all
public entry points are traced for API-call logging (keep existing signatures
and bodies unchanged).
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/norm.py`:
- Around line 468-471: The unconditional imports of rmsnorm_fp4quant and
add_rmsnorm_fp4quant from .cute_dsl can raise when cutlass/cute_dsl isn't
available; wrap these imports in a try/except (or check
cute_dsl.is_cute_dsl_available()) and fall back to setting rmsnorm_fp4quant =
None and add_rmsnorm_fp4quant = None so the module can still import and the
existing _USE_CUDA_NORM/CUDA fallback logic can decide at runtime; update any
usage sites to handle None accordingly.
flashinfer/norm.py
Outdated
| # CuTe-DSL fused RMSNorm + FP4 Quantization kernels | ||
| # These require CuTe-DSL to be available and SM100+ (Blackwell) GPUs | ||
| try: | ||
| from .cute_dsl import rmsnorm_fp4quant, add_rmsnorm_fp4quant | ||
| except ImportError: | ||
| # CuTe-DSL not available | ||
| rmsnorm_fp4quant = None # type: ignore[misc,assignment] | ||
| add_rmsnorm_fp4quant = None # type: ignore[misc,assignment] | ||
| # These require SM100+ (Blackwell) GPUs | ||
| from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant | ||
| from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the actual file structure around lines 468-471 in flashinfer/norm.py
head -n 480 flashinfer/norm.py | tail -n 30Repository: flashinfer-ai/flashinfer
Length of output: 914
🏁 Script executed:
# Also check if there's a _USE_CUDA_NORM conditional block and where these imports sit relative to it
rg -n "_USE_CUDA_NORM|SM100|Blackwell" flashinfer/norm.py -A 5 -B 5Repository: flashinfer-ai/flashinfer
Length of output: 4119
🏁 Script executed:
# Search for existing error handling patterns with cute_dsl imports elsewhere
rg -n "cute_dsl|try.*import|except ImportError" flashinfer/ --type py -lRepository: flashinfer-ai/flashinfer
Length of output: 439
🏁 Script executed:
# Check how the cute_dsl module is structured and whether it handles SM100+ checks
head -n 100 flashinfer/cute_dsl.py 2>/dev/null || find . -name "cute_dsl*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the cute_dsl __init__.py to see if there's SM100 version checking
cat -n flashinfer/cute_dsl/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 3281
🏁 Script executed:
# Search for rmsnorm_fp4quant definition and any SM100 guards
rg -n "rmsnorm_fp4quant|SM100" flashinfer/cute_dsl/ --type py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 11099
🏁 Script executed:
# Check if cute_dsl module raises errors or handles gracefully on non-SM100 systems
rg -n "raise|RuntimeError|ImportError|version|compute_capability" flashinfer/cute_dsl/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 6658
🏁 Script executed:
# Check how the module is used after import - are these functions wrapped with SM100 checks?
rg -n "rmsnorm_fp4quant\|add_rmsnorm_fp4quant" flashinfer/ --type py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Find the is_cute_dsl_available function
rg -n "def is_cute_dsl_available" flashinfer/ --type py -A 10Repository: flashinfer-ai/flashinfer
Length of output: 742
🏁 Script executed:
# Check the utils.py in cute_dsl to see what is_cute_dsl_available does
head -n 50 flashinfer/cute_dsl/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 1524
🏁 Script executed:
# Check how rmsnorm_fp4quant and add_rmsnorm_fp4quant are used after import
rg -n "rmsnorm_fp4quant|add_rmsnorm_fp4quant" flashinfer/norm.py -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 306
🏁 Script executed:
# Also check if there's any __getattr__ or dynamic import fallback in norm.py or at module level
rg -n "__getattr__|try:|except|__all__" flashinfer/norm.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 134
🏁 Script executed:
# Check flashinfer's main __init__.py to see if norm module import is conditional
head -n 100 flashinfer/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 3782
Wrap FP4 imports in try/except to prevent import failures on systems without cutlass.
The rmsnorm_fp4quant and add_rmsnorm_fp4quant imports are unconditional, but cute_dsl/__init__.py only exports them when is_cute_dsl_available() returns True (i.e., when cutlass is installed). Without error handling, importing flashinfer.norm will fail on systems without cutlass, even though the module provides CUDA-based fallbacks via _USE_CUDA_NORM. This breaks backward compatibility.
Suggested pattern for graceful degradation
# CuTe-DSL fused RMSNorm + FP4 Quantization kernels
# These require SM100+ (Blackwell) GPUs
-from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
-from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
+try:
+ from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
+ from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
+except ImportError:
+ rmsnorm_fp4quant = None
+ add_rmsnorm_fp4quant = None🤖 Prompt for AI Agents
In `@flashinfer/norm.py` around lines 468 - 471, The unconditional imports of
rmsnorm_fp4quant and add_rmsnorm_fp4quant from .cute_dsl can raise when
cutlass/cute_dsl isn't available; wrap these imports in a try/except (or check
cute_dsl.is_cute_dsl_available()) and fall back to setting rmsnorm_fp4quant =
None and add_rmsnorm_fp4quant = None so the module can still import and the
existing _USE_CUDA_NORM/CUDA fallback logic can decide at runtime; update any
usage sites to handle None accordingly.
|
wanna ask you @yzh119 about the reason we put these things all in cute_dsl |
We should categorize kernels by functionalities, not sources. All kernels inside For this specific PR, let me make |
|
cc @kahyunnam |
|
[CANCELED] Pipeline #42752005: canceled |
@yzh119 @cyx-6 were there any updates for this tvm-ffi investigation? (Just curious, since I was also working on a cute-dsl implementation of RoPE) |
There was a problem hiding this comment.
♻️ Duplicate comments (5)
flashinfer/norm/kernels/rmsnorm.py (2)
871-877:⚠️ Potential issue | 🔴 CriticalFix runtime argument order in
rmsnorm_quant_cute.Line 876 passes
(out, input, weight, ...), but the compiled kernel expects(input, weight, out, ...). This can produce invalid reads/writes or immediate failure.Suggested fix
- kernel(out, input, weight, M, scale, eps) + kernel(input, weight, out, M, scale, eps)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/rmsnorm.py` around lines 871 - 877, The runtime call to the compiled kernel uses the wrong argument order: change the kernel invocation in rmsnorm_quant_cute (where kernel is obtained via _get_compiled_rmsnorm_quant_kernel) from kernel(out, input, weight, M, scale, eps) to kernel(input, weight, out, M, scale, eps) so the compiled kernel receives (input, weight, out, ...) as expected and avoids invalid reads/writes.
562-564:⚠️ Potential issue | 🟠 MajorFP8 quant store still assumes contiguous output rows.
Line 619 computes
bidx * H + idx, which ignores output row stride and breaks non-contiguous output tensors.Suggested fix
idX = cute.make_identity_tensor(mX.shape) gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + gY = cute.local_tile(mY, tiler_mn, (bidx, 0)) cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) @@ - out_offset = bidx * H + idx - out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) + out_ptr = get_ptr_as_int64(gY, Int32(idx)) cvt_and_store_f32_to_e4m3(clamped, out_ptr)Also applies to: 619-621
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/rmsnorm.py` around lines 562 - 564, The FP8 quant store logic assumes contiguous output rows by computing global row index as bidx * H + idx (breaking for non-contiguous outputs); update the index computation to use the actual output row stride instead of H and use the tensor's stride/shape information when creating identity and local tiles (references: idX from cute.make_identity_tensor, gX/cX from cute.local_tile, and the bidx/idx calculations). Replace any hard-coded H multipliers with a derived row_stride (or call that returns the logical row offset) so gX/cX and the FP8 store address calculations work with non-contiguous output tensors. Ensure the same change is applied at the other occurrence noted (lines 619–621).flashinfer/norm/kernels/layernorm.py (2)
366-372:⚠️ Potential issue | 🟠 Major
gamma_dtype_stris accepted in the compile path but not honored by the kernel.Line 370 parses
gamma_dtype, but Line 371 constructsLayerNormKernel(dtype, H)and the kernel data path is fixed to Float32 gamma/beta. Non-fp32 gamma/beta inputs can break behavior.Minimal safety fix (enforce fp32 gamma/beta at API boundary)
def layernorm_cute( @@ - dtype_str = _torch_dtype_to_str(input.dtype) - gamma_dtype_str = _torch_dtype_to_str(gamma.dtype) - kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H, enable_pdl) + dtype_str = _torch_dtype_to_str(input.dtype) + if gamma.dtype != torch.float32 or beta.dtype != torch.float32: + raise TypeError( + "layernorm_cute currently requires gamma/beta to be torch.float32" + ) + kernel = _get_compiled_layernorm_kernel(dtype_str, "float32", H, enable_pdl)Also applies to: 431-434
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/layernorm.py` around lines 366 - 372, The function is parsing gamma_dtype_str but not using it—LayerNormKernel is being constructed with only dtype and H so gamma/beta are forced to fp32; update the compile path to propagate gamma_dtype into the kernel creation or enforce fp32 at the API boundary. Specifically, change the call that constructs LayerNormKernel(dtype, H) to include the parsed gamma_dtype (e.g., LayerNormKernel(dtype, gamma_dtype, H) or equivalent constructor) and update the LayerNormKernel constructor/signature to accept and store gamma/beta dtype, or alternatively validate gamma_dtype_str early and raise/convert to fp32 before proceeding; repeat the same fix for the second occurrence around the other compile path where gamma_dtype_str is parsed (the similar block at the other reported location).
90-100:⚠️ Potential issue | 🟠 MajorRemove dead gamma/beta scratch allocations to cut SMEM pressure.
Lines 196-206 and Lines 270-279 allocate/partition
sGamma/sBetaandtXrGamma/tXrBeta, but they are never read. Line 90 still accounts for this unused footprint.Suggested cleanup
def _smem_size_in_bytes(self) -> int: - # Shared memory for: - # - gamma/beta f32 tiles: cols_per_tile_f32 * 4 * 2 - # - gamma/beta input dtype tiles: cols_per_tile * elem_bytes * 2 - # - reduction buffers: 2 * num_warps * 4 - elem_bytes = self.dtype.width // 8 - return ( - self.cols_per_tile_f32 * 4 * 2 - + self.cols_per_tile * elem_bytes * 2 - + 2 * self.num_warps * 4 - ) + # Shared memory for: + # - gamma/beta f32 tiles: cols_per_tile_f32 * 4 * 2 + # - reduction buffers: 2 * num_warps * 4 + return self.cols_per_tile_f32 * 4 * 2 + 2 * self.num_warps * 4 @@ - # Shared memory tiles for gamma, beta in input dtype (for matching shape with x) - sGamma = smem.allocate_tensor( - mX.element_type, - cute.make_ordered_layout(tiler_mn, order=(1, 0)), - byte_alignment=16, - ) - sBeta = smem.allocate_tensor( - mX.element_type, - cute.make_ordered_layout(tiler_mn, order=(1, 0)), - byte_alignment=16, - ) @@ - # Partitions for gamma/beta (input dtype) - thr_copy_load.partition_D(sGamma) - thr_copy_load.partition_D(sBeta) @@ - tXrGamma = cute.make_rmem_tensor(tXgX.shape, mX.element_type) - tXrBeta = cute.make_rmem_tensor(tXgX.shape, mX.element_type) tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) - tXrGamma.store(cute.zeros_like(tXrGamma, dtype=mX.element_type)) - tXrBeta.store(cute.zeros_like(tXrBeta, dtype=mX.element_type))Also applies to: 196-206, 270-279
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/layernorm.py` around lines 90 - 100, The shared-memory size calculation in _smem_size_in_bytes() is overstating required SMEM because scratch allocations for sGamma/sBeta and tXrGamma/tXrBeta are never used; remove those unused allocations and corresponding size from the computation. Update _smem_size_in_bytes() to exclude the gamma/beta f32 and input-dtype tile terms (the expressions using cols_per_tile_f32 and cols_per_tile * elem_bytes) and adjust any code that partitions/allocates sGamma, sBeta, tXrGamma, tXrBeta (remove those variables/partitions) so only the actual reduction buffer space (2 * num_warps * 4) and any remaining live tiles are accounted for. Ensure references to cols_per_tile_f32, cols_per_tile, elem_bytes, and num_warps remain correct for the retained allocations.flashinfer/norm/kernels/fused_add_rmsnorm.py (1)
339-341:⚠️ Potential issue | 🟠 MajorRespect output row stride in the FP8 scalar store path.
Line 415 hardcodes contiguous-row addressing (
bidx * H + idx). That breaks the arbitrary-stride contract for non-contiguous outputs.Suggested fix
- gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) - gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) + gY = cute.local_tile(mY, tiler_mn, (bidx, 0)) cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) @@ - out_offset = bidx * H + idx - out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) + out_ptr = get_ptr_as_int64(gY, Int32(idx)) cvt_and_store_f32_to_e4m3(clamped, out_ptr)Also applies to: 415-417
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/fused_add_rmsnorm.py` around lines 339 - 341, The FP8 scalar store path currently computes target element addresses using a hardcoded contiguous-row formula (bidx * H + idx) which ignores arbitrary output row strides; update the scalar store logic to use the output buffer's actual row stride instead of H — derive the row stride from the tile/descriptor associated with idX/cX (or use the same stride used when creating gX/gR/cX via cute.local_tile/tiler_mn) and compute the base offset as bidx * out_row_stride + idx; apply the same fix to the other occurrences in the FP8 path (the three places corresponding to the contiguous address calculation).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py`:
- Around line 339-341: The FP8 scalar store path currently computes target
element addresses using a hardcoded contiguous-row formula (bidx * H + idx)
which ignores arbitrary output row strides; update the scalar store logic to use
the output buffer's actual row stride instead of H — derive the row stride from
the tile/descriptor associated with idX/cX (or use the same stride used when
creating gX/gR/cX via cute.local_tile/tiler_mn) and compute the base offset as
bidx * out_row_stride + idx; apply the same fix to the other occurrences in the
FP8 path (the three places corresponding to the contiguous address calculation).
In `@flashinfer/norm/kernels/layernorm.py`:
- Around line 366-372: The function is parsing gamma_dtype_str but not using
it—LayerNormKernel is being constructed with only dtype and H so gamma/beta are
forced to fp32; update the compile path to propagate gamma_dtype into the kernel
creation or enforce fp32 at the API boundary. Specifically, change the call that
constructs LayerNormKernel(dtype, H) to include the parsed gamma_dtype (e.g.,
LayerNormKernel(dtype, gamma_dtype, H) or equivalent constructor) and update the
LayerNormKernel constructor/signature to accept and store gamma/beta dtype, or
alternatively validate gamma_dtype_str early and raise/convert to fp32 before
proceeding; repeat the same fix for the second occurrence around the other
compile path where gamma_dtype_str is parsed (the similar block at the other
reported location).
- Around line 90-100: The shared-memory size calculation in
_smem_size_in_bytes() is overstating required SMEM because scratch allocations
for sGamma/sBeta and tXrGamma/tXrBeta are never used; remove those unused
allocations and corresponding size from the computation. Update
_smem_size_in_bytes() to exclude the gamma/beta f32 and input-dtype tile terms
(the expressions using cols_per_tile_f32 and cols_per_tile * elem_bytes) and
adjust any code that partitions/allocates sGamma, sBeta, tXrGamma, tXrBeta
(remove those variables/partitions) so only the actual reduction buffer space (2
* num_warps * 4) and any remaining live tiles are accounted for. Ensure
references to cols_per_tile_f32, cols_per_tile, elem_bytes, and num_warps remain
correct for the retained allocations.
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 871-877: The runtime call to the compiled kernel uses the wrong
argument order: change the kernel invocation in rmsnorm_quant_cute (where kernel
is obtained via _get_compiled_rmsnorm_quant_kernel) from kernel(out, input,
weight, M, scale, eps) to kernel(input, weight, out, M, scale, eps) so the
compiled kernel receives (input, weight, out, ...) as expected and avoids
invalid reads/writes.
- Around line 562-564: The FP8 quant store logic assumes contiguous output rows
by computing global row index as bidx * H + idx (breaking for non-contiguous
outputs); update the index computation to use the actual output row stride
instead of H and use the tensor's stride/shape information when creating
identity and local tiles (references: idX from cute.make_identity_tensor, gX/cX
from cute.local_tile, and the bidx/idx calculations). Replace any hard-coded H
multipliers with a derived row_stride (or call that returns the logical row
offset) so gX/cX and the FP8 store address calculations work with non-contiguous
output tensors. Ensure the same change is applied at the other occurrence noted
(lines 619–621).
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
flashinfer/__init__.pyflashinfer/cute_dsl/__init__.pyflashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/norm/kernels/fused_add_rmsnorm.pyflashinfer/norm/kernels/layernorm.pyflashinfer/norm/kernels/rmsnorm.pyflashinfer/norm/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/init.py
- flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
|
/bot run |
|
[FAILED] Pipeline #45155756: 7/20 passed |
|
/bot run |
|
[CANCELING] Pipeline #45179095: canceled |
Thanks @DevashishLal-CB. I absorbed #2459 in a newly added commit a7690ec by essentially bringing over all the changes. Is this what you were looking for? |
|
/bot run |
|
[CANCELING] Pipeline #45268078: canceled |
|
/bot run |
Yup thanks, LGTM! |
|
[SUCCESS] Pipeline #45342593: 10/20 passed |
📌 Description
We prioritize using dsl for kernel development over cuda for faster JIT compilation speed.
This PR is the first series that refactors the simple normalization kernels to cute-dsl.
CUDA code should be ready to remove after we finish end-to-end testing.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores
Bug Fixes