Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443
Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443bkryu wants to merge 12 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR reorganizes FlashInfer's quantization infrastructure by moving FP4 and FP8 quantization modules into a dedicated Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the MXFP8 quantization implementation by introducing a new, highly optimized backend based on CuTe-DSL. This change provides an alternative, potentially more performant, path for quantization operations, enhancing the flexibility and efficiency of the FlashInfer library. The integration ensures that users can seamlessly switch between CUDA and CuTe-DSL implementations, while comprehensive testing validates the correctness and caching mechanisms of the new kernels. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new cute-dsl backend for MXFP8 quantization, refactoring the existing CUDA implementation. The changes are well-structured, adding new CuTe-DSL kernels for both linear and swizzled layouts, and updating the public API, benchmarks, and tests accordingly. The new kernels correctly use M-agnostic compilation for better performance with varying batch sizes. My review includes a couple of suggestions to improve the maintainability of the new kernel code by explaining a magic number and refactoring a duplicated logic block. The accompanying test updates are comprehensive and include valuable checks for the compilation cache behavior.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/mxfp8_quantize.py`:
- Around line 585-635: The code flattens inputs when input.dim()>2 but never
restores batch dimensions or uses orig_shape; after computing
fp8_output/scale_output and before returning, reshape fp8_tensor and
scale_output back to the original batch shape using orig_shape: for fp8_tensor,
view/reshape to (*orig_shape[:-1], padded_k); for scale_output, convert the 1D
buffer into per-row blocks and then reshape to (*orig_shape[:-1],
num_sf_blocks_per_row) for the linear path (use total_sf_blocks -> view(m,
num_sf_blocks_per_row)), and for the swizzled path convert scale_output via
view(padded_m, padded_sf_cols) then take the first m rows ([:m,
:padded_sf_cols]) and reshape to (*orig_shape[:-1], padded_sf_cols); ensure you
reference orig_shape, padded_k, m, num_sf_blocks_per_row, padded_m and
padded_sf_cols when making these changes.
In `@flashinfer/cute_dsl/quantization_utils.py`:
- Around line 22-23: Remove the unused Uint8 import from the top-level imports
in quantization_utils.py: update the import line that currently reads "from
cutlass import Float32, Int32, Uint32, Uint64, Uint8" to exclude Uint8 so only
used symbols (Float32, Int32, Uint32, Uint64) remain; this will resolve the F401
lint error while leaving functions/classes that reference
Float32/Int32/Uint32/Uint64 untouched.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)
203-210: Silence unuseda_sfwarnings in denormal/zero/mixed tests.Ruff flags
a_sfas unused in several tests. Consider replacing it with_(or_a_sf) to avoid lint noise; same pattern applies to the other occurrences in this file.♻️ Example fix
- a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) + a_fp8, _ = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend)
|
/bot run |
|
/bot stop |
|
The GitLab CI pipeline #42939528 has been cancelled. |
|
/bot run |
|
/bot cancel |
|
Unknown Command Command Use |
|
/bot stop |
|
The GitLab CI pipeline #43311884 has been cancelled. |
|
/bot run |
|
[FAILED] Pipeline #43313609: 14/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@flashinfer/quantization/__init__.py`:
- Around line 56-86: The __all__ list in flashinfer/quantization/__init__.py is
unsorted (RUF022); either alphabetically sort the symbols in the __all__ list
(e.g., ensure entries like "block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float", "get_fp4_quantization_module",
"mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc. are in ASCII order) or if the
current grouped ordering is intentional add an explicit ruff suppression for
RUF022 (e.g., a module-level ruff noqa for RUF022) so the linter is satisfied.
Ensure the change targets the __all__ variable and preserves the conditional
addition when _cute_dsl_available is true.
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 413-441: The variable name `l` is ambiguous (ruff E741); rename it
to a descriptive identifier (e.g., `batch_size` or `num_groups`) throughout this
function and the related functions to fix the lint error and improve
readability: update the unpacking line `l, m, k_by_2 = input.shape` to
`batch_size, m, k_by_2 = input.shape` (or `num_groups`), update all subsequent
uses of `l` (construction of `output`, `output_scales`, reshapes, the call to
module.silu_and_mul_scaled_nvfp4_experts_quantize, and the final permute/view
lines), and apply the same consistent rename in the other affected blocks (the
functions around lines 449-468, 498-530, 538-557) so every reference (e.g.,
output.view(l * m, ...), output_scales.view(..., l, ...), return output,
output_scales) uses the new identifier.
- Around line 563-600: The call to module.e2m1_and_ufp8sf_scale_to_float_sm100
unconditionally calls global_scale_tensor.cpu() but global_scale_tensor is
optional; guard it and pass a CPU tensor when it's None. Update the invocation
in e2m1_and_ufp8sf_scale_to_float_sm100 so that you pass
(global_scale_tensor.cpu() if global_scale_tensor is not None else a default CPU
float32 tensor, e.g. torch.tensor([1.0], dtype=torch.float32, device='cpu')),
ensuring the default matches the expected shape/dtype the custom op requires.
In `@flashinfer/quantization/fp8_quantization.py`:
- Around line 91-101: The fake op function _fake_mxfp8_quantize_sm100 is missing
the enable_pdl parameter present on the real implementation, causing a signature
mismatch; update the _fake_mxfp8_quantize_sm100 definition to add an enable_pdl:
bool = False parameter (keeping existing defaults for is_sf_swizzled_layout and
alignment) and ensure any callers or the returned tensors behavior remain
unchanged so the fake op signature matches the real mxfp8_quantize
implementation.
In `@tests/utils/test_fp4_quantize.py`:
- Around line 20-29: The helper is_fp4_supported in
tests/utils/test_fp4_quantize.py directly calls torch.cuda.get_device_capability
and should instead use flashinfer.utils.get_compute_capability; update the
function to import get_compute_capability and replace the torch call with
get_compute_capability(device) (keep the existing CUDA version parsing and the
same support logic), ensuring the rest of is_fp4_supported still uses
cuda_version from torch.version.cuda and the same major/minor comparisons.
🧹 Nitpick comments (6)
flashinfer/quantization/kernels/__init__.py (1)
39-45: Consider sorting__all__for consistency.Static analysis flagged that
__all__is not sorted. While minor, sorting it alphabetically improves readability and maintainability.🔧 Suggested fix
__all__ = [ "MXFP4QuantizeSwizzledKernel", + "MXFP8QuantizeLinearKernel", + "MXFP8QuantizeSwizzledKernel", "mxfp4_quantize_cute_dsl", - "MXFP8QuantizeLinearKernel", - "MXFP8QuantizeSwizzledKernel", "mxfp8_quantize_cute_dsl", ]flashinfer/quantization/kernels/mxfp4_quantize.py (2)
98-116: Redundant condition in thread count optimization.Line 103's condition
if threads_per_row <= _MAX_THREADSis always true because line 98-100 already handles the case whenthreads_per_row >= _MAX_THREADSand returns early. Theifblock can be simplified.🔧 Suggested simplification
if threads_per_row >= _MAX_THREADS: # Large K: use max threads, will need column loop return _MAX_THREADS - # threads_per_block should be a multiple of threads_per_row - if threads_per_row <= _MAX_THREADS: - # Find largest multiple of threads_per_row <= _MAX_THREADS - threads = (_MAX_THREADS // threads_per_row) * threads_per_row - if threads >= _MIN_THREADS: - return threads - # If largest multiple is below _MIN_THREADS, use the smallest valid one - threads = threads_per_row - while threads < _MIN_THREADS: - threads += threads_per_row - if threads <= _MAX_THREADS: - return threads + # threads_per_block should be a multiple of threads_per_row + # Find largest multiple of threads_per_row <= _MAX_THREADS + threads = (_MAX_THREADS // threads_per_row) * threads_per_row + if threads >= _MIN_THREADS: + return threads + # If largest multiple is below _MIN_THREADS, use the smallest valid one + threads = threads_per_row + while threads < _MIN_THREADS: + threads += threads_per_row + if threads <= _MAX_THREADS: + return threads # Fallback to default return _DEFAULT_THREADS
155-168: Use explicitNoneunion syntax for type hints.PEP 484 prohibits implicit
Optional. Thetarget_gridparameter should use explicit union syntax for consistency with the rest of the codebase (e.g., line 467 usesbool | None).🔧 Suggested fix
def __init__( self, dtype: cutlass.Numeric, K: int, enable_pdl: bool = False, - target_grid: int = None, + target_grid: int | None = None, ):flashinfer/quantization/kernels/mxfp8_quantize.py (2)
75-88: Consider consolidating duplicated_get_target_gridfunction.This function is identical to
_get_target_gridinmxfp4_quantize.py(lines 58-71). Consider moving it toquantization_cute_dsl_utils.pyto avoid code duplication.#!/bin/bash # Verify the duplication echo "=== mxfp4_quantize.py _get_target_grid ===" rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp4_quantize.py echo "" echo "=== mxfp8_quantize.py _get_target_grid ===" rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp8_quantize.py
162-176: Use explicitNoneunion syntax for type hints.For consistency with other parts of the codebase (line 678 uses
bool | None), update thetarget_gridparameter type annotation.🔧 Suggested fix
def __init__( self, dtype: cutlass.Numeric, K: int, enable_pdl: bool = False, - target_grid: int = None, + target_grid: int | None = None, ):Apply the same change to
MXFP8QuantizeSwizzledKernel.__init__(line 314),_get_compiled_kernel_linear(line 583), and_get_compiled_kernel_swizzled(line 630).flashinfer/quantization/quantization_cute_dsl_utils.py (1)
964-1002: Consider sorting__all__for maintainability.While the current organization by category (constants, intrinsics, helpers) is logical, sorting alphabetically or at minimum keeping consistent ordering would help with maintainability as the module grows.
| __all__ = [ | ||
| # Packbits | ||
| "packbits", | ||
| "segment_packbits", | ||
| # JIT module generator | ||
| "gen_quantization_module", | ||
| # FP8 | ||
| "mxfp8_quantize", | ||
| "mxfp8_dequantize_host", | ||
| # FP4 | ||
| "SfLayout", | ||
| "block_scale_interleave", | ||
| "nvfp4_block_scale_interleave", | ||
| "e2m1_and_ufp8sf_scale_to_float", | ||
| "fp4_quantize", | ||
| "mxfp4_dequantize_host", | ||
| "mxfp4_dequantize", | ||
| "mxfp4_quantize", | ||
| "nvfp4_quantize", | ||
| "nvfp4_batched_quantize", | ||
| "shuffle_matrix_a", | ||
| "shuffle_matrix_sf_a", | ||
| "scaled_fp4_grouped_quantize", | ||
| "get_fp4_quantization_module", | ||
| ] | ||
|
|
||
| if _cute_dsl_available: | ||
| __all__ += [ | ||
| "mxfp8_quantize_cute_dsl", | ||
| "mxfp4_quantize_cute_dsl", | ||
| ] |
There was a problem hiding this comment.
Ruff RUF022: __all__ is not sorted.
Consider sorting to satisfy lint, or explicitly suppress if the grouped ordering is intentional.
🔧 Optional suppression to keep grouped ordering
-__all__ = [
+__all__ = [ # noqa: RUF022 - keep grouped exports
# Packbits
"packbits",
"segment_packbits",
@@
-if _cute_dsl_available:
- __all__ += [
+if _cute_dsl_available:
+ __all__ += [ # noqa: RUF022 - keep grouped exports
"mxfp8_quantize_cute_dsl",
"mxfp4_quantize_cute_dsl",
]🧰 Tools
🪛 Ruff (0.14.14)
[warning] 56-80: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
[warning] 83-86: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
In `@flashinfer/quantization/__init__.py` around lines 56 - 86, The __all__ list
in flashinfer/quantization/__init__.py is unsorted (RUF022); either
alphabetically sort the symbols in the __all__ list (e.g., ensure entries like
"block_scale_interleave", "e2m1_and_ufp8sf_scale_to_float",
"get_fp4_quantization_module", "mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc.
are in ASCII order) or if the current grouped ordering is intentional add an
explicit ruff suppression for RUF022 (e.g., a module-level ruff noqa for RUF022)
so the linter is satisfied. Ensure the change targets the __all__ variable and
preserves the conditional addition when _cute_dsl_available is true.
|
@bkryu Q: |
kahyunnam
left a comment
There was a problem hiding this comment.
LGTM, I just left a few questions about compute capability heuristic
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| "10.0": ["cuda", "cute-dsl"], |
There was a problem hiding this comment.
Just curious, why is cute-dsl only enabled above 10.0?
Is it just a future to-do for more testing/benchmarking for <10.0 before enabling?
There was a problem hiding this comment.
Hardware accelerated MXFP8-related instructions are a feature of Blackwell generation. Hopper should be good for (non-MX-) FP8 hence should not be able to run these kernels.
As such on Hopper or prior, we do not expect users to use MXFP8 (software-emulated MXFP8 is possible but perf would likely be unsatisfcatory)
There was a problem hiding this comment.
Oh ... this makes a lot of sense 😅
| It may change or be removed in future versions without notice. | ||
| Use at your own risk for production workloads. | ||
| """ | ||
| if backend == "cute-dsl": |
There was a problem hiding this comment.
Should we also add a compute capability check here for current compilation context (/ current device) being compute >= 10.0, since it seems from benchmarking that we're only testing cute-dsl on 10.0 and above?
📌 Description
This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization kernels as alternatives to JIT-compiled CUDA backends
Key changes:
File Structure Reorganization
Quantization files are now organized in
flashinfer/quantization/:Performance
CuTe DSL kernels are strong compared to CUDA counterparts:
bench_mxfp4_quantize_backend_comparison.pybench_mxfp8_quantize_backend_comparison.pyExpand below for performance heatmaps:
CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see performance comparison data
BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster

BF16 input; Linear cases. > 1.0 means CuTe DSL is faster

BF16 input; Swizzled cases. Annotated values are achieved TB/s

BF16 input; Linear cases. Annotated values are achieved TB/s

CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp4_quantize_backend_comparison.py. Click to see performance comparison data
BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster

BF16 input; Swizzled cases. Annotated values are achieved TB/s

🔍 Related Issues
#2496
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests