Skip to content
7 changes: 5 additions & 2 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm

from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
try:
from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
except (ImportError, AttributeError):
pass # nvidia-cutlass-dsl not installed
Comment on lines +112 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Expose quantized norm APIs at the package level.

Line 97-101 exports rmsnorm and fused_add_rmsnorm, but the new quantized variants (rmsnorm_quant, fused_add_rmsnorm_quant) from flashinfer.norm are still missing at the top level. Consider exporting them here so flashinfer.rmsnorm_quant works consistently.

✅ Suggested export additions
 from .norm import fused_add_rmsnorm as fused_add_rmsnorm
+from .norm import fused_add_rmsnorm_quant as fused_add_rmsnorm_quant
 from .norm import layernorm as layernorm
 from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
 from .norm import gemma_rmsnorm as gemma_rmsnorm
 from .norm import rmsnorm as rmsnorm
+from .norm import rmsnorm_quant as rmsnorm_quant

As per coding guidelines: Export new operations in flashinfer/init.py to make them available at package level.

🤖 Prompt for AI Agents
In `@flashinfer/__init__.py` around lines 103 - 107, The package-level exports for
the quantized norm variants are missing: import rmsnorm_fp4quant and
add_rmsnorm_fp4quant from flashinfer.norm (already attempted in the try block)
and then assign them to the public names used elsewhere (e.g., expose
rmsnorm_fp4quant as rmsnorm_quant and add_rmsnorm_fp4quant as
fused_add_rmsnorm_quant) so flashinfer.rmsnorm_quant and
flashinfer.fused_add_rmsnorm_quant resolve; update the try block in
flashinfer.__init__.py to perform these assignments (keep the existing
ImportError/AttributeError handling).

from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
Expand Down
53 changes: 47 additions & 6 deletions flashinfer/cute_dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@
This module provides high-performance GPU kernels implemented using NVIDIA CuTe-DSL.
"""

from .utils import is_cute_dsl_available, make_ptr, get_cutlass_dtype, get_num_sm
import importlib.util

# Conditionally import CuTe-DSL kernels

def is_cute_dsl_available() -> bool:
return (
importlib.util.find_spec("cutlass") is not None
and importlib.util.find_spec("cutlass.cute") is not None
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Decorate and cache is_cute_dsl_available().

Line 24-28 defines a public API (exported in __all__) but it isn’t logged or cached. Adding @flashinfer_api and @functools.cache avoids repeated module discovery and aligns with the API logging/caching policy.

✅ Suggested update
-import importlib.util
+import functools
+import importlib.util
+
+from ..api_logging import flashinfer_api
@@
-def is_cute_dsl_available() -> bool:
+@functools.cache
+@flashinfer_api
+def is_cute_dsl_available() -> bool:
     return (
         importlib.util.find_spec("cutlass") is not None
         and importlib.util.find_spec("cutlass.cute") is not None
     )

As per coding guidelines: Python API functions should use @functools.cache decorator for module caching to avoid recompilation; Use @flashinfer_api decorator on Python functions for API logging with crash-safe input capture before execution.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import importlib.util
# Conditionally import CuTe-DSL kernels
def is_cute_dsl_available() -> bool:
return (
importlib.util.find_spec("cutlass") is not None
and importlib.util.find_spec("cutlass.cute") is not None
)
import functools
import importlib.util
from ..api_logging import flashinfer_api
`@functools.cache`
`@flashinfer_api`
def is_cute_dsl_available() -> bool:
return (
importlib.util.find_spec("cutlass") is not None
and importlib.util.find_spec("cutlass.cute") is not None
)
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/__init__.py` around lines 21 - 28, The
is_cute_dsl_available() function should be decorated with `@flashinfer_api` and
`@functools.cache` to enable API logging and cache the module discovery; update
the top imports to import functools (or functools.cache) and import the
flashinfer_api decorator (or from its module) so you can apply `@flashinfer_api`
and `@functools.cache` directly above def is_cute_dsl_available to avoid repeated
find_spec calls and ensure API logging.



# Conditionally import CuTe-DSL kernels (including utils which requires cutlass)
if is_cute_dsl_available():
from .utils import make_ptr, get_cutlass_dtype, get_num_sm
from .blockscaled_gemm import (
grouped_gemm_nt_masked,
Sm100BlockScaledPersistentDenseGemmKernel,
Expand All @@ -36,16 +45,35 @@
AddRMSNormFP4QuantKernel,
)

# Backwards-compatible re-exports from flashinfer.norm.kernels submodule
from ..norm.kernels import (
# Kernel classes
RMSNormKernel,
QKRMSNormKernel,
RMSNormQuantKernel,
FusedAddRMSNormKernel,
FusedAddRMSNormQuantKernel,
LayerNormKernel,
# Python API functions
rmsnorm_cute,
qk_rmsnorm_cute,
rmsnorm_quant_cute,
fused_add_rmsnorm_cute,
fused_add_rmsnorm_quant_cute,
layernorm_cute,
)

__all__ = [
# Utils (always available)
# Always available
"is_cute_dsl_available",
"make_ptr",
"get_cutlass_dtype",
"get_num_sm",
]

if is_cute_dsl_available():
__all__ += [
# Utils (require cutlass)
"make_ptr",
"get_cutlass_dtype",
"get_num_sm",
# Blockscaled GEMM
"grouped_gemm_nt_masked",
"Sm100BlockScaledPersistentDenseGemmKernel",
Expand All @@ -56,4 +84,17 @@
# Add + RMSNorm + FP4 Quantization
"add_rmsnorm_fp4quant",
"AddRMSNormFP4QuantKernel",
# Norm kernels (CuTe DSL) - backwards-compatible re-exports
"RMSNormKernel",
"QKRMSNormKernel",
"RMSNormQuantKernel",
"FusedAddRMSNormKernel",
"FusedAddRMSNormQuantKernel",
"LayerNormKernel",
"rmsnorm_cute",
"qk_rmsnorm_cute",
"rmsnorm_quant_cute",
"fused_add_rmsnorm_cute",
"fused_add_rmsnorm_quant_cute",
"layernorm_cute",
]
155 changes: 125 additions & 30 deletions flashinfer/norm.py → flashinfer/norm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,53 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

FlashInfer Normalization Kernels
================================

This package provides high-performance normalization kernels:

- RMSNorm: Root Mean Square Normalization
- LayerNorm: Layer Normalization
- Fused Add + RMSNorm: Combined residual add and RMSNorm
- Quantized variants with FP8/FP4 output
"""

import functools
import os
from typing import Optional

import torch

from .api_logging import flashinfer_api
from .jit.norm import gen_norm_module
from .utils import device_support_pdl, register_custom_op, register_fake_op
from ..api_logging import flashinfer_api
from ..utils import device_support_pdl, register_custom_op, register_fake_op

# Always import gen_norm_module for JIT warmup and CUDA fallback
from ..jit.norm import gen_norm_module

# Use CUDA JIT implementation instead of CuTe DSL (for debugging/fallback)
# Also fallback to CUDA JIT if nvidia-cutlass-dsl is not installed
_USE_CUDA_NORM = os.environ.get("FLASHINFER_USE_CUDA_NORM", "0") == "1"

@functools.cache
def get_norm_module():
return gen_norm_module().build_and_load()
if not _USE_CUDA_NORM:
try:
from .kernels import (
rmsnorm_cute,
qk_rmsnorm_cute,
rmsnorm_quant_cute,
fused_add_rmsnorm_cute,
fused_add_rmsnorm_quant_cute,
layernorm_cute,
)
except (ImportError, AttributeError):
# nvidia-cutlass-dsl not installed or incompatible version
_USE_CUDA_NORM = True

if _USE_CUDA_NORM:

@functools.cache
def get_norm_module():
return gen_norm_module().build_and_load()


@flashinfer_api
Expand Down Expand Up @@ -60,16 +92,14 @@ def rmsnorm(
output: torch.Tensor
Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
if out is None:
out = torch.empty_like(input)
_rmsnorm(out, input, weight, eps, enable_pdl)
_rmsnorm_impl(out, input, weight, eps, enable_pdl)
return out


@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
def _rmsnorm(
def _rmsnorm_impl(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -78,11 +108,21 @@ def _rmsnorm(
) -> None:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl)
if _USE_CUDA_NORM:
get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl)
else:
if input.dim() == 3:
qk_rmsnorm_cute(
input, weight, out, eps, weight_bias=0.0, enable_pdl=enable_pdl
)
else:
rmsnorm_cute(
input, weight, out, eps, weight_bias=0.0, enable_pdl=enable_pdl
)


@register_fake_op("flashinfer::rmsnorm")
def _rmsnorm_fake(
def _rmsnorm_impl_fake(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -129,7 +169,12 @@ def rmsnorm_quant(
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl)
if _USE_CUDA_NORM:
get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl)
else:
rmsnorm_quant_cute(
out, input, weight, scale, eps, weight_bias=0.0, enable_pdl=enable_pdl
)


@register_fake_op("flashinfer::rmsnorm_quant")
Expand Down Expand Up @@ -177,7 +222,12 @@ def fused_add_rmsnorm(
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
if _USE_CUDA_NORM:
get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
else:
fused_add_rmsnorm_cute(
input, residual, weight, eps, weight_bias=0.0, enable_pdl=enable_pdl
)


@register_fake_op("flashinfer::fused_add_rmsnorm")
Expand Down Expand Up @@ -232,9 +282,21 @@ def fused_add_rmsnorm_quant(
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().fused_add_rmsnorm_quant(
out, input, residual, weight, scale, eps, enable_pdl
)
if _USE_CUDA_NORM:
get_norm_module().fused_add_rmsnorm_quant(
out, input, residual, weight, scale, eps, enable_pdl
)
else:
fused_add_rmsnorm_quant_cute(
out,
input,
residual,
weight,
scale,
eps,
weight_bias=0.0,
enable_pdl=enable_pdl,
)


@register_fake_op("flashinfer::fused_add_rmsnorm_quant")
Expand Down Expand Up @@ -281,16 +343,14 @@ def gemma_rmsnorm(
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
if out is None:
out = torch.empty_like(input)
_gemma_rmsnorm(out, input, weight, eps, enable_pdl)
_gemma_rmsnorm_impl(out, input, weight, eps, enable_pdl)
return out


@register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",))
def _gemma_rmsnorm(
def _gemma_rmsnorm_impl(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -299,11 +359,21 @@ def _gemma_rmsnorm(
) -> None:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl)
if _USE_CUDA_NORM:
get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl)
else:
if input.dim() == 3:
qk_rmsnorm_cute(
input, weight, out, eps, weight_bias=1.0, enable_pdl=enable_pdl
)
else:
rmsnorm_cute(
input, weight, out, eps, weight_bias=1.0, enable_pdl=enable_pdl
)


@register_fake_op("flashinfer::gemma_rmsnorm")
def _gemma_rmsnorm_fake(
def _gemma_rmsnorm_impl_fake(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -348,7 +418,14 @@ def gemma_fused_add_rmsnorm(
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
if _USE_CUDA_NORM:
get_norm_module().gemma_fused_add_rmsnorm(
input, residual, weight, eps, enable_pdl
)
else:
fused_add_rmsnorm_cute(
input, residual, weight, eps, weight_bias=1.0, enable_pdl=enable_pdl
)


@register_fake_op("flashinfer::gemma_fused_add_rmsnorm")
Expand Down Expand Up @@ -388,7 +465,10 @@ def layernorm(
Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
"""
out = torch.empty_like(input)
get_norm_module().layernorm(out, input, gemma, beta, eps)
if _USE_CUDA_NORM:
get_norm_module().layernorm(out, input, gemma, beta, eps)
else:
layernorm_cute(out, input, gemma, beta, eps)
return out


Expand All @@ -404,10 +484,25 @@ def _layernorm_fake(


# CuTe-DSL fused RMSNorm + FP4 Quantization kernels
# These require CuTe-DSL to be available and SM100+ (Blackwell) GPUs
# These require SM100+ (Blackwell) GPUs and nvidia-cutlass-dsl
try:
from .cute_dsl import rmsnorm_fp4quant, add_rmsnorm_fp4quant
from ..cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
from ..cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
except ImportError:
# CuTe-DSL not available
rmsnorm_fp4quant = None # type: ignore[misc,assignment]
add_rmsnorm_fp4quant = None # type: ignore[misc,assignment]
# nvidia-cutlass-dsl not installed, these functions will not be available
pass


# Public API exports
__all__ = [
# JIT module generator (always available)
"gen_norm_module",
# Public APIs
"rmsnorm",
"rmsnorm_quant",
"fused_add_rmsnorm",
"fused_add_rmsnorm_quant",
"gemma_rmsnorm",
"gemma_fused_add_rmsnorm",
"layernorm",
]
Loading
Loading