Skip to content

feat: Add MXFP8 GEMM mm_mxfp8 (cutlass)#2464

Merged
yongwww merged 11 commits intoflashinfer-ai:mainfrom
danisereb:support_mm_mxfp8
Feb 12, 2026
Merged

feat: Add MXFP8 GEMM mm_mxfp8 (cutlass)#2464
yongwww merged 11 commits intoflashinfer-ai:mainfrom
danisereb:support_mm_mxfp8

Conversation

@danisereb
Copy link
Contributor

@danisereb danisereb commented Feb 2, 2026

📌 Description

Add new API mm_mxfp8 to support MXFP8 GEMM (using cutlass).

The logic/flow of mm_mxfp8 was based on mm_fp4.

The MXFP8 CUTLASS code is based on the existing Flashinfer NVFP4 code.

Will be used for MXFP8 support in vLLM (with ModelOpt MXFP8 checkpoints).

Performance results (Blackwell, B200)

Comparison of MXFP8 and FP8

MXFP8

python3 benchmarks/flashinfer_benchmark.py -vv \
--routine mm_mxfp8 \
--backends cutlass \
--autotune \
--input_dtype bfloat16 \
--out_dtype bfloat16 \
--m 8192 --n 8192 --k 4096

Results:

--m 1024 --n 2048 --k 1024
[PERF] cutlass_autotun:: median time 0.005 ms; std 0.000 ms; 
achieved tflops 789.981 TFLOPs/sec; achieved tb_per_sec 1.350 TB/sec

--m 1024 --n 4096 --k 4096
[PERF] cutlass_autotun:: median time 0.017 ms; std 0.000 ms; 
achieved tflops 2014.903 TFLOPs/sec; achieved tb_per_sec 1.722 TB/sec

--m 8192 --n 8192 --k 4096
[PERF] cutlass_autotun:: median time 0.172 ms; std 0.001 ms; 
achieved tflops 3200.870 TFLOPs/sec; achieved tb_per_sec 1.172 TB/sec

--m 16384 --n 16384 --k 16384
[PERF] cutlass_autotun:: median time 2.838 ms; std 0.105 ms; 
achieved tflops 3099.690 TFLOPs/sec; achieved tb_per_sec 0.378 TB/sec

FP8

python3 benchmarks/flashinfer_benchmark.py -vv \
--routine bmm_fp8 \
--backends cutlass \
--autotune \
--input_dtype fp8_e4m3 \
--out_dtype bfloat16 \
--batch_size 1 \
--m 8192 --n 8192 --k 4096

Results:

--m 1024 --n 2048 --k 1024
[PERF] cutlass_autotun:: median time 0.005 ms; std 0.000 ms; 
achieved tflops 858.444 TFLOPs/sec; achieved tb_per_sec 1.467 TB/sec

--m 1024 --n 4096 --k 4096
[PERF] cutlass_autotun:: median time 0.016 ms; std 0.000 ms; 
achieved tflops 2087.976 TFLOPs/sec; achieved tb_per_sec 1.784 TB/sec

--m 8192 --n 8192 --k 4096
[PERF] cutlass_autotun:: median time 0.176 ms; std 0.004 ms; 
achieved tflops 3116.445 TFLOPs/sec; achieved tb_per_sec 1.141 TB/sec

--m 16384 --n 16384 --k 16384
[PERF] cutlass_autotun:: median time 3.352 ms; std 0.109 ms; 
achieved tflops 2624.382 TFLOPs/sec; achieved tb_per_sec 0.320 TB/sec

More about MXFP8 performance (BF16 / FP8 / MXFP8)

https://cursor.com/ja/blog/kernels

When scale factors must also reside in TMEM, however, we can only execute a single 256x32x256 tcgen05.mma instruction at a time using just a 128x256 region of TMEM. As a result, performance degradation is unavoidable. 

For example, the throughput of a 16,384x16,384x16,384 FP8 matrix multiplication drops from 3,200 TFLOP/s to 3,040 TFLOP/s under this constraint.

These throughput numbers apply only to pure FP8 matrix multiplication.
With MXFP8 block scaling, throughput inevitably decreases further due to TMEM pipelining overhead.

In practice, we achieve around 2,750 TFLOP/s with L2 cache clearance for block-scaled MXFP8 matrix multiplication kernels.
Even so, this remains ~1.83x faster than standard BF16 matrix multiplication, which typically reaches 1,500~1,550 TFLOP/s on optimal shapes.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Adds mm_mxfp8: MXFP8 quantized matrix multiply with native Cutlass CUDA path, BF16/FP16 outputs, autotuning, workspace management, backend selection, and performance metrics.
  • Tests

    • Adds comprehensive end-to-end tests covering correctness, layouts/swizzle, scale handling, dtypes, error cases, large workloads, and realistic model simulations.
  • Chores

    • Updates benchmarks, CLI listings, public exports, JIT generation, and backend capability mappings to include mm_mxfp8.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds end-to-end MXFP8 GEMM support: new public Python API mm_mxfp8, Cutlass SM100 JIT generator and runner, CUDA/Jinja kernel sources and headers/templates, benchmark registration and harness, autotuning integration, and comprehensive unit tests.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/gemm.py
Register mm_mxfp8 in benchmark APIs and CC→backend mapping; add testMmMxfp8 harness, autotune/help updates, timing, and reference-validation logic.
Python exports
flashinfer/__init__.py, flashinfer/gemm/__init__.py
Export mm_mxfp8 at package and subpackage levels and include in __all__.
GEMM core & API (Python)
flashinfer/gemm/gemm_base.py
Add CUTLASS MXFP8 module creator/runner, backend selection/heuristic, problem/scale validation, autotuning config (_MM_MXFP8_TUNING_CONFIG), workspace handling, _mxfp8_swizzled_scale_len, and public mm_mxfp8 API with backend/autotune integration.
JIT generator
flashinfer/jit/gemm/__init__.py, flashinfer/jit/gemm/core.py
Add gen_gemm_sm100_module_cutlass_mxfp8 JIT spec/source generation for SM100 MXFP8 modules with NVCC flags and BF16 defines.
CUDA + FFI
csrc/mxfp8_gemm_cutlass.cu, csrc/mxfp8_gemm_cutlass.jinja
New Cutlass MXFP8 CUDA runner with TVM/FFI exports (mxfp8_gemm, mxfp8_gemm_tactic_num), tactic/config selection, input/scale validation, workspace allocation, and kernel instantiations.
C++ headers & templates
include/flashinfer/gemm/mxfp8_gemm_cutlass.h, include/flashinfer/gemm/mxfp8_gemm_cutlass_template.h, include/flashinfer/gemm/mxfp8_gemm_template_sm100.h
Introduce MXFP8 runner interface, templated Cutlass dispatch, SM100-specific kernel launchers, config enumeration, workspace sizing, and instantiation macros/templates.
Tests
tests/gemm/test_mm_mxfp8.py
Add extensive unit tests and helpers covering correctness, layout/scale semantics, error cases, large dims, cosine-similarity validation, autotuning, and LLM-like simulations.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant API as mm_mxfp8_API
    participant Validator as ProblemValidator
    participant Tuner as AutoTuner
    participant Backend as CutlassRunner
    participant CUDA as CUDA_mxfp8_gemm

    User->>API: call mm_mxfp8(A,B,a_descale,b_descale[,out])
    API->>Validator: validate shapes, dtypes, scales, layout
    Validator-->>API: ok / error
    API->>Tuner: select backend & tactic (heuristic / autotune)
    Tuner->>Backend: query tactics/configs & workspace requirements
    Backend-->>Tuner: tactics + workspace needs
    Tuner->>Backend: choose tactic, request workspace
    Backend->>CUDA: launch mxfp8_gemm(tactic, workspace, inputs)
    CUDA-->>Backend: execution status + output
    Backend-->>API: return result tensor
    API-->>User: deliver output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

v0.6.2

Suggested reviewers

  • aleozlx
  • djmmoss
  • yongwww
  • cyx-6
  • bkryu
  • yzh119

"I hopped through code with nimble feet,
Cutlass tiles aligned in rows so neat,
MXFP8 kernels learned to hum,
Autotuned tactics beat the drum,
I munch a carrot — tests all green!" 🐇✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.91% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding a new MXFP8 GEMM API (mm_mxfp8) using CUTLASS, which aligns with the substantial changeset across multiple files.
Description check ✅ Passed The description provides a clear overview of what was added (mm_mxfp8 API for MXFP8 GEMM support using CUTLASS), explains the motivation (vLLM integration with ModelOpt MXFP8 checkpoints), includes performance benchmarks comparing MXFP8 vs FP8, and confirms all checklist items are completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

We've launched Issue Planner and it is currently in beta. Please try it out and share your feedback on Discord!


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danisereb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities by introducing support for MXFP8 GEMM operations. This new functionality, powered by CUTLASS, is crucial for enabling efficient mixed-precision computations, particularly for large language models (LLMs) and their deployment with vLLM and ModelOpt MXFP8 checkpoints. The changes include core kernel implementations, Python API exposure, and robust testing to ensure high accuracy and performance on modern NVIDIA architectures.

Highlights

  • New MXFP8 GEMM Functionality: Introduced a new mm_mxfp8 function to perform General Matrix Multiplication (GEMM) using MXFP8 (mixed-precision floating-point 8) data types, leveraging the CUTLASS library for optimized performance on NVIDIA GPUs.
  • CUTLASS Integration: Integrated CUTLASS kernels for MXFP8 GEMM, supporting various CTA and cluster shapes, and providing an autotuning mechanism to select the most efficient kernel configuration for given problem sizes.
  • Flexible Scale Factor Handling: The mm_mxfp8 API now supports both 1D swizzled and 2D non-swizzled block scaling factors, with internal logic to ensure correct memory layout and contiguity for CUTLASS and cuDNN backends.
  • Enhanced cuDNN Support for MXFP8: Updated the cuDNN integration to correctly handle MXFP8 scales, including logic to expand 2D tensors to 3D (batch=1) and dynamically set reordering types based on whether scales are swizzled.
  • Comprehensive Benchmarking and Testing: Added a new benchmarking routine (testMmMxfp8) and extensive unit tests covering various dimensions, value ranges, invalid inputs, and realistic LLM layer simulations to ensure accuracy and robustness of the mm_mxfp8 operation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MXFP8 GEMM using the CUTLASS library. The changes include new C++/CUDA kernels, Python bindings, benchmark routines, and unit tests. The implementation looks solid and follows the existing patterns in the repository. I've identified a couple of bugs in the new benchmark code and a minor issue in a test case that should be addressed. Overall, this is a great addition to the library.

@danisereb danisereb changed the title Add MXFP8 GEMM mm_mxfp8 (cutlass) feat: Add MXFP8 GEMM mm_mxfp8 (cutlass) Feb 2, 2026
@danisereb danisereb marked this pull request as ready for review February 2, 2026 20:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Fix all issues with AI agents
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 300-309: The mm_mxfp8 backend map is missing an entry for SM110 so
filter_backends_by_compute_capability() will drop the CUTLASS backend on compute
capability 11.x; update the "mm_mxfp8" dict (the map shown) to include an "11.0"
key with the same backend list as other supported SM versions (e.g., add "11.0":
["cutlass"]) so CUTLASS is not filtered out for SM110.

In `@benchmarks/routines/gemm.py`:
- Around line 1268-1305: The mm_mxfp8 routine is being passed the default
backends (e.g., ["cudnn"]) which get filtered out and cause an early exit;
update the argument handling so that when args.routine == "mm_mxfp8" you force
the backends variable to include only "cutlass" before calling
filter_backends_by_compute_capability and before any autotune filtering (modify
the backends assignment/handling around the existing backends variable and the
autotune_supported_backends list); this ensures backends contains "cutlass" for
mm_mxfp8 and avoids the early return while preserving downstream filtering
logic.

In `@csrc/mxfp8_gemm_cutlass.cu`:
- Around line 94-157: mxfp8_bmm_impl currently only type-checks mat1Scale and
mat2Scale; add shape validation using the sfVecSize=32 contract to prevent OOB
scale access: compute sfVecSize=32 and expected scale lengths as scale_len =
(dim + sfVecSize - 1) / sfVecSize, then for the 2D path require
mat1Scale.ndim()==1 and mat1Scale.size(0)==scale_len(m) and mat2Scale.ndim()==1
and mat2Scale.size(0)==scale_len(n); for the 3D path require mat1Scale.ndim()==2
and mat1Scale.size(0)==b and mat1Scale.size(1)==scale_len(m) (and similarly
mat2Scale.size(0)==b and mat2Scale.size(1)==scale_len(n)); use TVM_FFI_ICHECK_EQ
with clear messages referencing mxfp8_bmm_impl, mat1Scale and mat2Scale so
incorrect shapes fail before kernel launch.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 2482-2517: The function _check_mm_mxfp8_problem_size currently
ignores a user-provided out tensor; add validation when out is not None to
ensure out is 2D with shape (a.shape[0], b.shape[1]) (remember b is passed
transposed as [k, n]), on the same device as a (or b), and has dtype equal to
out_dtype (after calling _validate_mxfp8_output_dtype). If any of these checks
fail, raise ValueError with a clear message referencing expected vs actual
shape/device/dtype.

In `@flashinfer/jit/gemm/core.py`:
- Around line 239-244: In gen_gemm_sm100_module_cutlass_mxfp8(), copy the
CUTLASS source mxfp8_gemm_cutlass.cu from jit_env.FLASHINFER_CSRC_DIR into
gen_directory (e.g., using shutil.copy2 or Path operations) before constructing
source_paths, and then add the copied file path (located under gen_directory) to
source_paths instead of referencing the original FLASHINFER_CSRC_DIR file; this
ensures builds reference the hermetic, cached file in gen_directory while
leaving existing os.makedirs(...) intact.

In `@include/flashinfer/gemm/mxfp8_gemm_cutlass_template.h`:
- Around line 119-171: The dispatcher dispatchMXFP8xMXFP8GemmCTAShapeSm100
currently accepts CutlassTileConfigSM100::CtaShape128x64x128B (N=64) which
contradicts the comment "Cta N should be one of 128/192/256"; either remove that
case or make the code and comment consistent. To fix: decide which behavior is
correct, then (A) if N=64 is invalid, remove the CtaShape128x64x128B case and
replace it with a branch that throws a clear std::runtime_error (same style as
other error branches), or (B) if N=64 is supported, update the top comment to
document that N=64 is allowed for MXFP8 and add a short inline comment above the
CtaShape128x64x128B case explaining why N=64 is permitted; reference function
name dispatchMXFP8xMXFP8GemmCTAShapeSm100 and the enum value
CutlassTileConfigSM100::CtaShape128x64x128B when making the change.
- Around line 271-299: The static workspace_hashmap in CutlassMxfp8GemmRunner<T,
mxfp8GemmType>::getWorkspaceSize is mutated without synchronization causing data
races; add a std::mutex (include <mutex>) scoped at the same translation-unit
scope as workspace_hashmap (e.g., static std::mutex workspace_mutex) and wrap
all accesses and modifications of workspace_hashmap in a
std::lock_guard<std::mutex> to serialize lookups and inserts (both the find, the
call to getWorkspaceSizeImpl, and the subsequent assignment/get). Ensure the
mutex protects uses of MNK/MNKHash keys created via
std::make_tuple(m,n,k,batch_count) so getWorkspaceSize is thread-safe.

In `@tests/gemm/test_mm_mxfp8.py`:
- Around line 376-387: The test currently catches any RuntimeError from mm_mxfp8
and unconditionally skips, which hides real failures; change the except block
around mm_mxfp8 to inspect the exception message (or specific exception type if
available) and only call pytest.skip for known unsupported/backends cases (e.g.,
message contains "unsupported", "cutlass not available", or other
project-specific sentinel), otherwise re-raise the exception so real regressions
surface; keep references to mm_mxfp8 and the contextual m, n, k values in the
skip message.
🧹 Nitpick comments (5)
csrc/mxfp8_gemm_cutlass.cu (1)

158-168: Consider replacing with DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 macro for consistency.

The manual switch on dtype works correctly, but other GEMM kernel launchers in the project (e.g., csrc/gemm_groupwise_sm100.cu) use the standard dispatch macros defined in csrc/tvm_ffi_utils.h. Replace the switch with DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(out.dtype(), c_type, [&] { runGemm<c_type>(...); }) to align with project conventions.

tests/gemm/test_mm_mxfp8.py (4)

45-66: Avoid duplicating GPU‑compatibility checks.
Keep the skip policy in one place so future updates don’t drift.

♻️ Suggested refactor
 def _run_mm_mxfp8(
     m,
     n,
     k,
     input_dtype,
     is_sf_swizzled_layout,
     out_dtype,
     backend,
     auto_tuning,
     provide_out,
 ):
-    compute_capability = get_compute_capability(torch.device("cuda"))
-    if compute_capability[0] in [11, 12]:
-        pytest.skip("Not tested on SM110/SM120/SM121")
-    if compute_capability[0] < 10:
-        pytest.skip("mm_mxfp8 is only supported on SM100 and above GPUs.")
+    _skip_if_unsupported()

Also applies to: 161-166


114-158: Consider trimming the parameter grid to keep runtime manageable.
The Cartesian product here is large; consider a smaller smoke matrix plus a separate extended/slow suite to keep CI stable.


334-340: Prefer next(...) over single‑element list indexing.
Avoids materializing a list and matches common Python idioms.

♻️ Suggested tweak
-        min_scale = [scale for scale, sim in results if sim == min_sim][0]
+        min_scale = next(scale for scale, sim in results if sim == min_sim)

483-525: Clarify the non‑contiguous scale expectation.
The test name implies a contiguity requirement, but the assertions only check finiteness. If non‑contig is unsupported, assert it raises; if supported, compare outputs and rename the test accordingly.

Comment on lines +239 to +244
def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec:
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8"
os.makedirs(gen_directory, exist_ok=True)
source_paths = [
jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu",
]
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Copy the CUTLASS MXFP8 source into the gen directory before adding to source_paths.
The JIT generators should reference a copied source under gen_directory to keep generated builds hermetic and cacheable.

🔧 Suggested fix
-    source_paths = [
-        jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu",
-    ]
+    src_path = jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu"
+    dest_path = gen_directory / "mxfp8_gemm_cutlass.cu"
+    with open(src_path, "r") as f:
+        source = f.read()
+    write_if_different(dest_path, source)
+    source_paths = [dest_path]
As per coding guidelines: Copy source files from FLASHINFER_CSRC_DIR to gen_directory before adding to sources list in gen_*_module() functions.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec:
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8"
os.makedirs(gen_directory, exist_ok=True)
source_paths = [
jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu",
]
def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec:
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8"
os.makedirs(gen_directory, exist_ok=True)
src_path = jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu"
dest_path = gen_directory / "mxfp8_gemm_cutlass.cu"
with open(src_path, "r") as f:
source = f.read()
write_if_different(dest_path, source)
source_paths = [dest_path]
🤖 Prompt for AI Agents
In `@flashinfer/jit/gemm/core.py` around lines 239 - 244, In
gen_gemm_sm100_module_cutlass_mxfp8(), copy the CUTLASS source
mxfp8_gemm_cutlass.cu from jit_env.FLASHINFER_CSRC_DIR into gen_directory (e.g.,
using shutil.copy2 or Path operations) before constructing source_paths, and
then add the copied file path (located under gen_directory) to source_paths
instead of referencing the original FLASHINFER_CSRC_DIR file; this ensures
builds reference the hermetic, cached file in gen_directory while leaving
existing os.makedirs(...) intact.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhiraj113 @aleozlx
What do you think ?
Is this suggested fix required ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. We can leave it as is for now and take a look later if required.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2482-2531: The function _check_mm_mxfp8_problem_size currently
only checks dtypes for a_descale/b_descale; add shape validation to ensure the
scale tensors cannot underflow the kernel: for each of a_descale and b_descale
accept exactly two layouts — swizzled 1D or non‑swizzled 2D — and assert their
shapes match the matrix tiling used by the kernel (for a_descale validate
against A's row partitioning derived from a.shape[0], for b_descale validate
against B's column/column‑block partitioning derived from b.shape[1]);
specifically, if a_descale.ndim == 1 ensure a_descale.shape[0] ==
num_blocks_for_rows(a.shape[0]) (and likewise for b_descale with
num_blocks_for_cols(b.shape[1])), and if ndim == 2 ensure its shape equals
(num_blocks, block_scale_width) as used by the kernel; use the same block/tile
constants/functions the implementation uses to compute num_blocks (don’t
hardcode magic numbers) and raise ValueError with a clear message referencing
a_descale.shape / b_descale.shape when mismatched.
🧹 Nitpick comments (2)
tests/gemm/test_mm_mxfp8.py (2)

56-60: DRY the GPU capability skip logic.

_run_mm_mxfp8 duplicates _skip_if_unsupported; calling the helper keeps skip criteria consistent and avoids divergence.

♻️ Proposed refactor
 def _run_mm_mxfp8(
     m,
     n,
     k,
     input_dtype,
     is_sf_swizzled_layout,
     out_dtype,
     backend,
     auto_tuning,
     provide_out,
 ):
-    compute_capability = get_compute_capability(torch.device("cuda"))
-    if compute_capability[0] in [11, 12]:
-        pytest.skip("Not tested on SM110/SM120/SM121")
-    if compute_capability[0] < 10:
-        pytest.skip("mm_mxfp8 is only supported on SM100 and above GPUs.")
+    _skip_if_unsupported()

338-340: Use next(...) to avoid building a list for min_scale.

Slightly clearer and avoids allocating a list.

♻️ Suggested tweak
-        min_scale = [scale for scale, sim in results if sim == min_sim][0]
+        min_scale = next(scale for scale, sim in results if sim == min_sim)

@danisereb danisereb force-pushed the support_mm_mxfp8 branch 2 times, most recently from b57b6bf to 2dd82ea Compare February 3, 2026 09:45
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@csrc/mxfp8_gemm_cutlass.cu`:
- Around line 180-214: The validation is checking scale lengths along M/N
instead of K; update the mat1Scale/mat2Scale non-swizzled checks in
mxfp8_bmm_impl so the per-scale dimension is scale_len(k) (not
scale_len(m)/scale_len(n)) and the other dims match either the unbatched 2D
shape ([M, scale_len(k)] / [N, scale_len(k)]) or the batched shapes ([B, M,
scale_len(k)] or flattened [B*M, scale_len(k)] / [B, N, scale_len(k)] or
flattened [B*N, scale_len(k)]). Concretely, change the branches that reference
scale_len(m) and scale_len(n) to validate scale_len(k) and ensure
mat1Scale.size(0)/size(1) correspond to M or B*M (use symbols mat1Scale,
mat2Scale, scale_len, swizzled_len, mxfp8_bmm_impl, and variables b,m,n,k to
locate the checks).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/gemm/test_mm_mxfp8.py`:
- Around line 118-163: The test parametrizations for test_mm_mxfp8 and
test_mm_mxfp8_large_dimensions create too many heavy GPU runs and should be
gated: mark the large/slow matrix set with a pytest marker (e.g.,
`@pytest.mark.slow` or `@pytest.mark.xslow`) or wrap the large-parameter test in a
skipif driven by an environment variable (e.g., os.getenv("RUN_SLOW_TESTS") ==
"1"), and keep the default parametrization small; update the two test functions
(test_mm_mxfp8, test_mm_mxfp8_large_dimensions) or the helper _run_mm_mxfp8 call
sites to apply the marker/skip condition, and add the marker declaration to
pytest.ini if using a custom marker so CI can opt-in to run the extended tests.
🧹 Nitpick comments (2)
tests/gemm/test_mm_mxfp8.py (2)

60-65: Deduplicate GPU capability skip logic.

_run_mm_mxfp8 repeats the same capability checks already handled by _skip_if_unsupported, which risks divergence over time. Prefer calling the helper to keep policy centralized.

♻️ Proposed refactor
 def _run_mm_mxfp8(
@@
 ):
-    compute_capability = get_compute_capability(torch.device("cuda"))
-    if compute_capability[0] in [11, 12]:
-        pytest.skip("Not tested on SM110/SM120/SM121")
-    if compute_capability[0] < 10:
-        pytest.skip("mm_mxfp8 is only supported on SM100 and above GPUs.")
+    _skip_if_unsupported()

366-368: Use next(...) for min-scale lookup.

Avoid building an intermediate list to extract a single element; next(...) is clearer and cheaper.

🧹 Proposed change
-        min_scale = [scale for scale, sim in results if sim == min_sim][0]
+        min_scale = next(scale for scale, sim in results if sim == min_sim)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2482-2553: The non-swizzled MXFP8 scale checks in
_check_mm_mxfp8_problem_size currently compute scale shapes using integer
division by sf_vec_size (32) but don't verify K is divisible by 32; add explicit
divisibility checks and raise a ValueError if not. Specifically, when handling
a_descale.ndim == 2 ensure a.shape[1] % sf_vec_size == 0 before comparing
expected_shape, and when handling b_descale.ndim == 2 ensure b.shape[0] %
sf_vec_size == 0 before comparing expected_shape; use the same error
style/messages as the existing checks mentioning a_descale/b_descale and the
required divisibility by 32.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/gemm.py`:
- Around line 1268-1271: The backend check incorrectly rejects "auto" values
even though mm_mxfp8 accepts auto (which maps to "cutlass"); update the
conditional that examines the backends variable (the code block around backends
= args.backends and the if check) to allow either "cutlass" or "auto" (e.g.,
test for "cutlass" in backends or "auto" in backends) and keep the existing
ValueError only when neither is present so mm_mxfp8 calls succeed for
backend="auto".
🧹 Nitpick comments (2)
tests/gemm/test_mm_mxfp8.py (2)

345-347: Nit: prefer next(...) over single-element slice.

Per Ruff RUF015, use a generator with next() instead of building a list and indexing [0]:

🔧 Suggested fix
-        min_scale = [scale for scale, sim in results if sim == min_sim][0]
+        min_scale = next(scale for scale, sim in results if sim == min_sim)

488-530: Consider comparing outputs between non-contiguous and contiguous scale paths.

Currently both paths only assert isfinite. Comparing the two outputs against each other (or both against a reference) would catch silent correctness divergence between the contiguous and non-contiguous code paths.

Signed-off-by: Daniel Serebrenik <[email protected]>
@aleozlx
Copy link
Collaborator

aleozlx commented Feb 9, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !301 has been created, and the CI pipeline #43624652 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@dhiraj113 dhiraj113 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving the change. Check that the CI/CD is passing before submission.

)


def test_mm_mxfp8_invalid_ndim():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this function used/called?

Comment on lines +239 to +244
def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec:
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8"
os.makedirs(gen_directory, exist_ok=True)
source_paths = [
jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. We can leave it as is for now and take a look later if required.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43624652: canceled

@bkryu
Copy link
Collaborator

bkryu commented Feb 10, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !301 has been created, and the CI pipeline #43649889 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Daniel Serebrenik <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Line 2648: Fix the typo in the docstring that reads "Input Btensor, shape (k,
n), should be column major, mxfp8 e4m3." — change "Btensor" to "B tensor" so the
sentence reads "Input B tensor, shape (k, n), should be column major, mxfp8
e4m3." Update the docstring where the variable B is described (the string
containing "Input Btensor...") in gemm_base.py.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

2709-2722: Redundant assert statements duplicate _check_mm_mxfp8_problem_size validations.

@backend_requirement already invokes _check_mm_mxfp8_problem_size (which performs these same checks with raise ValueError) before the function body executes. These asserts are redundant and will silently disappear under python -O.

Consider removing them to reduce noise, or converting to raise ValueError if you want belt-and-suspenders validation inside the body.

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reviewed b46e87c..35a5635

mainly removed min_m check

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43649889: 9/20 passed

@aleozlx aleozlx added the run-ci label Feb 11, 2026
@aleozlx aleozlx enabled auto-merge (squash) February 11, 2026 22:12
@yongwww yongwww disabled auto-merge February 12, 2026 19:47
@yongwww yongwww enabled auto-merge (squash) February 12, 2026 19:47
@yongwww yongwww disabled auto-merge February 12, 2026 19:48
@yongwww yongwww merged commit b492d3f into flashinfer-ai:main Feb 12, 2026
33 of 34 checks passed
extra_cflags=[
"-DFAST_BUILD",
],
)


def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec:
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_bf16"
def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also adding this module to https://github.com/flashinfer-ai/flashinfer/blob/579435f982856a7be744793ddcf1a6d12b31c46b/flashinfer/aot.py to make sure we compile this module in flashinfer-jit-cache package?

@yzh119
Copy link
Collaborator

yzh119 commented Feb 12, 2026

I just noticed this PR was merged, @yongwww @danisereb would you mind creating another PR fixing the issue I mentioned?

yongwww added a commit that referenced this pull request Feb 13, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
follow-up #2464

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added mixed FP8 (mxfp8) precision support for SM100 GPUs, expanding
available kernel variants for more flexible computation options.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@coderabbitai coderabbitai bot mentioned this pull request Feb 24, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants