feat: Add MXFP8 GEMM mm_mxfp8 (cutlass)#2464
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds end-to-end MXFP8 GEMM support: new public Python API Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant API as mm_mxfp8_API
participant Validator as ProblemValidator
participant Tuner as AutoTuner
participant Backend as CutlassRunner
participant CUDA as CUDA_mxfp8_gemm
User->>API: call mm_mxfp8(A,B,a_descale,b_descale[,out])
API->>Validator: validate shapes, dtypes, scales, layout
Validator-->>API: ok / error
API->>Tuner: select backend & tactic (heuristic / autotune)
Tuner->>Backend: query tactics/configs & workspace requirements
Backend-->>Tuner: tactics + workspace needs
Tuner->>Backend: choose tactic, request workspace
Backend->>CUDA: launch mxfp8_gemm(tactic, workspace, inputs)
CUDA-->>Backend: execution status + output
Backend-->>API: return result tensor
API-->>User: deliver output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Tip We've launched Issue Planner and it is currently in beta. Please try it out and share your feedback on Discord! Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @danisereb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by introducing support for MXFP8 GEMM operations. This new functionality, powered by CUTLASS, is crucial for enabling efficient mixed-precision computations, particularly for large language models (LLMs) and their deployment with vLLM and ModelOpt MXFP8 checkpoints. The changes include core kernel implementations, Python API exposure, and robust testing to ensure high accuracy and performance on modern NVIDIA architectures. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for MXFP8 GEMM using the CUTLASS library. The changes include new C++/CUDA kernels, Python bindings, benchmark routines, and unit tests. The implementation looks solid and follows the existing patterns in the repository. I've identified a couple of bugs in the new benchmark code and a minor issue in a test case that should be addressed. Overall, this is a great addition to the library.
998efe7 to
00780c7
Compare
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Fix all issues with AI agents
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 300-309: The mm_mxfp8 backend map is missing an entry for SM110 so
filter_backends_by_compute_capability() will drop the CUTLASS backend on compute
capability 11.x; update the "mm_mxfp8" dict (the map shown) to include an "11.0"
key with the same backend list as other supported SM versions (e.g., add "11.0":
["cutlass"]) so CUTLASS is not filtered out for SM110.
In `@benchmarks/routines/gemm.py`:
- Around line 1268-1305: The mm_mxfp8 routine is being passed the default
backends (e.g., ["cudnn"]) which get filtered out and cause an early exit;
update the argument handling so that when args.routine == "mm_mxfp8" you force
the backends variable to include only "cutlass" before calling
filter_backends_by_compute_capability and before any autotune filtering (modify
the backends assignment/handling around the existing backends variable and the
autotune_supported_backends list); this ensures backends contains "cutlass" for
mm_mxfp8 and avoids the early return while preserving downstream filtering
logic.
In `@csrc/mxfp8_gemm_cutlass.cu`:
- Around line 94-157: mxfp8_bmm_impl currently only type-checks mat1Scale and
mat2Scale; add shape validation using the sfVecSize=32 contract to prevent OOB
scale access: compute sfVecSize=32 and expected scale lengths as scale_len =
(dim + sfVecSize - 1) / sfVecSize, then for the 2D path require
mat1Scale.ndim()==1 and mat1Scale.size(0)==scale_len(m) and mat2Scale.ndim()==1
and mat2Scale.size(0)==scale_len(n); for the 3D path require mat1Scale.ndim()==2
and mat1Scale.size(0)==b and mat1Scale.size(1)==scale_len(m) (and similarly
mat2Scale.size(0)==b and mat2Scale.size(1)==scale_len(n)); use TVM_FFI_ICHECK_EQ
with clear messages referencing mxfp8_bmm_impl, mat1Scale and mat2Scale so
incorrect shapes fail before kernel launch.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2482-2517: The function _check_mm_mxfp8_problem_size currently
ignores a user-provided out tensor; add validation when out is not None to
ensure out is 2D with shape (a.shape[0], b.shape[1]) (remember b is passed
transposed as [k, n]), on the same device as a (or b), and has dtype equal to
out_dtype (after calling _validate_mxfp8_output_dtype). If any of these checks
fail, raise ValueError with a clear message referencing expected vs actual
shape/device/dtype.
In `@flashinfer/jit/gemm/core.py`:
- Around line 239-244: In gen_gemm_sm100_module_cutlass_mxfp8(), copy the
CUTLASS source mxfp8_gemm_cutlass.cu from jit_env.FLASHINFER_CSRC_DIR into
gen_directory (e.g., using shutil.copy2 or Path operations) before constructing
source_paths, and then add the copied file path (located under gen_directory) to
source_paths instead of referencing the original FLASHINFER_CSRC_DIR file; this
ensures builds reference the hermetic, cached file in gen_directory while
leaving existing os.makedirs(...) intact.
In `@include/flashinfer/gemm/mxfp8_gemm_cutlass_template.h`:
- Around line 119-171: The dispatcher dispatchMXFP8xMXFP8GemmCTAShapeSm100
currently accepts CutlassTileConfigSM100::CtaShape128x64x128B (N=64) which
contradicts the comment "Cta N should be one of 128/192/256"; either remove that
case or make the code and comment consistent. To fix: decide which behavior is
correct, then (A) if N=64 is invalid, remove the CtaShape128x64x128B case and
replace it with a branch that throws a clear std::runtime_error (same style as
other error branches), or (B) if N=64 is supported, update the top comment to
document that N=64 is allowed for MXFP8 and add a short inline comment above the
CtaShape128x64x128B case explaining why N=64 is permitted; reference function
name dispatchMXFP8xMXFP8GemmCTAShapeSm100 and the enum value
CutlassTileConfigSM100::CtaShape128x64x128B when making the change.
- Around line 271-299: The static workspace_hashmap in CutlassMxfp8GemmRunner<T,
mxfp8GemmType>::getWorkspaceSize is mutated without synchronization causing data
races; add a std::mutex (include <mutex>) scoped at the same translation-unit
scope as workspace_hashmap (e.g., static std::mutex workspace_mutex) and wrap
all accesses and modifications of workspace_hashmap in a
std::lock_guard<std::mutex> to serialize lookups and inserts (both the find, the
call to getWorkspaceSizeImpl, and the subsequent assignment/get). Ensure the
mutex protects uses of MNK/MNKHash keys created via
std::make_tuple(m,n,k,batch_count) so getWorkspaceSize is thread-safe.
In `@tests/gemm/test_mm_mxfp8.py`:
- Around line 376-387: The test currently catches any RuntimeError from mm_mxfp8
and unconditionally skips, which hides real failures; change the except block
around mm_mxfp8 to inspect the exception message (or specific exception type if
available) and only call pytest.skip for known unsupported/backends cases (e.g.,
message contains "unsupported", "cutlass not available", or other
project-specific sentinel), otherwise re-raise the exception so real regressions
surface; keep references to mm_mxfp8 and the contextual m, n, k values in the
skip message.
🧹 Nitpick comments (5)
csrc/mxfp8_gemm_cutlass.cu (1)
158-168: Consider replacing with DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16 macro for consistency.The manual switch on dtype works correctly, but other GEMM kernel launchers in the project (e.g.,
csrc/gemm_groupwise_sm100.cu) use the standard dispatch macros defined incsrc/tvm_ffi_utils.h. Replace the switch withDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(out.dtype(), c_type, [&] { runGemm<c_type>(...); })to align with project conventions.tests/gemm/test_mm_mxfp8.py (4)
45-66: Avoid duplicating GPU‑compatibility checks.
Keep the skip policy in one place so future updates don’t drift.♻️ Suggested refactor
def _run_mm_mxfp8( m, n, k, input_dtype, is_sf_swizzled_layout, out_dtype, backend, auto_tuning, provide_out, ): - compute_capability = get_compute_capability(torch.device("cuda")) - if compute_capability[0] in [11, 12]: - pytest.skip("Not tested on SM110/SM120/SM121") - if compute_capability[0] < 10: - pytest.skip("mm_mxfp8 is only supported on SM100 and above GPUs.") + _skip_if_unsupported()Also applies to: 161-166
114-158: Consider trimming the parameter grid to keep runtime manageable.
The Cartesian product here is large; consider a smaller smoke matrix plus a separate extended/slow suite to keep CI stable.
334-340: Prefernext(...)over single‑element list indexing.
Avoids materializing a list and matches common Python idioms.♻️ Suggested tweak
- min_scale = [scale for scale, sim in results if sim == min_sim][0] + min_scale = next(scale for scale, sim in results if sim == min_sim)
483-525: Clarify the non‑contiguous scale expectation.
The test name implies a contiguity requirement, but the assertions only check finiteness. If non‑contig is unsupported, assert it raises; if supported, compare outputs and rename the test accordingly.
| def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: | ||
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8" | ||
| os.makedirs(gen_directory, exist_ok=True) | ||
| source_paths = [ | ||
| jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu", | ||
| ] |
There was a problem hiding this comment.
Copy the CUTLASS MXFP8 source into the gen directory before adding to source_paths.
The JIT generators should reference a copied source under gen_directory to keep generated builds hermetic and cacheable.
🔧 Suggested fix
- source_paths = [
- jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu",
- ]
+ src_path = jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu"
+ dest_path = gen_directory / "mxfp8_gemm_cutlass.cu"
+ with open(src_path, "r") as f:
+ source = f.read()
+ write_if_different(dest_path, source)
+ source_paths = [dest_path]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: | |
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8" | |
| os.makedirs(gen_directory, exist_ok=True) | |
| source_paths = [ | |
| jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu", | |
| ] | |
| def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: | |
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8" | |
| os.makedirs(gen_directory, exist_ok=True) | |
| src_path = jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu" | |
| dest_path = gen_directory / "mxfp8_gemm_cutlass.cu" | |
| with open(src_path, "r") as f: | |
| source = f.read() | |
| write_if_different(dest_path, source) | |
| source_paths = [dest_path] |
🤖 Prompt for AI Agents
In `@flashinfer/jit/gemm/core.py` around lines 239 - 244, In
gen_gemm_sm100_module_cutlass_mxfp8(), copy the CUTLASS source
mxfp8_gemm_cutlass.cu from jit_env.FLASHINFER_CSRC_DIR into gen_directory (e.g.,
using shutil.copy2 or Path operations) before constructing source_paths, and
then add the copied file path (located under gen_directory) to source_paths
instead of referencing the original FLASHINFER_CSRC_DIR file; this ensures
builds reference the hermetic, cached file in gen_directory while leaving
existing os.makedirs(...) intact.
There was a problem hiding this comment.
@dhiraj113 @aleozlx
What do you think ?
Is this suggested fix required ?
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
There was a problem hiding this comment.
I am not sure about this. We can leave it as is for now and take a look later if required.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2482-2531: The function _check_mm_mxfp8_problem_size currently
only checks dtypes for a_descale/b_descale; add shape validation to ensure the
scale tensors cannot underflow the kernel: for each of a_descale and b_descale
accept exactly two layouts — swizzled 1D or non‑swizzled 2D — and assert their
shapes match the matrix tiling used by the kernel (for a_descale validate
against A's row partitioning derived from a.shape[0], for b_descale validate
against B's column/column‑block partitioning derived from b.shape[1]);
specifically, if a_descale.ndim == 1 ensure a_descale.shape[0] ==
num_blocks_for_rows(a.shape[0]) (and likewise for b_descale with
num_blocks_for_cols(b.shape[1])), and if ndim == 2 ensure its shape equals
(num_blocks, block_scale_width) as used by the kernel; use the same block/tile
constants/functions the implementation uses to compute num_blocks (don’t
hardcode magic numbers) and raise ValueError with a clear message referencing
a_descale.shape / b_descale.shape when mismatched.
🧹 Nitpick comments (2)
tests/gemm/test_mm_mxfp8.py (2)
56-60: DRY the GPU capability skip logic.
_run_mm_mxfp8duplicates_skip_if_unsupported; calling the helper keeps skip criteria consistent and avoids divergence.♻️ Proposed refactor
def _run_mm_mxfp8( m, n, k, input_dtype, is_sf_swizzled_layout, out_dtype, backend, auto_tuning, provide_out, ): - compute_capability = get_compute_capability(torch.device("cuda")) - if compute_capability[0] in [11, 12]: - pytest.skip("Not tested on SM110/SM120/SM121") - if compute_capability[0] < 10: - pytest.skip("mm_mxfp8 is only supported on SM100 and above GPUs.") + _skip_if_unsupported()
338-340: Usenext(...)to avoid building a list formin_scale.Slightly clearer and avoids allocating a list.
♻️ Suggested tweak
- min_scale = [scale for scale, sim in results if sim == min_sim][0] + min_scale = next(scale for scale, sim in results if sim == min_sim)
b57b6bf to
2dd82ea
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@csrc/mxfp8_gemm_cutlass.cu`:
- Around line 180-214: The validation is checking scale lengths along M/N
instead of K; update the mat1Scale/mat2Scale non-swizzled checks in
mxfp8_bmm_impl so the per-scale dimension is scale_len(k) (not
scale_len(m)/scale_len(n)) and the other dims match either the unbatched 2D
shape ([M, scale_len(k)] / [N, scale_len(k)]) or the batched shapes ([B, M,
scale_len(k)] or flattened [B*M, scale_len(k)] / [B, N, scale_len(k)] or
flattened [B*N, scale_len(k)]). Concretely, change the branches that reference
scale_len(m) and scale_len(n) to validate scale_len(k) and ensure
mat1Scale.size(0)/size(1) correspond to M or B*M (use symbols mat1Scale,
mat2Scale, scale_len, swizzled_len, mxfp8_bmm_impl, and variables b,m,n,k to
locate the checks).
58076a9 to
8dd450c
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tests/gemm/test_mm_mxfp8.py`:
- Around line 118-163: The test parametrizations for test_mm_mxfp8 and
test_mm_mxfp8_large_dimensions create too many heavy GPU runs and should be
gated: mark the large/slow matrix set with a pytest marker (e.g.,
`@pytest.mark.slow` or `@pytest.mark.xslow`) or wrap the large-parameter test in a
skipif driven by an environment variable (e.g., os.getenv("RUN_SLOW_TESTS") ==
"1"), and keep the default parametrization small; update the two test functions
(test_mm_mxfp8, test_mm_mxfp8_large_dimensions) or the helper _run_mm_mxfp8 call
sites to apply the marker/skip condition, and add the marker declaration to
pytest.ini if using a custom marker so CI can opt-in to run the extended tests.
🧹 Nitpick comments (2)
tests/gemm/test_mm_mxfp8.py (2)
60-65: Deduplicate GPU capability skip logic.
_run_mm_mxfp8repeats the same capability checks already handled by_skip_if_unsupported, which risks divergence over time. Prefer calling the helper to keep policy centralized.♻️ Proposed refactor
def _run_mm_mxfp8( @@ ): - compute_capability = get_compute_capability(torch.device("cuda")) - if compute_capability[0] in [11, 12]: - pytest.skip("Not tested on SM110/SM120/SM121") - if compute_capability[0] < 10: - pytest.skip("mm_mxfp8 is only supported on SM100 and above GPUs.") + _skip_if_unsupported()
366-368: Usenext(...)for min-scale lookup.Avoid building an intermediate list to extract a single element;
next(...)is clearer and cheaper.🧹 Proposed change
- min_scale = [scale for scale, sim in results if sim == min_sim][0] + min_scale = next(scale for scale, sim in results if sim == min_sim)
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2482-2553: The non-swizzled MXFP8 scale checks in
_check_mm_mxfp8_problem_size currently compute scale shapes using integer
division by sf_vec_size (32) but don't verify K is divisible by 32; add explicit
divisibility checks and raise a ValueError if not. Specifically, when handling
a_descale.ndim == 2 ensure a.shape[1] % sf_vec_size == 0 before comparing
expected_shape, and when handling b_descale.ndim == 2 ensure b.shape[0] %
sf_vec_size == 0 before comparing expected_shape; use the same error
style/messages as the existing checks mentioning a_descale/b_descale and the
required divisibility by 32.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/gemm.py`:
- Around line 1268-1271: The backend check incorrectly rejects "auto" values
even though mm_mxfp8 accepts auto (which maps to "cutlass"); update the
conditional that examines the backends variable (the code block around backends
= args.backends and the if check) to allow either "cutlass" or "auto" (e.g.,
test for "cutlass" in backends or "auto" in backends) and keep the existing
ValueError only when neither is present so mm_mxfp8 calls succeed for
backend="auto".
🧹 Nitpick comments (2)
tests/gemm/test_mm_mxfp8.py (2)
345-347: Nit: prefernext(...)over single-element slice.Per Ruff RUF015, use a generator with
next()instead of building a list and indexing[0]:🔧 Suggested fix
- min_scale = [scale for scale, sim in results if sim == min_sim][0] + min_scale = next(scale for scale, sim in results if sim == min_sim)
488-530: Consider comparing outputs between non-contiguous and contiguous scale paths.Currently both paths only assert
isfinite. Comparing the two outputs against each other (or both against a reference) would catch silent correctness divergence between the contiguous and non-contiguous code paths.
Signed-off-by: Daniel Serebrenik <[email protected]>
79a80de to
9f1813d
Compare
|
/bot run |
dhiraj113
left a comment
There was a problem hiding this comment.
Approving the change. Check that the CI/CD is passing before submission.
| ) | ||
|
|
||
|
|
||
| def test_mm_mxfp8_invalid_ndim(): |
There was a problem hiding this comment.
Where is this function used/called?
| def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: | ||
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_mxfp8" | ||
| os.makedirs(gen_directory, exist_ok=True) | ||
| source_paths = [ | ||
| jit_env.FLASHINFER_CSRC_DIR / "mxfp8_gemm_cutlass.cu", | ||
| ] |
There was a problem hiding this comment.
I am not sure about this. We can leave it as is for now and take a look later if required.
|
[CANCELING] Pipeline #43624652: canceled |
|
/bot run |
Signed-off-by: Daniel Serebrenik <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Line 2648: Fix the typo in the docstring that reads "Input Btensor, shape (k,
n), should be column major, mxfp8 e4m3." — change "Btensor" to "B tensor" so the
sentence reads "Input B tensor, shape (k, n), should be column major, mxfp8
e4m3." Update the docstring where the variable B is described (the string
containing "Input Btensor...") in gemm_base.py.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
2709-2722: Redundantassertstatements duplicate_check_mm_mxfp8_problem_sizevalidations.
@backend_requirementalready invokes_check_mm_mxfp8_problem_size(which performs these same checks withraise ValueError) before the function body executes. Theseasserts are redundant and will silently disappear underpython -O.Consider removing them to reduce noise, or converting to
raise ValueErrorif you want belt-and-suspenders validation inside the body.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
|
[FAILED] Pipeline #43649889: 9/20 passed |
| extra_cflags=[ | ||
| "-DFAST_BUILD", | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| def gen_gemm_sm100_module_cutlass_bf16() -> JitSpec: | ||
| gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_bf16" | ||
| def gen_gemm_sm100_module_cutlass_mxfp8() -> JitSpec: |
There was a problem hiding this comment.
Can you also adding this module to https://github.com/flashinfer-ai/flashinfer/blob/579435f982856a7be744793ddcf1a6d12b31c46b/flashinfer/aot.py to make sure we compile this module in flashinfer-jit-cache package?
|
I just noticed this PR was merged, @yongwww @danisereb would you mind creating another PR fixing the issue I mentioned? |
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> follow-up #2464 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added mixed FP8 (mxfp8) precision support for SM100 GPUs, expanding available kernel variants for more flexible computation options. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Add new API
mm_mxfp8to support MXFP8 GEMM (using cutlass).The logic/flow of
mm_mxfp8was based onmm_fp4.The MXFP8 CUTLASS code is based on the existing Flashinfer NVFP4 code.
Will be used for MXFP8 support in vLLM (with ModelOpt MXFP8 checkpoints).
Performance results (Blackwell, B200)
Comparison of MXFP8 and FP8
MXFP8
Results:
FP8
Results:
More about MXFP8 performance (BF16 / FP8 / MXFP8)
https://cursor.com/ja/blog/kernels
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores