Skip to content

Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443

Open
bkryu wants to merge 12 commits intoflashinfer-ai:mainfrom
bkryu:cute-dsl-mxfp8
Open

Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443
bkryu wants to merge 12 commits intoflashinfer-ai:mainfrom
bkryu:cute-dsl-mxfp8

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Jan 30, 2026

📌 Description

This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization kernels as alternatives to JIT-compiled CUDA backends

Key changes:

  • Add CuTe-DSL MXFP8 and MXFP4 quantization kernels
  • Reorganize quantization module structure for better maintainability
  • Add benchmarks and unit tests for backend comparison

File Structure Reorganization
Quantization files are now organized in flashinfer/quantization/:

flashinfer/quantization/
├── __init__.py                    # Package exports
├── fp4_quantization.py            # MXFP4 public API
├── fp8_quantization.py            # MXFP8 public API  
├── packbits.py                    # Utility functions
├── quantization_cute_dsl_utils.py # Shared PTX intrinsics
└── kernels/
    ├── __init__.py                # Kernel exports (EXPERIMENTAL)
    ├── mxfp4_quantize.py          # MXFP4 CuTe-DSL kernel
    └── mxfp8_quantize.py          # MXFP8 CuTe-DSL kernel

Performance
CuTe DSL kernels are strong compared to CUDA counterparts:

  • mxfp4_quantization - Geomean 12x speedup; beats cuda backend in all cases in bench_mxfp4_quantize_backend_comparison.py
  • mxfp8_quantization - Geomean ~1.3x speedup; beats cuda backend in all cases in bench_mxfp8_quantize_backend_comparison.py

Expand below for performance heatmaps:

CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see performance comparison data

BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster
sm100_mxfp8_swizzled_bfloat16

BF16 input; Linear cases. > 1.0 means CuTe DSL is faster
sm100_mxfp8_linear_bfloat16

BF16 input; Swizzled cases. Annotated values are achieved TB/s
sm100_mxfp8_bandwidth_linear_bfloat16

BF16 input; Linear cases. Annotated values are achieved TB/s
sm100_mxfp8_bandwidth_swizzled_bfloat16

CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in ‎bench_mxfp4_quantize_backend_comparison.py. Click to see performance comparison data

BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster
sm100_mxfp4_comparison_bfloat16

BF16 input; Swizzled cases. Annotated values are achieved TB/s
sm100_mxfp4_bandwidth_bfloat16

🔍 Related Issues

#2496

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added CuTe-DSL backend support for MXFP8 and MXFP4 quantization operations, available alongside the existing CUDA backend.
    • Added comprehensive benchmarking scripts for comparing quantization backend performance and correctness.
  • Tests

    • Extended quantization test coverage with backend-specific validation and benchmarking utilities.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 30, 2026

📝 Walkthrough

Walkthrough

This PR reorganizes FlashInfer's quantization infrastructure by moving FP4 and FP8 quantization modules into a dedicated flashinfer/quantization/ subpackage and adds support for the "cute-dsl" backend alongside existing CUDA backends. It introduces CuTe-DSL kernel implementations for both MXFP4 and MXFP8 quantization with caching mechanisms, extends tests and benchmarks to validate both backends, and maintains backward compatibility through wrapper modules.

Changes

Cohort / File(s) Summary
Module Reorganization
flashinfer/fp4_quantization.py, flashinfer/fp8_quantization.py
Converted to backward-compatibility wrappers; now re-export symbols from flashinfer.quantization.fp4_quantization and flashinfer.quantization.fp8_quantization respectively. Internal implementations moved to new locations.
Core Quantization Implementations
flashinfer/quantization/fp4_quantization.py, flashinfer/quantization/fp8_quantization.py
New comprehensive FP4 and FP8 quantization implementations with factory functions, module generation, and public APIs (mxfp4_quantize, nvfp4_quantize, mxfp8_quantize, etc.) supporting CUDA and experimental CuTe-DSL backends.
CuTe-DSL Kernel Implementations
flashinfer/quantization/kernels/mxfp4_quantize.py, flashinfer/quantization/kernels/mxfp8_quantize.py
New CuTe-DSL kernel classes (MXFP4QuantizeSwizzledKernel, MXFP8QuantizeLinearKernel, MXFP8QuantizeSwizzledKernel) with TVM-FFI compilation, caching, and kernel dispatch logic for both small and large K paths.
CuTe-DSL Utilities
flashinfer/quantization/quantization_cute_dsl_utils.py
New module providing low-level SIMD intrinsics, max-reduction helpers, FP8/FP4 conversion primitives, and swizzled indexing utilities for CuTe-DSL quantization kernels.
Package Infrastructure
flashinfer/quantization/__init__.py, flashinfer/quantization/kernels/__init__.py, flashinfer/__init__.py
Updated package exports to re-export public quantization APIs from new locations; added conditional exports for experimental CuTe-DSL kernels; updated main module imports.
Quantization Configuration Updates
flashinfer/quantization/packbits.py, flashinfer/activation.py
Updated relative import paths from . to .. to reflect new package structure; no logic changes.
Benchmark Infrastructure
benchmarks/bench_mxfp4_quantize_backend_comparison.py, benchmarks/bench_mxfp8_quantize_backend_comparison.py
New comprehensive benchmarking scripts for MXFP4 and MXFP8 quantization comparing CUDA vs CuTe-DSL backends, including correctness verification, performance measurement, bandwidth analysis, and heatmap generation.
Benchmark Configurations
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/quantization.py
Extended backend mappings and test routines to support "cute-dsl" backend; added enable_pdl parameter tracking and simplified backend dispatch logic.
Test Extensions
tests/utils/test_fp4_quantize.py, tests/utils/test_fp8_quantize.py
Added is_cute_dsl_available() utility and is_fp4_supported() helper; extended tests with backend parameterization, CuTe-DSL availability checks, compilation cache tests, and comprehensive backend parity and roundtrip validation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

v0.6.2

Suggested reviewers

  • kaixih
  • aleozlx
  • cyx-6
  • nvmbreughe
  • kahyunnam
  • jimmyzho

Poem

🐰 From CUDA's realm to CuTe's bright DSL,
We shuffle scales and swizzle them well,
Two backends now dance in quantized delight,
With kernels cached tight and benchmarks alight—
FlashInfer hops forward, faster in flight! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.07% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add cute-dsl backends to mxfp[8,4]_quantization for future refactor' clearly and concisely describes the main change: adding CuTe-DSL backend support for MXFP8 and MXFP4 quantization, which aligns with the core objective and changeset.
Description check ✅ Passed The PR description is comprehensive and follows the template structure with all major sections: a detailed 📌 Description explaining the changes, 🔍 Related Issues linking to #2496, and a completed 🚀 Pull Request Checklist with pre-commit checks marked done and tests updated.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the MXFP8 quantization implementation by introducing a new, highly optimized backend based on CuTe-DSL. This change provides an alternative, potentially more performant, path for quantization operations, enhancing the flexibility and efficiency of the FlashInfer library. The integration ensures that users can seamlessly switch between CUDA and CuTe-DSL implementations, while comprehensive testing validates the correctness and caching mechanisms of the new kernels.

Highlights

  • New CuTe-DSL MXFP8 Quantization Kernels: Introduced new high-performance MXFP8 quantization kernels implemented using CuTe-DSL, supporting both linear and swizzled (128x4) scale factor layouts. These kernels feature Half2/BFloat2 SIMD for max-abs computation, 4-thread cooperation per scale factor block, vectorized 128-bit global loads/stores, and M-agnostic compilation.
  • Backend Selection for MXFP8 Quantization: The main mxfp8_quantize API now accepts a backend argument, allowing users to choose between the existing JIT-compiled CUDA kernel ('cuda') and the new CuTe-DSL kernel ('cute-dsl'). A runtime check ensures CuTe-DSL is available if selected.
  • Shared Quantization Utilities and Intrinsics: A new quantization_utils.py module was added to house common constants (e.g., SF_VEC_SIZE, FLOAT8_E4M3_MAX) and PTX intrinsics for efficient GPU operations, including Half2/BFloat2 SIMD for max reduction, fast UE8M0 conversion, FP8 conversion with scaling, and warp shuffle for 4-thread reduction.
  • Enhanced Benchmarking and Testing Infrastructure: The benchmarking and testing utilities were updated to support and validate the new 'cute-dsl' backend. This includes adding 'cute-dsl' as a choice in benchmark arguments and extending unit tests to cover the new backend, including specific tests for CuTe-DSL's M-agnostic, K-specific, and dtype-specific compilation caching behavior.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu bkryu marked this pull request as draft January 30, 2026 17:08
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new cute-dsl backend for MXFP8 quantization, refactoring the existing CUDA implementation. The changes are well-structured, adding new CuTe-DSL kernels for both linear and swizzled layouts, and updating the public API, benchmarks, and tests accordingly. The new kernels correctly use M-agnostic compilation for better performance with varying batch sizes. My review includes a couple of suggestions to improve the maintainability of the new kernel code by explaining a magic number and refactoring a duplicated logic block. The accompanying test updates are comprehensive and include valuable checks for the compilation cache behavior.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/mxfp8_quantize.py`:
- Around line 585-635: The code flattens inputs when input.dim()>2 but never
restores batch dimensions or uses orig_shape; after computing
fp8_output/scale_output and before returning, reshape fp8_tensor and
scale_output back to the original batch shape using orig_shape: for fp8_tensor,
view/reshape to (*orig_shape[:-1], padded_k); for scale_output, convert the 1D
buffer into per-row blocks and then reshape to (*orig_shape[:-1],
num_sf_blocks_per_row) for the linear path (use total_sf_blocks -> view(m,
num_sf_blocks_per_row)), and for the swizzled path convert scale_output via
view(padded_m, padded_sf_cols) then take the first m rows ([:m,
:padded_sf_cols]) and reshape to (*orig_shape[:-1], padded_sf_cols); ensure you
reference orig_shape, padded_k, m, num_sf_blocks_per_row, padded_m and
padded_sf_cols when making these changes.

In `@flashinfer/cute_dsl/quantization_utils.py`:
- Around line 22-23: Remove the unused Uint8 import from the top-level imports
in quantization_utils.py: update the import line that currently reads "from
cutlass import Float32, Int32, Uint32, Uint64, Uint8" to exclude Uint8 so only
used symbols (Float32, Int32, Uint32, Uint64) remain; this will resolve the F401
lint error while leaving functions/classes that reference
Float32/Int32/Uint32/Uint64 untouched.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)

203-210: Silence unused a_sf warnings in denormal/zero/mixed tests.

Ruff flags a_sf as unused in several tests. Consider replacing it with _ (or _a_sf) to avoid lint noise; same pattern applies to the other occurrences in this file.

♻️ Example fix
-    a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend)
+    a_fp8, _ = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend)

@bkryu
Copy link
Collaborator Author

bkryu commented Jan 30, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !283 has been created, and the CI pipeline #42939528 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator Author

bkryu commented Jan 31, 2026

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #42939528 has been cancelled.

@bkryu bkryu changed the title refactor: refactor mxfp8 quantization cuda code to cute-dsl Add mxfp8 and mxfp4 quantization cute-dsl backends for future refactor Feb 4, 2026
@bkryu bkryu changed the title Add mxfp8 and mxfp4 quantization cute-dsl backends for future refactor Add cute-dsl backends to mxfp[8,4]_quantization for future refactor Feb 5, 2026
@bkryu
Copy link
Collaborator Author

bkryu commented Feb 5, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !283 has been updated with latest changes, and the CI pipeline #43311884 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator Author

bkryu commented Feb 5, 2026

/bot cancel

@flashinfer-bot
Copy link
Collaborator

Unknown Command

Command /bot cancel is not recognized.

Use /bot help for available commands.

@bkryu
Copy link
Collaborator Author

bkryu commented Feb 5, 2026

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #43311884 has been cancelled.

@bkryu
Copy link
Collaborator Author

bkryu commented Feb 5, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !283 has been updated with latest changes, and the CI pipeline #43313609 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43313609: 14/20 passed

@bkryu bkryu marked this pull request as ready for review February 5, 2026 17:04
@bkryu bkryu requested a review from kahyunnam as a code owner February 5, 2026 17:04
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@flashinfer/quantization/__init__.py`:
- Around line 56-86: The __all__ list in flashinfer/quantization/__init__.py is
unsorted (RUF022); either alphabetically sort the symbols in the __all__ list
(e.g., ensure entries like "block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float", "get_fp4_quantization_module",
"mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc. are in ASCII order) or if the
current grouped ordering is intentional add an explicit ruff suppression for
RUF022 (e.g., a module-level ruff noqa for RUF022) so the linter is satisfied.
Ensure the change targets the __all__ variable and preserves the conditional
addition when _cute_dsl_available is true.

In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 413-441: The variable name `l` is ambiguous (ruff E741); rename it
to a descriptive identifier (e.g., `batch_size` or `num_groups`) throughout this
function and the related functions to fix the lint error and improve
readability: update the unpacking line `l, m, k_by_2 = input.shape` to
`batch_size, m, k_by_2 = input.shape` (or `num_groups`), update all subsequent
uses of `l` (construction of `output`, `output_scales`, reshapes, the call to
module.silu_and_mul_scaled_nvfp4_experts_quantize, and the final permute/view
lines), and apply the same consistent rename in the other affected blocks (the
functions around lines 449-468, 498-530, 538-557) so every reference (e.g.,
output.view(l * m, ...), output_scales.view(..., l, ...), return output,
output_scales) uses the new identifier.
- Around line 563-600: The call to module.e2m1_and_ufp8sf_scale_to_float_sm100
unconditionally calls global_scale_tensor.cpu() but global_scale_tensor is
optional; guard it and pass a CPU tensor when it's None. Update the invocation
in e2m1_and_ufp8sf_scale_to_float_sm100 so that you pass
(global_scale_tensor.cpu() if global_scale_tensor is not None else a default CPU
float32 tensor, e.g. torch.tensor([1.0], dtype=torch.float32, device='cpu')),
ensuring the default matches the expected shape/dtype the custom op requires.

In `@flashinfer/quantization/fp8_quantization.py`:
- Around line 91-101: The fake op function _fake_mxfp8_quantize_sm100 is missing
the enable_pdl parameter present on the real implementation, causing a signature
mismatch; update the _fake_mxfp8_quantize_sm100 definition to add an enable_pdl:
bool = False parameter (keeping existing defaults for is_sf_swizzled_layout and
alignment) and ensure any callers or the returned tensors behavior remain
unchanged so the fake op signature matches the real mxfp8_quantize
implementation.

In `@tests/utils/test_fp4_quantize.py`:
- Around line 20-29: The helper is_fp4_supported in
tests/utils/test_fp4_quantize.py directly calls torch.cuda.get_device_capability
and should instead use flashinfer.utils.get_compute_capability; update the
function to import get_compute_capability and replace the torch call with
get_compute_capability(device) (keep the existing CUDA version parsing and the
same support logic), ensuring the rest of is_fp4_supported still uses
cuda_version from torch.version.cuda and the same major/minor comparisons.
🧹 Nitpick comments (6)
flashinfer/quantization/kernels/__init__.py (1)

39-45: Consider sorting __all__ for consistency.

Static analysis flagged that __all__ is not sorted. While minor, sorting it alphabetically improves readability and maintainability.

🔧 Suggested fix
 __all__ = [
     "MXFP4QuantizeSwizzledKernel",
+    "MXFP8QuantizeLinearKernel",
+    "MXFP8QuantizeSwizzledKernel",
     "mxfp4_quantize_cute_dsl",
-    "MXFP8QuantizeLinearKernel",
-    "MXFP8QuantizeSwizzledKernel",
     "mxfp8_quantize_cute_dsl",
 ]
flashinfer/quantization/kernels/mxfp4_quantize.py (2)

98-116: Redundant condition in thread count optimization.

Line 103's condition if threads_per_row <= _MAX_THREADS is always true because line 98-100 already handles the case when threads_per_row >= _MAX_THREADS and returns early. The if block can be simplified.

🔧 Suggested simplification
     if threads_per_row >= _MAX_THREADS:
         # Large K: use max threads, will need column loop
         return _MAX_THREADS

-    # threads_per_block should be a multiple of threads_per_row
-    if threads_per_row <= _MAX_THREADS:
-        # Find largest multiple of threads_per_row <= _MAX_THREADS
-        threads = (_MAX_THREADS // threads_per_row) * threads_per_row
-        if threads >= _MIN_THREADS:
-            return threads
-        # If largest multiple is below _MIN_THREADS, use the smallest valid one
-        threads = threads_per_row
-        while threads < _MIN_THREADS:
-            threads += threads_per_row
-        if threads <= _MAX_THREADS:
-            return threads
+    # threads_per_block should be a multiple of threads_per_row
+    # Find largest multiple of threads_per_row <= _MAX_THREADS
+    threads = (_MAX_THREADS // threads_per_row) * threads_per_row
+    if threads >= _MIN_THREADS:
+        return threads
+    # If largest multiple is below _MIN_THREADS, use the smallest valid one
+    threads = threads_per_row
+    while threads < _MIN_THREADS:
+        threads += threads_per_row
+    if threads <= _MAX_THREADS:
+        return threads

     # Fallback to default
     return _DEFAULT_THREADS

155-168: Use explicit None union syntax for type hints.

PEP 484 prohibits implicit Optional. The target_grid parameter should use explicit union syntax for consistency with the rest of the codebase (e.g., line 467 uses bool | None).

🔧 Suggested fix
     def __init__(
         self,
         dtype: cutlass.Numeric,
         K: int,
         enable_pdl: bool = False,
-        target_grid: int = None,
+        target_grid: int | None = None,
     ):
flashinfer/quantization/kernels/mxfp8_quantize.py (2)

75-88: Consider consolidating duplicated _get_target_grid function.

This function is identical to _get_target_grid in mxfp4_quantize.py (lines 58-71). Consider moving it to quantization_cute_dsl_utils.py to avoid code duplication.

#!/bin/bash
# Verify the duplication
echo "=== mxfp4_quantize.py _get_target_grid ==="
rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp4_quantize.py

echo ""
echo "=== mxfp8_quantize.py _get_target_grid ==="
rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp8_quantize.py

162-176: Use explicit None union syntax for type hints.

For consistency with other parts of the codebase (line 678 uses bool | None), update the target_grid parameter type annotation.

🔧 Suggested fix
     def __init__(
         self,
         dtype: cutlass.Numeric,
         K: int,
         enable_pdl: bool = False,
-        target_grid: int = None,
+        target_grid: int | None = None,
     ):

Apply the same change to MXFP8QuantizeSwizzledKernel.__init__ (line 314), _get_compiled_kernel_linear (line 583), and _get_compiled_kernel_swizzled (line 630).

flashinfer/quantization/quantization_cute_dsl_utils.py (1)

964-1002: Consider sorting __all__ for maintainability.

While the current organization by category (constants, intrinsics, helpers) is logical, sorting alphabetically or at minimum keeping consistent ordering would help with maintainability as the module grows.

Comment on lines +56 to +86
__all__ = [
# Packbits
"packbits",
"segment_packbits",
# JIT module generator
"gen_quantization_module",
# FP8
"mxfp8_quantize",
"mxfp8_dequantize_host",
# FP4
"SfLayout",
"block_scale_interleave",
"nvfp4_block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float",
"fp4_quantize",
"mxfp4_dequantize_host",
"mxfp4_dequantize",
"mxfp4_quantize",
"nvfp4_quantize",
"nvfp4_batched_quantize",
"shuffle_matrix_a",
"shuffle_matrix_sf_a",
"scaled_fp4_grouped_quantize",
"get_fp4_quantization_module",
]

if _cute_dsl_available:
__all__ += [
"mxfp8_quantize_cute_dsl",
"mxfp4_quantize_cute_dsl",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ruff RUF022: __all__ is not sorted.
Consider sorting to satisfy lint, or explicitly suppress if the grouped ordering is intentional.

🔧 Optional suppression to keep grouped ordering
-__all__ = [
+__all__ = [  # noqa: RUF022 - keep grouped exports
     # Packbits
     "packbits",
     "segment_packbits",
@@
-if _cute_dsl_available:
-    __all__ += [
+if _cute_dsl_available:
+    __all__ += [  # noqa: RUF022 - keep grouped exports
         "mxfp8_quantize_cute_dsl",
         "mxfp4_quantize_cute_dsl",
     ]
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 56-80: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


[warning] 83-86: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
In `@flashinfer/quantization/__init__.py` around lines 56 - 86, The __all__ list
in flashinfer/quantization/__init__.py is unsorted (RUF022); either
alphabetically sort the symbols in the __all__ list (e.g., ensure entries like
"block_scale_interleave", "e2m1_and_ufp8sf_scale_to_float",
"get_fp4_quantization_module", "mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc.
are in ASCII order) or if the current grouped ordering is intentional add an
explicit ruff suppression for RUF022 (e.g., a module-level ruff noqa for RUF022)
so the linter is satisfied. Ensure the change targets the __all__ variable and
preserves the conditional addition when _cute_dsl_available is true.

@vincentzed
Copy link
Contributor

@bkryu Q:
Did you find success in making 256 bit load in CuteDSL, which I found was not possible seems like? For example, in pure

Cutlass:
https://github.com/HydraQYH/hp_rms_norm/blob/master/hp_rms_norm/csrc/cuda/hp_rms_norm.cuh

Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I just left a few questions about compute capability heuristic

"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
"10.0": ["cuda", "cute-dsl"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why is cute-dsl only enabled above 10.0?

Is it just a future to-do for more testing/benchmarking for <10.0 before enabling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardware accelerated MXFP8-related instructions are a feature of Blackwell generation. Hopper should be good for (non-MX-) FP8 hence should not be able to run these kernels.

As such on Hopper or prior, we do not expect users to use MXFP8 (software-emulated MXFP8 is possible but perf would likely be unsatisfcatory)

Copy link
Collaborator

@kahyunnam kahyunnam Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ... this makes a lot of sense 😅

It may change or be removed in future versions without notice.
Use at your own risk for production workloads.
"""
if backend == "cute-dsl":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add a compute capability check here for current compilation context (/ current device) being compute >= 10.0, since it seems from benchmarking that we're only testing cute-dsl on 10.0 and above?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants