Skip to content

benchmarks: Add norm and quantization routines to microbenchmark harness.#2362

Merged
yzh119 merged 7 commits intoflashinfer-ai:mainfrom
bkryu:benchmark_utils
Jan 16, 2026
Merged

benchmarks: Add norm and quantization routines to microbenchmark harness.#2362
yzh119 merged 7 commits intoflashinfer-ai:mainfrom
bkryu:benchmark_utils

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Jan 15, 2026

📌 Description

No changes to library code

Extends the flashinfer_benchmark.py microbenchmark harness to support normalization and quantization routines. This enables performance benchmarking for RMSNorm variants (including fused residual-add and FP8/FP4 quantization) and standalone quantization kernels (MxFP8, MxFP4, NVFP4).

New Routines

Category Routine Backend Min SM
Norm rmsnorm cuda 7.5
  rmsnorm_quant cuda 7.5
  fused_add_rmsnorm_quant cuda 7.5
  rmsnorm_fp4quant cute-dsl 10.0
  add_rmsnorm_fp4quant cute-dsl 10.0
Quantization mxfp8_quantize cuda 10.0
  mxfp4_quantize cuda 10.0
  nvfp4_quantize cuda 10.0
  nvfp4_batched_quantize cuda 10.0

Example outputs:

$ python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 16384 --hidden_size 16384 --input_dtype bfloat16 --refcheck
[PERF] cuda           :: median time 0.251 ms; std 0.001 ms; achieved tflops 5.356 TFLOPs/sec; achieved tb_per_sec 4.285 TB/sec
$ python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 16384 --hidden_size 16384 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck 
[PERF] cuda           :: median time 0.268 ms; std 0.001 ms; achieved tflops 5.004 TFLOPs/sec; achieved tb_per_sec 3.002 TB/sec
$ python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 16384 --hidden_size 16384 --input_dtype bfloat16 --use_global_scale 
[PERF] cute-dsl       :: median time 0.209 ms; std 0.000 ms; achieved tflops 6.407 TFLOPs/sec; achieved tb_per_sec 3.284 TB/sec
$ python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 8192 --k 8192 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4
[PERF] cuda           :: median time 0.049 ms; std 0.000 ms; achieved tflops 4.113 TFLOPs/sec; achieved tb_per_sec 3.514 TB/sec

🔍 Related Issues

#2361

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added benchmarking support for Norm (RMSNorm and variants) and multiple Quantization routines, plus new operation variants and expanded backend support.
  • Documentation

    • Expanded benchmark guide with Norm and Quantization flag docs, example commands, verbose run samples, and updated Routine & Backend support matrix.
  • Tests

    • Added many sample benchmark entries covering 2D/3D shapes, fused and quantized variants, and FP4/FP8 cases.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds Norm (RMSNorm) and Quantization benchmark routines, CLI wiring, output schema/backends updates, many new sample test entries, and README documentation for the new APIs and flags.

Changes

Cohort / File(s) Summary
Documentation
benchmarks/README.md
Added Norm and Quantization to overview and backend list; added quick-start examples and verbose repro commands; documented Norm Flags and Quantization Flags; updated Routine & Backend Support Matrix and backend legend.
Main orchestrator
benchmarks/flashinfer_benchmark.py
Added imports and CLI routing for norm and quantization; dispatches to run_norm_test / run_quantization_test; preserves routine-specific output defaults when emitting general fields.
Benchmark schema & mappings
benchmarks/routines/flashinfer_benchmark_utils.py
Extended output_column_dict with norm and quantization sections; restored some general fields; updated full_output_columns, benchmark_apis, and expanded routine_cc_to_supported_backends to include many new norm/quant routines and per-CC backend support. (Pay attention to backend matrix entries and new column names.)
Norm benchmarks
benchmarks/routines/norm.py
New module implementing run_norm_test, parse_norm_args, and test functions: testRmsnorm, testRmsnormQuant, testFusedAddRmsnormQuant, testRmsnormFp4quant, testAddRmsnormFp4quant. Implements tensor setup, backend dispatch, optional reference checks, timing, and perf reporting. (High logic density; review correctness of dtype/out_dtype, FP4 handling, and global_scale paths.)
Quantization benchmarks
benchmarks/routines/quantization.py
New module implementing run_quantization_test, parse_quantization_args, and tests: testMxfp8Quantize, testMxfp4Quantize, testNvfp4Quantize, testNvfp4BatchedQuantize. Implements input/layout handling, optional dequantize reference checks, timing, and perf reporting. (Verify sf_layout / sf_vec_size validation and batched-shape checks.)
Sample tests
benchmarks/samples/sample_testlist.txt
Added ~100+ RMSNorm and quantization test entries (2D/3D, PDL variants, FP8/FP4, fused add cases, layout/scale options).

Sequence Diagram(s)

mermaid
sequenceDiagram
participant CLI as CLI
participant Orchestrator as flashinfer_benchmark.py
participant Router as run_*_test
participant Routine as norm / quantization module
participant Backends as HardwareBackends
participant Ref as ReferenceCPU
CLI->>Orchestrator: parse args, select routine
Orchestrator->>Router: dispatch to run_norm_test / run_quantization_test
Router->>Routine: parse routine args; prepare tensors
Routine->>Backends: invoke backend implementations (cuda, cute-dsl, cutlass, trtllm, cpu...)
Backends-->>Routine: results, timings
Routine->>Ref: (optional) run float reference / dequantize check
Ref-->>Routine: reference outputs
Routine->>Orchestrator: aggregated results & perf metrics
Orchestrator-->>CLI: emit results / write output file

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Anerudhan
  • yzh119
  • cyx-6
  • kahyunnam
  • jiahanc

Poem

🐰
I hopped through code with nimble paws,
Norms and quants now join the cause,
Benchmarks hum with measured beats,
Kernels race and metrics meet,
A carrot and a green-lit pass!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding norm and quantization routines to the microbenchmark harness.
Description check ✅ Passed The PR description includes a clear summary of changes, a detailed table of new routines with backends and minimum SM versions, example command outputs, links to related issues, and completed pre-commit and test checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flashinfer_benchmark.py microbenchmark harness by integrating a comprehensive suite of normalization and quantization routines. This expansion allows for detailed performance analysis of various RMSNorm implementations, including those with FP8 and FP4 quantized outputs, as well as dedicated MxFP8, MxFP4, and NVFP4 quantization kernels. The changes facilitate robust benchmarking across different backends, including the newly supported CuTe-DSL for advanced FP4 operations on Blackwell architectures, providing critical insights into the efficiency of these core deep learning primitives.

Highlights

  • Expanded Benchmarking Capabilities: The flashinfer_benchmark.py harness now supports performance benchmarking for normalization and quantization routines, in addition to existing Attention, GEMM, and MOE APIs.
  • New Norm Routines: Introduced benchmarks for rmsnorm, rmsnorm_quant (with FP8 output), fused_add_rmsnorm_quant (fused residual-add with FP8 output), rmsnorm_fp4quant (FP4 output via CuTe-DSL), and add_rmsnorm_fp4quant (fused residual-add with FP4 output via CuTe-DSL).
  • New Quantization Routines: Added benchmarks for standalone quantization kernels including mxfp8_quantize, mxfp4_quantize, nvfp4_quantize, and nvfp4_batched_quantize.
  • CuTe-DSL Backend Integration: The cute-dsl backend is now supported for FP4 quantization routines, specifically targeting Blackwell SM10.0+ architectures.
  • Comprehensive Documentation Updates: The README.md has been thoroughly updated to reflect the new API categories, supported backends, detailed command-line flags for norm and quantization, and an expanded routine/backend support matrix.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the flashinfer_benchmark.py microbenchmark harness to support normalization and quantization routines, which is a great addition. The changes are well-structured, with new routines encapsulated in their own files and common utilities refactored. The documentation has also been updated accordingly.

I've identified a couple of areas for improvement in the new benchmark scripts to enhance user experience and correctness:

  • In routines/norm.py, the handling of out_dtype for FP4 routines could be made stricter to avoid user confusion.
  • In routines/quantization.py, the command-line argument parsing for boolean flags can be made more idiomatic.

Overall, this is a solid contribution that significantly expands the benchmarking capabilities.

Comment on lines +87 to +98
parser.add_argument(
"--is_sf_swizzled_layout",
action="store_true",
default=True,
help="Use swizzled layout for scale factors. Default: True",
)
parser.add_argument(
"--no_sf_swizzled_layout",
action="store_true",
default=False,
help="Disable swizzled layout for scale factors.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for handling --is_sf_swizzled_layout and --no_sf_swizzled_layout is not idiomatic argparse usage. Using action="store_true" with default=True makes the --is_sf_swizzled_layout flag have no effect. A clearer approach is to use a single dest for two flags with opposite actions. This also allows removing the manual handling logic on lines 160-162.

    parser.add_argument(
        "--is_sf_swizzled_layout",
        dest="is_sf_swizzled_layout",
        action="store_true",
        help="Use swizzled layout for scale factors. (default)",
    )
    parser.add_argument(
        "--no_sf_swizzled_layout",
        dest="is_sf_swizzled_layout",
        action="store_false",
        help="Disable swizzled layout for scale factors.",
    )
    parser.set_defaults(is_sf_swizzled_layout=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point will address in latest commit.

Comment on lines +160 to +162
# Handle swizzled layout flag
if args.no_sf_swizzled_layout:
args.is_sf_swizzled_layout = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the suggested change to use dest and set_defaults for the is_sf_swizzled_layout argument, this manual handling is no longer necessary and should be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point will address in latest commit.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

103-110: Avoid duplicate column names in full_output_columns.

The new sections reuse m, k, enable_pdl, and is_sf_swizzled_layout, so the concatenated header will contain duplicates. Many CSV readers (e.g., pandas) will auto-rename or clobber duplicates, which makes results ambiguous. Consider de-duplicating or prefixing category-specific fields.

♻️ Proposed fix (de-duplicate while preserving order)
-full_output_columns = (
-    output_column_dict["perf"]
-    + output_column_dict["attention"]
-    + output_column_dict["gemm"]
-    + output_column_dict["moe"]
-    + output_column_dict["norm"]
-    + output_column_dict["quantization"]
-    + output_column_dict["general"]
-)
+full_output_columns = list(
+    dict.fromkeys(
+        output_column_dict["perf"]
+        + output_column_dict["attention"]
+        + output_column_dict["gemm"]
+        + output_column_dict["moe"]
+        + output_column_dict["norm"]
+        + output_column_dict["quantization"]
+        + output_column_dict["general"]
+    )
+)
🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 241-257: The Norm and Quantization markdown tables lack
surrounding blank lines required by markdownlint MD058; add one blank line
before the "### Norm Flags" table and one blank line after its closing |---| row
(before the "### Quantization Flags" heading), and likewise ensure the
Quantization table has a blank line before and after it so both tables are
separated by empty lines from adjacent headings/content.
- Around line 32-42: Fix the markdown nested list indentation under the "Norm"
and "Quantization" sections by converting the 4-space indents to 2-space indents
for each nested list item (e.g., `rmsnorm`, `rmsnorm_quant`,
`fused_add_rmsnorm_quant`, `rmsnorm_fp4quant`, `add_rmsnorm_fp4quant`,
`mxfp8_quantize`, `mxfp4_quantize`, `nvfp4_quantize`, `nvfp4_batched_quantize`)
so the lists comply with MD007; locate the entries in the README where these
items are listed and reduce their leading spaces from four to two while
preserving the same bullet characters and text.

In `@benchmarks/routines/norm.py`:
- Around line 743-753: The code derives block_size from out_dtype (nvfp4→16,
mxfp4→32) but never verifies hidden_size is divisible by block_size, which can
cause kernel failures and wrong metrics; add an explicit check after block_size
is set that verifies hidden_size % block_size == 0 and raise a ValueError with a
clear message mentioning out_dtype, block_size and hidden_size (similar to the
check in testAddRmsnormFp4quant) so callers know the required alignment before
running the CuTe-DSL RMSNorm path; reference the variables block_size,
hidden_size and out_dtype near this logic in benchmarks/routines/norm.py.

In `@benchmarks/routines/quantization.py`:
- Around line 292-304: The refcheck branch computes round-trip stats via
is_close_stats but never enforces them; update the exception/validation logic
after is_close_stats in the quantization routine to check num_different_elements
and, if args.refcheck is set and num_different_elements > 0, raise an
AssertionError (unless args.allow_output_mismatch is true) and otherwise emit a
warning or log when mismatches exist; apply the same enforcement/warning logic
to the testMxfp4Quantize path so both routines respect --refcheck and
--allow_output_mismatch while still printing verbose stats guarded by
args.verbose.
- Around line 208-323: The code currently hardcodes a 32-byte scale-factor
vector size when validating k and computing bandwidth; update the checks and
math to use the actual alignment (sf_vec_size) passed to the kernel. Replace the
k % 32 validation with k % alignment (or k % sf_vec_size if you prefer to set
sf_vec_size = alignment early), and set sf_vec_size = alignment before computing
num_scale_factors so that num_scale_factors = (m * k) // sf_vec_size and
problem_bytes uses that value; apply the same replacement in any other
quantization tests that compute sf_vec_size or validate k to ensure consistency
between the alignment argument and bandwidth/validation logic (refer to
variables alignment, sf_vec_size, num_scale_factors, and the k validation
check).
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 36380e2 and c2c39ad.

📒 Files selected for processing (6)
  • benchmarks/README.md
  • benchmarks/flashinfer_benchmark.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/norm.py
  • benchmarks/routines/quantization.py
  • benchmarks/samples/sample_testlist.txt
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • benchmarks/flashinfer_benchmark.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • benchmarks/README.md
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads

Applied to files:

  • benchmarks/README.md
🧬 Code graph analysis (1)
benchmarks/routines/norm.py (1)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1508-1655)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md

33-33: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


34-34: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


35-35: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


36-36: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


37-37: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


39-39: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


40-40: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


41-41: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


42-42: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)


242-242: Tables should be surrounded by blank lines

(MD058, blanks-around-tables)


257-257: Tables should be surrounded by blank lines

(MD058, blanks-around-tables)

🪛 Ruff (0.14.11)
benchmarks/routines/quantization.py

53-53: Avoid specifying long messages outside the exception class

(TRY003)


210-210: Avoid specifying long messages outside the exception class

(TRY003)


219-221: Avoid specifying long messages outside the exception class

(TRY003)


244-244: Avoid specifying long messages outside the exception class

(TRY003)


304-304: Do not catch blind exception: Exception

(BLE001)


393-395: Avoid specifying long messages outside the exception class

(TRY003)


410-410: Avoid specifying long messages outside the exception class

(TRY003)


462-462: Do not catch blind exception: Exception

(BLE001)


569-571: Avoid specifying long messages outside the exception class

(TRY003)


599-599: Avoid specifying long messages outside the exception class

(TRY003)


718-718: Avoid specifying long messages outside the exception class

(TRY003)


727-729: Avoid specifying long messages outside the exception class

(TRY003)


751-751: Avoid specifying long messages outside the exception class

(TRY003)

benchmarks/routines/norm.py

55-55: Avoid specifying long messages outside the exception class

(TRY003)


198-200: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Avoid specifying long messages outside the exception class

(TRY003)


269-271: Avoid specifying long messages outside the exception class

(TRY003)


359-361: Avoid specifying long messages outside the exception class

(TRY003)


365-367: Avoid specifying long messages outside the exception class

(TRY003)


396-396: Avoid specifying long messages outside the exception class

(TRY003)


451-453: Avoid specifying long messages outside the exception class

(TRY003)


539-541: Avoid specifying long messages outside the exception class

(TRY003)


545-547: Avoid specifying long messages outside the exception class

(TRY003)


579-579: Avoid specifying long messages outside the exception class

(TRY003)


649-651: Avoid specifying long messages outside the exception class

(TRY003)


751-753: Avoid specifying long messages outside the exception class

(TRY003)


762-764: Avoid specifying long messages outside the exception class

(TRY003)


806-806: Avoid specifying long messages outside the exception class

(TRY003)


921-923: Avoid specifying long messages outside the exception class

(TRY003)


932-934: Avoid specifying long messages outside the exception class

(TRY003)


979-979: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
benchmarks/samples/sample_testlist.txt (1)

52-153: Sample coverage looks solid for the new routines.

Nice spread across 2D/3D shapes, FP8/FP4 variants, PDL, and layout options.

benchmarks/flashinfer_benchmark.py (1)

31-46: Routing + output defaulting looks good.

The dispatch additions are clean, and the guarded output population preserves routine-specific formatting.

benchmarks/routines/flashinfer_benchmark_utils.py (3)

68-91: Norm/quantization output schema looks aligned with the new routines.

Reintroducing shared fields under general keeps the schema consistent across benchmark families.


133-145: Benchmark API registry updated appropriately.

The new norm/quantization routines are wired into the benchmark API list as expected.


326-418: Backend support mapping covers the new routines.

The per-SM backend lists are consistent with the intended CUDA/cute-dsl availability.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

@bkryu bkryu marked this pull request as draft January 15, 2026 23:07
@bkryu bkryu self-assigned this Jan 16, 2026
@bkryu bkryu marked this pull request as ready for review January 16, 2026 00:33
@bkryu
Copy link
Collaborator Author

bkryu commented Jan 16, 2026

Note: The PR does not touch the library code, so will not trigger a unit test.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@benchmarks/routines/norm.py`:
- Around line 566-628: The timed runs reuse the same pre-evaluated input_args
tuple so the residual clone is shared across iterations; to guarantee a fresh
residual every iteration, clone the residual inside run_backend (e.g., at the
start of run_backend do residual_tensor = residual_tensor.clone()) before
calling flashinfer.norm.fused_add_rmsnorm_quant, so each invocation gets its own
independent residual regardless of how bench_gpu_time supplies input_args; refer
to run_backend, fused_add_rmsnorm_quant, and bench_gpu_time when making this
change.

In `@benchmarks/routines/quantization.py`:
- Around line 74-78: The CLI help for the "--k" argument is misleading because
divisibility depends on chosen quantization parameters (alignment for mxfp8 or
sf_vec_size for nvfp4); update the help text in the add_argument call for "--k"
to state that k must be divisible by the active quantization granularity (e.g.,
"must be divisible by alignment for mxfp8 or by sf_vec_size for nvfp4"), and
optionally reference which flags select those modes so users know which divisor
applies (check where "--quant-mode"/related flags and variables alignment and
sf_vec_size are defined and mentioned).
- Around line 281-316: The except block in the refcheck for mxfp8 (inside the
try that calls flashinfer.mxfp8_dequantize_host and is_close_stats) currently
only prints a warning; change it to re-raise the exception unless
args.allow_output_mismatch is true: in the except Exception as e: branch, keep
the verbose print but then if not args.allow_output_mismatch: raise (re-raise
the caught exception) so failures surface; otherwise, print the warning and
continue. Make the identical change to the mxfp4_quantize exception handler as
well, preserving its log tag ([mxfp4_quantize]) where applicable.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c2c39ad and e9227f0.

📒 Files selected for processing (2)
  • benchmarks/routines/norm.py
  • benchmarks/routines/quantization.py
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/quantization.py (1)
benchmarks/flashinfer_benchmark.py (1)
  • parse_args (56-189)
benchmarks/routines/norm.py (2)
benchmarks/routines/flashinfer_benchmark_utils.py (5)
  • dtype_str_to_torch_dtype (179-193)
  • get_device (156-165)
  • print_perf_metrics (149-153)
  • is_close_stats (168-176)
  • filter_backends_by_compute_capability (422-440)
benchmarks/routines/quantization.py (4)
  • run_backend (233-242)
  • run_backend (421-425)
  • run_backend (620-631)
  • run_backend (781-789)
🪛 Ruff (0.14.11)
benchmarks/routines/quantization.py

53-53: Avoid specifying long messages outside the exception class

(TRY003)


208-208: Avoid specifying long messages outside the exception class

(TRY003)


217-219: Avoid specifying long messages outside the exception class

(TRY003)


242-242: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Abstract raise to an inner function

(TRY301)


313-313: Do not catch blind exception: Exception

(BLE001)


397-399: Avoid specifying long messages outside the exception class

(TRY003)


408-410: Avoid specifying long messages outside the exception class

(TRY003)


425-425: Avoid specifying long messages outside the exception class

(TRY003)


487-487: Abstract raise to an inner function

(TRY301)


488-488: Do not catch blind exception: Exception

(BLE001)


590-592: Avoid specifying long messages outside the exception class

(TRY003)


601-603: Avoid specifying long messages outside the exception class

(TRY003)


631-631: Avoid specifying long messages outside the exception class

(TRY003)


750-750: Avoid specifying long messages outside the exception class

(TRY003)


754-756: Avoid specifying long messages outside the exception class

(TRY003)


765-767: Avoid specifying long messages outside the exception class

(TRY003)


789-789: Avoid specifying long messages outside the exception class

(TRY003)

benchmarks/routines/norm.py

55-55: Avoid specifying long messages outside the exception class

(TRY003)


198-200: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Avoid specifying long messages outside the exception class

(TRY003)


269-271: Avoid specifying long messages outside the exception class

(TRY003)


359-361: Avoid specifying long messages outside the exception class

(TRY003)


365-367: Avoid specifying long messages outside the exception class

(TRY003)


396-396: Avoid specifying long messages outside the exception class

(TRY003)


451-453: Avoid specifying long messages outside the exception class

(TRY003)


539-541: Avoid specifying long messages outside the exception class

(TRY003)


545-547: Avoid specifying long messages outside the exception class

(TRY003)


579-579: Avoid specifying long messages outside the exception class

(TRY003)


649-651: Avoid specifying long messages outside the exception class

(TRY003)


751-753: Avoid specifying long messages outside the exception class

(TRY003)


757-760: Avoid specifying long messages outside the exception class

(TRY003)


769-771: Avoid specifying long messages outside the exception class

(TRY003)


813-813: Avoid specifying long messages outside the exception class

(TRY003)


928-930: Avoid specifying long messages outside the exception class

(TRY003)


934-937: Avoid specifying long messages outside the exception class

(TRY003)


946-948: Avoid specifying long messages outside the exception class

(TRY003)


993-993: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
benchmarks/routines/quantization.py (3)

34-53: Dispatcher routing is clean and fail-fast.

Clear routine mapping and an explicit error for unsupported routines.


530-707: NVFP4 quantize path looks solid.

Good handling of layout mapping, CUDA‑graph compatibility, and sf_vec_size validation.


710-863: Batched NVFP4 benchmark flow looks good.

Input validation, shape checks, and bandwidth accounting are clear.

benchmarks/routines/norm.py (6)

34-55: Norm dispatcher is straightforward.

Clear routing and explicit failure for unknown routines.


69-149: CLI coverage for norm benchmarks looks complete.

Arguments and defaults are well-aligned with the routine matrix.


152-310: RMSNorm benchmark flow looks correct.

Reference check, timing, and metrics are consistent with the harness.


313-491: RMSNorm + FP8 quant benchmark is well-structured.

Refcheck, dtype validation, and metrics look good.


693-867: FP4 RMSNorm benchmark looks solid.

Alignment validation and bandwidth accounting are clear.


870-1049: Add‑RMSNorm FP4 path is in good shape.

Input prep, backend routing, and metrics look consistent.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +74 to +78
"--k",
type=int,
required=True,
help="Number of columns in input tensor (must be divisible by 32).",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify the --k divisibility rule in CLI help.

The help text says “divisible by 32,” but actual validation depends on alignment (mxfp8) or sf_vec_size (nvfp4). This can mislead users.

✏️ Suggested help text tweak
-        help="Number of columns in input tensor (must be divisible by 32).",
+        help="Number of columns in input tensor (must be divisible by alignment or sf_vec_size for the selected routine).",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--k",
type=int,
required=True,
help="Number of columns in input tensor (must be divisible by 32).",
)
"--k",
type=int,
required=True,
help="Number of columns in input tensor (must be divisible by alignment or sf_vec_size for the selected routine).",
)
🤖 Prompt for AI Agents
In `@benchmarks/routines/quantization.py` around lines 74 - 78, The CLI help for
the "--k" argument is misleading because divisibility depends on chosen
quantization parameters (alignment for mxfp8 or sf_vec_size for nvfp4); update
the help text in the add_argument call for "--k" to state that k must be
divisible by the active quantization granularity (e.g., "must be divisible by
alignment for mxfp8 or by sf_vec_size for nvfp4"), and optionally reference
which flags select those modes so users know which divisor applies (check where
"--quant-mode"/related flags and variables alignment and sf_vec_size are defined
and mentioned).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/quantization.py`:
- Around line 668-673: The current NVFP4 shape checks print a warning on
mismatch; change them to raise an AssertionError when x_q.shape !=
expected_shape unless args.allow_output_mismatch is True. Locate the checks that
compute expected_shape = (m, k // 2) and compare to x_q.shape (and the similar
block later in the file) and replace the print warning with: if not
args.allow_output_mismatch: raise AssertionError(f"Unexpected output shape:
{x_q.shape}, expected {expected_shape}") so behavior matches the mxfp8/mxfp4
tests.
♻️ Duplicate comments (1)
benchmarks/routines/quantization.py (1)

74-78: Clarify --k divisibility rule in CLI help.

The divisor depends on alignment or sf_vec_size, not always 32.

✏️ Suggested help text tweak
-        help="Number of columns in input tensor (must be divisible by 32).",
+        help="Number of columns in input tensor (must be divisible by alignment for mxfp8 or sf_vec_size for nvfp4).",
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9227f0 and 1f2d48a.

📒 Files selected for processing (1)
  • benchmarks/routines/quantization.py
🧰 Additional context used
🪛 Ruff (0.14.11)
benchmarks/routines/quantization.py

53-53: Avoid specifying long messages outside the exception class

(TRY003)


208-208: Avoid specifying long messages outside the exception class

(TRY003)


217-219: Avoid specifying long messages outside the exception class

(TRY003)


242-242: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Abstract raise to an inner function

(TRY301)


401-403: Avoid specifying long messages outside the exception class

(TRY003)


412-414: Avoid specifying long messages outside the exception class

(TRY003)


429-429: Avoid specifying long messages outside the exception class

(TRY003)


491-491: Abstract raise to an inner function

(TRY301)


598-600: Avoid specifying long messages outside the exception class

(TRY003)


609-611: Avoid specifying long messages outside the exception class

(TRY003)


639-639: Avoid specifying long messages outside the exception class

(TRY003)


758-758: Avoid specifying long messages outside the exception class

(TRY003)


762-764: Avoid specifying long messages outside the exception class

(TRY003)


773-775: Avoid specifying long messages outside the exception class

(TRY003)


797-797: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
benchmarks/routines/quantization.py (3)

34-53: Clear routine dispatch.

Simple, explicit mapping with a clean fallback error.


205-319: Alignment validation + refcheck enforcement look solid.

Nice to see alignment-based checks and explicit mismatch handling.


398-499: Refcheck round‑trip handling is consistent.

Mismatch handling aligns well with the mxfp8 path.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +668 to +673
# Verify output shape (M, K/2) for FP4
expected_shape = (m, k // 2)
if x_q.shape != expected_shape:
print(
f"[WARNING] Unexpected output shape: {x_q.shape}, expected {expected_shape}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the specific lines mentioned in the review
cat -n benchmarks/routines/quantization.py | sed -n '665,675p'

Repository: flashinfer-ai/flashinfer

Length of output: 644


🏁 Script executed:

# Also check the second location (826-831)
cat -n benchmarks/routines/quantization.py | sed -n '823,835p'

Repository: flashinfer-ai/flashinfer

Length of output: 777


🏁 Script executed:

# Search for allow_output_mismatch usage in the file
rg "allow_output_mismatch" benchmarks/routines/quantization.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 303


🏁 Script executed:

# Get more context around the nvfp4 function to understand the structure
rg "def.*nvfp4" benchmarks/routines/quantization.py -A 5 -n

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check how args is passed to these functions
rg "args\." benchmarks/routines/quantization.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1058


🏁 Script executed:

# Check the context around allow_output_mismatch usage to understand its purpose
rg "allow_output_mismatch" benchmarks/routines/quantization.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1978


🏁 Script executed:

# Look for the full nvfp4_quantize and nvfp4_batched_quantize function definitions
rg "def.*nvfp4.*quantize" benchmarks/routines/quantization.py -A 50 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if refcheck is used and how
rg "refcheck\|run_refcheck" benchmarks/routines/quantization.py -n -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Get the broader context of lines 668-673 (around 50 lines before to understand function)
cat -n benchmarks/routines/quantization.py | sed -n '620,680p'

Repository: flashinfer-ai/flashinfer

Length of output: 3021


🏁 Script executed:

# Find the function definitions for nvfp4_quantize and nvfp4_batched_quantize
rg "def.*nvfp4.*:" benchmarks/routines/quantization.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if args is accessible within the run_backend and refcheck logic
sed -n '600,650p' benchmarks/routines/quantization.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 2413


🏁 Script executed:

# Search for all shape verification patterns in the file to see if there's consistency
rg "expected_shape" benchmarks/routines/quantization.py -B 2 -A 5 -n

Repository: flashinfer-ai/flashinfer

Length of output: 1081


🏁 Script executed:

# Check where nvfp4_quantize function ends to see the full scope
rg "def.*nvfp4_quantize" benchmarks/routines/quantization.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Get the function definition for nvfp4_quantize
sed -n '600,900p' benchmarks/routines/quantization.py | head -200

Repository: flashinfer-ai/flashinfer

Length of output: 7868


🏁 Script executed:

# Check if args is in scope within the refcheck verification block
sed -n '655,675p' benchmarks/routines/quantization.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1064


🏁 Script executed:

# Check the function signature to confirm args is accessible
rg "def testNvfp4Quantize" benchmarks/routines/quantization.py -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 658


🏁 Script executed:

# Verify the complete context for nvfp4_batched_quantize as well
rg "def testNvfp4BatchedQuantize" benchmarks/routines/quantization.py -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 740


🏁 Script executed:

# Check if there are other similar shape checks that raise instead of warn
rg "expected_shape" benchmarks/routines/quantization.py -B 5 -A 8

Repository: flashinfer-ai/flashinfer

Length of output: 1673


Fail refcheck on nvfp4 shape mismatches (unless --allow_output_mismatch is set).

Shape mismatches currently only warn, which can silently pass invalid outputs during refcheck. Make them consistent with other quantization tests (mxfp8, mxfp4) by raising AssertionError unless args.allow_output_mismatch is set. This pattern is already established in the file for similar validation checks.

🛠️ Proposed fix
                if x_q.shape != expected_shape:
-                    print(
-                        f"[WARNING] Unexpected output shape: {x_q.shape}, expected {expected_shape}"
-                    )
+                    msg = (
+                        f"Unexpected output shape: {x_q.shape}, expected {expected_shape}"
+                    )
+                    if args.allow_output_mismatch:
+                        print(f"[WARNING] {msg}")
+                    else:
+                        raise AssertionError(msg)

Also applies to: 826-831

🤖 Prompt for AI Agents
In `@benchmarks/routines/quantization.py` around lines 668 - 673, The current
NVFP4 shape checks print a warning on mismatch; change them to raise an
AssertionError when x_q.shape != expected_shape unless
args.allow_output_mismatch is True. Locate the checks that compute
expected_shape = (m, k // 2) and compare to x_q.shape (and the similar block
later in the file) and replace the print warning with: if not
args.allow_output_mismatch: raise AssertionError(f"Unexpected output shape:
{x_q.shape}, expected {expected_shape}") so behavior matches the mxfp8/mxfp4
tests.

@yzh119 yzh119 enabled auto-merge (squash) January 16, 2026 03:02
@yzh119 yzh119 merged commit 94820ca into flashinfer-ai:main Jan 16, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants