benchmarks: Add norm and quantization routines to microbenchmark harness.#2362
benchmarks: Add norm and quantization routines to microbenchmark harness.#2362yzh119 merged 7 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds Norm (RMSNorm) and Quantization benchmark routines, CLI wiring, output schema/backends updates, many new sample test entries, and README documentation for the new APIs and flags. Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request extends the flashinfer_benchmark.py microbenchmark harness to support normalization and quantization routines, which is a great addition. The changes are well-structured, with new routines encapsulated in their own files and common utilities refactored. The documentation has also been updated accordingly.
I've identified a couple of areas for improvement in the new benchmark scripts to enhance user experience and correctness:
- In
routines/norm.py, the handling ofout_dtypefor FP4 routines could be made stricter to avoid user confusion. - In
routines/quantization.py, the command-line argument parsing for boolean flags can be made more idiomatic.
Overall, this is a solid contribution that significantly expands the benchmarking capabilities.
| parser.add_argument( | ||
| "--is_sf_swizzled_layout", | ||
| action="store_true", | ||
| default=True, | ||
| help="Use swizzled layout for scale factors. Default: True", | ||
| ) | ||
| parser.add_argument( | ||
| "--no_sf_swizzled_layout", | ||
| action="store_true", | ||
| default=False, | ||
| help="Disable swizzled layout for scale factors.", | ||
| ) |
There was a problem hiding this comment.
The current implementation for handling --is_sf_swizzled_layout and --no_sf_swizzled_layout is not idiomatic argparse usage. Using action="store_true" with default=True makes the --is_sf_swizzled_layout flag have no effect. A clearer approach is to use a single dest for two flags with opposite actions. This also allows removing the manual handling logic on lines 160-162.
parser.add_argument(
"--is_sf_swizzled_layout",
dest="is_sf_swizzled_layout",
action="store_true",
help="Use swizzled layout for scale factors. (default)",
)
parser.add_argument(
"--no_sf_swizzled_layout",
dest="is_sf_swizzled_layout",
action="store_false",
help="Disable swizzled layout for scale factors.",
)
parser.set_defaults(is_sf_swizzled_layout=True)There was a problem hiding this comment.
Fair point will address in latest commit.
benchmarks/routines/quantization.py
Outdated
| # Handle swizzled layout flag | ||
| if args.no_sf_swizzled_layout: | ||
| args.is_sf_swizzled_layout = False |
There was a problem hiding this comment.
Fair point will address in latest commit.
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
103-110: Avoid duplicate column names infull_output_columns.The new sections reuse
m,k,enable_pdl, andis_sf_swizzled_layout, so the concatenated header will contain duplicates. Many CSV readers (e.g., pandas) will auto-rename or clobber duplicates, which makes results ambiguous. Consider de-duplicating or prefixing category-specific fields.♻️ Proposed fix (de-duplicate while preserving order)
-full_output_columns = ( - output_column_dict["perf"] - + output_column_dict["attention"] - + output_column_dict["gemm"] - + output_column_dict["moe"] - + output_column_dict["norm"] - + output_column_dict["quantization"] - + output_column_dict["general"] -) +full_output_columns = list( + dict.fromkeys( + output_column_dict["perf"] + + output_column_dict["attention"] + + output_column_dict["gemm"] + + output_column_dict["moe"] + + output_column_dict["norm"] + + output_column_dict["quantization"] + + output_column_dict["general"] + ) +)
🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 241-257: The Norm and Quantization markdown tables lack
surrounding blank lines required by markdownlint MD058; add one blank line
before the "### Norm Flags" table and one blank line after its closing |---| row
(before the "### Quantization Flags" heading), and likewise ensure the
Quantization table has a blank line before and after it so both tables are
separated by empty lines from adjacent headings/content.
- Around line 32-42: Fix the markdown nested list indentation under the "Norm"
and "Quantization" sections by converting the 4-space indents to 2-space indents
for each nested list item (e.g., `rmsnorm`, `rmsnorm_quant`,
`fused_add_rmsnorm_quant`, `rmsnorm_fp4quant`, `add_rmsnorm_fp4quant`,
`mxfp8_quantize`, `mxfp4_quantize`, `nvfp4_quantize`, `nvfp4_batched_quantize`)
so the lists comply with MD007; locate the entries in the README where these
items are listed and reduce their leading spaces from four to two while
preserving the same bullet characters and text.
In `@benchmarks/routines/norm.py`:
- Around line 743-753: The code derives block_size from out_dtype (nvfp4→16,
mxfp4→32) but never verifies hidden_size is divisible by block_size, which can
cause kernel failures and wrong metrics; add an explicit check after block_size
is set that verifies hidden_size % block_size == 0 and raise a ValueError with a
clear message mentioning out_dtype, block_size and hidden_size (similar to the
check in testAddRmsnormFp4quant) so callers know the required alignment before
running the CuTe-DSL RMSNorm path; reference the variables block_size,
hidden_size and out_dtype near this logic in benchmarks/routines/norm.py.
In `@benchmarks/routines/quantization.py`:
- Around line 292-304: The refcheck branch computes round-trip stats via
is_close_stats but never enforces them; update the exception/validation logic
after is_close_stats in the quantization routine to check num_different_elements
and, if args.refcheck is set and num_different_elements > 0, raise an
AssertionError (unless args.allow_output_mismatch is true) and otherwise emit a
warning or log when mismatches exist; apply the same enforcement/warning logic
to the testMxfp4Quantize path so both routines respect --refcheck and
--allow_output_mismatch while still printing verbose stats guarded by
args.verbose.
- Around line 208-323: The code currently hardcodes a 32-byte scale-factor
vector size when validating k and computing bandwidth; update the checks and
math to use the actual alignment (sf_vec_size) passed to the kernel. Replace the
k % 32 validation with k % alignment (or k % sf_vec_size if you prefer to set
sf_vec_size = alignment early), and set sf_vec_size = alignment before computing
num_scale_factors so that num_scale_factors = (m * k) // sf_vec_size and
problem_bytes uses that value; apply the same replacement in any other
quantization tests that compute sf_vec_size or validate k to ensure consistency
between the alignment argument and bandwidth/validation logic (refer to
variables alignment, sf_vec_size, num_scale_factors, and the k validation
check).
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/README.mdbenchmarks/flashinfer_benchmark.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/norm.pybenchmarks/routines/quantization.pybenchmarks/samples/sample_testlist.txt
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
benchmarks/flashinfer_benchmark.pybenchmarks/routines/flashinfer_benchmark_utils.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
benchmarks/README.md
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads
Applied to files:
benchmarks/README.md
🧬 Code graph analysis (1)
benchmarks/routines/norm.py (1)
flashinfer/testing/utils.py (1)
bench_gpu_time(1508-1655)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
33-33: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
34-34: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
35-35: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
36-36: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
37-37: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
39-39: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
40-40: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
41-41: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
42-42: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
242-242: Tables should be surrounded by blank lines
(MD058, blanks-around-tables)
257-257: Tables should be surrounded by blank lines
(MD058, blanks-around-tables)
🪛 Ruff (0.14.11)
benchmarks/routines/quantization.py
53-53: Avoid specifying long messages outside the exception class
(TRY003)
210-210: Avoid specifying long messages outside the exception class
(TRY003)
219-221: Avoid specifying long messages outside the exception class
(TRY003)
244-244: Avoid specifying long messages outside the exception class
(TRY003)
304-304: Do not catch blind exception: Exception
(BLE001)
393-395: Avoid specifying long messages outside the exception class
(TRY003)
410-410: Avoid specifying long messages outside the exception class
(TRY003)
462-462: Do not catch blind exception: Exception
(BLE001)
569-571: Avoid specifying long messages outside the exception class
(TRY003)
599-599: Avoid specifying long messages outside the exception class
(TRY003)
718-718: Avoid specifying long messages outside the exception class
(TRY003)
727-729: Avoid specifying long messages outside the exception class
(TRY003)
751-751: Avoid specifying long messages outside the exception class
(TRY003)
benchmarks/routines/norm.py
55-55: Avoid specifying long messages outside the exception class
(TRY003)
198-200: Avoid specifying long messages outside the exception class
(TRY003)
223-223: Avoid specifying long messages outside the exception class
(TRY003)
269-271: Avoid specifying long messages outside the exception class
(TRY003)
359-361: Avoid specifying long messages outside the exception class
(TRY003)
365-367: Avoid specifying long messages outside the exception class
(TRY003)
396-396: Avoid specifying long messages outside the exception class
(TRY003)
451-453: Avoid specifying long messages outside the exception class
(TRY003)
539-541: Avoid specifying long messages outside the exception class
(TRY003)
545-547: Avoid specifying long messages outside the exception class
(TRY003)
579-579: Avoid specifying long messages outside the exception class
(TRY003)
649-651: Avoid specifying long messages outside the exception class
(TRY003)
751-753: Avoid specifying long messages outside the exception class
(TRY003)
762-764: Avoid specifying long messages outside the exception class
(TRY003)
806-806: Avoid specifying long messages outside the exception class
(TRY003)
921-923: Avoid specifying long messages outside the exception class
(TRY003)
932-934: Avoid specifying long messages outside the exception class
(TRY003)
979-979: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
benchmarks/samples/sample_testlist.txt (1)
52-153: Sample coverage looks solid for the new routines.Nice spread across 2D/3D shapes, FP8/FP4 variants, PDL, and layout options.
benchmarks/flashinfer_benchmark.py (1)
31-46: Routing + output defaulting looks good.The dispatch additions are clean, and the guarded output population preserves routine-specific formatting.
benchmarks/routines/flashinfer_benchmark_utils.py (3)
68-91: Norm/quantization output schema looks aligned with the new routines.Reintroducing shared fields under
generalkeeps the schema consistent across benchmark families.
133-145: Benchmark API registry updated appropriately.The new norm/quantization routines are wired into the benchmark API list as expected.
326-418: Backend support mapping covers the new routines.The per-SM backend lists are consistent with the intended CUDA/cute-dsl availability.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
|
Note: The PR does not touch the library code, so will not trigger a unit test. |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@benchmarks/routines/norm.py`:
- Around line 566-628: The timed runs reuse the same pre-evaluated input_args
tuple so the residual clone is shared across iterations; to guarantee a fresh
residual every iteration, clone the residual inside run_backend (e.g., at the
start of run_backend do residual_tensor = residual_tensor.clone()) before
calling flashinfer.norm.fused_add_rmsnorm_quant, so each invocation gets its own
independent residual regardless of how bench_gpu_time supplies input_args; refer
to run_backend, fused_add_rmsnorm_quant, and bench_gpu_time when making this
change.
In `@benchmarks/routines/quantization.py`:
- Around line 74-78: The CLI help for the "--k" argument is misleading because
divisibility depends on chosen quantization parameters (alignment for mxfp8 or
sf_vec_size for nvfp4); update the help text in the add_argument call for "--k"
to state that k must be divisible by the active quantization granularity (e.g.,
"must be divisible by alignment for mxfp8 or by sf_vec_size for nvfp4"), and
optionally reference which flags select those modes so users know which divisor
applies (check where "--quant-mode"/related flags and variables alignment and
sf_vec_size are defined and mentioned).
- Around line 281-316: The except block in the refcheck for mxfp8 (inside the
try that calls flashinfer.mxfp8_dequantize_host and is_close_stats) currently
only prints a warning; change it to re-raise the exception unless
args.allow_output_mismatch is true: in the except Exception as e: branch, keep
the verbose print but then if not args.allow_output_mismatch: raise (re-raise
the caught exception) so failures surface; otherwise, print the warning and
continue. Make the identical change to the mxfp4_quantize exception handler as
well, preserving its log tag ([mxfp4_quantize]) where applicable.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
benchmarks/routines/norm.pybenchmarks/routines/quantization.py
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/quantization.py (1)
benchmarks/flashinfer_benchmark.py (1)
parse_args(56-189)
benchmarks/routines/norm.py (2)
benchmarks/routines/flashinfer_benchmark_utils.py (5)
dtype_str_to_torch_dtype(179-193)get_device(156-165)print_perf_metrics(149-153)is_close_stats(168-176)filter_backends_by_compute_capability(422-440)benchmarks/routines/quantization.py (4)
run_backend(233-242)run_backend(421-425)run_backend(620-631)run_backend(781-789)
🪛 Ruff (0.14.11)
benchmarks/routines/quantization.py
53-53: Avoid specifying long messages outside the exception class
(TRY003)
208-208: Avoid specifying long messages outside the exception class
(TRY003)
217-219: Avoid specifying long messages outside the exception class
(TRY003)
242-242: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Abstract raise to an inner function
(TRY301)
313-313: Do not catch blind exception: Exception
(BLE001)
397-399: Avoid specifying long messages outside the exception class
(TRY003)
408-410: Avoid specifying long messages outside the exception class
(TRY003)
425-425: Avoid specifying long messages outside the exception class
(TRY003)
487-487: Abstract raise to an inner function
(TRY301)
488-488: Do not catch blind exception: Exception
(BLE001)
590-592: Avoid specifying long messages outside the exception class
(TRY003)
601-603: Avoid specifying long messages outside the exception class
(TRY003)
631-631: Avoid specifying long messages outside the exception class
(TRY003)
750-750: Avoid specifying long messages outside the exception class
(TRY003)
754-756: Avoid specifying long messages outside the exception class
(TRY003)
765-767: Avoid specifying long messages outside the exception class
(TRY003)
789-789: Avoid specifying long messages outside the exception class
(TRY003)
benchmarks/routines/norm.py
55-55: Avoid specifying long messages outside the exception class
(TRY003)
198-200: Avoid specifying long messages outside the exception class
(TRY003)
223-223: Avoid specifying long messages outside the exception class
(TRY003)
269-271: Avoid specifying long messages outside the exception class
(TRY003)
359-361: Avoid specifying long messages outside the exception class
(TRY003)
365-367: Avoid specifying long messages outside the exception class
(TRY003)
396-396: Avoid specifying long messages outside the exception class
(TRY003)
451-453: Avoid specifying long messages outside the exception class
(TRY003)
539-541: Avoid specifying long messages outside the exception class
(TRY003)
545-547: Avoid specifying long messages outside the exception class
(TRY003)
579-579: Avoid specifying long messages outside the exception class
(TRY003)
649-651: Avoid specifying long messages outside the exception class
(TRY003)
751-753: Avoid specifying long messages outside the exception class
(TRY003)
757-760: Avoid specifying long messages outside the exception class
(TRY003)
769-771: Avoid specifying long messages outside the exception class
(TRY003)
813-813: Avoid specifying long messages outside the exception class
(TRY003)
928-930: Avoid specifying long messages outside the exception class
(TRY003)
934-937: Avoid specifying long messages outside the exception class
(TRY003)
946-948: Avoid specifying long messages outside the exception class
(TRY003)
993-993: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
benchmarks/routines/quantization.py (3)
34-53: Dispatcher routing is clean and fail-fast.Clear routine mapping and an explicit error for unsupported routines.
530-707: NVFP4 quantize path looks solid.Good handling of layout mapping, CUDA‑graph compatibility, and
sf_vec_sizevalidation.
710-863: Batched NVFP4 benchmark flow looks good.Input validation, shape checks, and bandwidth accounting are clear.
benchmarks/routines/norm.py (6)
34-55: Norm dispatcher is straightforward.Clear routing and explicit failure for unknown routines.
69-149: CLI coverage for norm benchmarks looks complete.Arguments and defaults are well-aligned with the routine matrix.
152-310: RMSNorm benchmark flow looks correct.Reference check, timing, and metrics are consistent with the harness.
313-491: RMSNorm + FP8 quant benchmark is well-structured.Refcheck, dtype validation, and metrics look good.
693-867: FP4 RMSNorm benchmark looks solid.Alignment validation and bandwidth accounting are clear.
870-1049: Add‑RMSNorm FP4 path is in good shape.Input prep, backend routing, and metrics look consistent.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| "--k", | ||
| type=int, | ||
| required=True, | ||
| help="Number of columns in input tensor (must be divisible by 32).", | ||
| ) |
There was a problem hiding this comment.
Clarify the --k divisibility rule in CLI help.
The help text says “divisible by 32,” but actual validation depends on alignment (mxfp8) or sf_vec_size (nvfp4). This can mislead users.
✏️ Suggested help text tweak
- help="Number of columns in input tensor (must be divisible by 32).",
+ help="Number of columns in input tensor (must be divisible by alignment or sf_vec_size for the selected routine).",📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "--k", | |
| type=int, | |
| required=True, | |
| help="Number of columns in input tensor (must be divisible by 32).", | |
| ) | |
| "--k", | |
| type=int, | |
| required=True, | |
| help="Number of columns in input tensor (must be divisible by alignment or sf_vec_size for the selected routine).", | |
| ) |
🤖 Prompt for AI Agents
In `@benchmarks/routines/quantization.py` around lines 74 - 78, The CLI help for
the "--k" argument is misleading because divisibility depends on chosen
quantization parameters (alignment for mxfp8 or sf_vec_size for nvfp4); update
the help text in the add_argument call for "--k" to state that k must be
divisible by the active quantization granularity (e.g., "must be divisible by
alignment for mxfp8 or by sf_vec_size for nvfp4"), and optionally reference
which flags select those modes so users know which divisor applies (check where
"--quant-mode"/related flags and variables alignment and sf_vec_size are defined
and mentioned).
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/quantization.py`:
- Around line 668-673: The current NVFP4 shape checks print a warning on
mismatch; change them to raise an AssertionError when x_q.shape !=
expected_shape unless args.allow_output_mismatch is True. Locate the checks that
compute expected_shape = (m, k // 2) and compare to x_q.shape (and the similar
block later in the file) and replace the print warning with: if not
args.allow_output_mismatch: raise AssertionError(f"Unexpected output shape:
{x_q.shape}, expected {expected_shape}") so behavior matches the mxfp8/mxfp4
tests.
♻️ Duplicate comments (1)
benchmarks/routines/quantization.py (1)
74-78: Clarify--kdivisibility rule in CLI help.The divisor depends on
alignmentorsf_vec_size, not always 32.✏️ Suggested help text tweak
- help="Number of columns in input tensor (must be divisible by 32).", + help="Number of columns in input tensor (must be divisible by alignment for mxfp8 or sf_vec_size for nvfp4).",
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/routines/quantization.py
🧰 Additional context used
🪛 Ruff (0.14.11)
benchmarks/routines/quantization.py
53-53: Avoid specifying long messages outside the exception class
(TRY003)
208-208: Avoid specifying long messages outside the exception class
(TRY003)
217-219: Avoid specifying long messages outside the exception class
(TRY003)
242-242: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Abstract raise to an inner function
(TRY301)
401-403: Avoid specifying long messages outside the exception class
(TRY003)
412-414: Avoid specifying long messages outside the exception class
(TRY003)
429-429: Avoid specifying long messages outside the exception class
(TRY003)
491-491: Abstract raise to an inner function
(TRY301)
598-600: Avoid specifying long messages outside the exception class
(TRY003)
609-611: Avoid specifying long messages outside the exception class
(TRY003)
639-639: Avoid specifying long messages outside the exception class
(TRY003)
758-758: Avoid specifying long messages outside the exception class
(TRY003)
762-764: Avoid specifying long messages outside the exception class
(TRY003)
773-775: Avoid specifying long messages outside the exception class
(TRY003)
797-797: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
benchmarks/routines/quantization.py (3)
34-53: Clear routine dispatch.Simple, explicit mapping with a clean fallback error.
205-319: Alignment validation + refcheck enforcement look solid.Nice to see alignment-based checks and explicit mismatch handling.
398-499: Refcheck round‑trip handling is consistent.Mismatch handling aligns well with the mxfp8 path.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| # Verify output shape (M, K/2) for FP4 | ||
| expected_shape = (m, k // 2) | ||
| if x_q.shape != expected_shape: | ||
| print( | ||
| f"[WARNING] Unexpected output shape: {x_q.shape}, expected {expected_shape}" | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the specific lines mentioned in the review
cat -n benchmarks/routines/quantization.py | sed -n '665,675p'Repository: flashinfer-ai/flashinfer
Length of output: 644
🏁 Script executed:
# Also check the second location (826-831)
cat -n benchmarks/routines/quantization.py | sed -n '823,835p'Repository: flashinfer-ai/flashinfer
Length of output: 777
🏁 Script executed:
# Search for allow_output_mismatch usage in the file
rg "allow_output_mismatch" benchmarks/routines/quantization.py -nRepository: flashinfer-ai/flashinfer
Length of output: 303
🏁 Script executed:
# Get more context around the nvfp4 function to understand the structure
rg "def.*nvfp4" benchmarks/routines/quantization.py -A 5 -nRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check how args is passed to these functions
rg "args\." benchmarks/routines/quantization.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1058
🏁 Script executed:
# Check the context around allow_output_mismatch usage to understand its purpose
rg "allow_output_mismatch" benchmarks/routines/quantization.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1978
🏁 Script executed:
# Look for the full nvfp4_quantize and nvfp4_batched_quantize function definitions
rg "def.*nvfp4.*quantize" benchmarks/routines/quantization.py -A 50 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if refcheck is used and how
rg "refcheck\|run_refcheck" benchmarks/routines/quantization.py -n -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get the broader context of lines 668-673 (around 50 lines before to understand function)
cat -n benchmarks/routines/quantization.py | sed -n '620,680p'Repository: flashinfer-ai/flashinfer
Length of output: 3021
🏁 Script executed:
# Find the function definitions for nvfp4_quantize and nvfp4_batched_quantize
rg "def.*nvfp4.*:" benchmarks/routines/quantization.py -nRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if args is accessible within the run_backend and refcheck logic
sed -n '600,650p' benchmarks/routines/quantization.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2413
🏁 Script executed:
# Search for all shape verification patterns in the file to see if there's consistency
rg "expected_shape" benchmarks/routines/quantization.py -B 2 -A 5 -nRepository: flashinfer-ai/flashinfer
Length of output: 1081
🏁 Script executed:
# Check where nvfp4_quantize function ends to see the full scope
rg "def.*nvfp4_quantize" benchmarks/routines/quantization.py -nRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get the function definition for nvfp4_quantize
sed -n '600,900p' benchmarks/routines/quantization.py | head -200Repository: flashinfer-ai/flashinfer
Length of output: 7868
🏁 Script executed:
# Check if args is in scope within the refcheck verification block
sed -n '655,675p' benchmarks/routines/quantization.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1064
🏁 Script executed:
# Check the function signature to confirm args is accessible
rg "def testNvfp4Quantize" benchmarks/routines/quantization.py -A 20Repository: flashinfer-ai/flashinfer
Length of output: 658
🏁 Script executed:
# Verify the complete context for nvfp4_batched_quantize as well
rg "def testNvfp4BatchedQuantize" benchmarks/routines/quantization.py -A 20Repository: flashinfer-ai/flashinfer
Length of output: 740
🏁 Script executed:
# Check if there are other similar shape checks that raise instead of warn
rg "expected_shape" benchmarks/routines/quantization.py -B 5 -A 8Repository: flashinfer-ai/flashinfer
Length of output: 1673
Fail refcheck on nvfp4 shape mismatches (unless --allow_output_mismatch is set).
Shape mismatches currently only warn, which can silently pass invalid outputs during refcheck. Make them consistent with other quantization tests (mxfp8, mxfp4) by raising AssertionError unless args.allow_output_mismatch is set. This pattern is already established in the file for similar validation checks.
🛠️ Proposed fix
if x_q.shape != expected_shape:
- print(
- f"[WARNING] Unexpected output shape: {x_q.shape}, expected {expected_shape}"
- )
+ msg = (
+ f"Unexpected output shape: {x_q.shape}, expected {expected_shape}"
+ )
+ if args.allow_output_mismatch:
+ print(f"[WARNING] {msg}")
+ else:
+ raise AssertionError(msg)Also applies to: 826-831
🤖 Prompt for AI Agents
In `@benchmarks/routines/quantization.py` around lines 668 - 673, The current
NVFP4 shape checks print a warning on mismatch; change them to raise an
AssertionError when x_q.shape != expected_shape unless
args.allow_output_mismatch is True. Locate the checks that compute
expected_shape = (m, k // 2) and compare to x_q.shape (and the similar block
later in the file) and replace the print warning with: if not
args.allow_output_mismatch: raise AssertionError(f"Unexpected output shape:
{x_q.shape}, expected {expected_shape}") so behavior matches the mxfp8/mxfp4
tests.
📌 Description
No changes to library code
Extends the
flashinfer_benchmark.pymicrobenchmark harness to support normalization and quantization routines. This enables performance benchmarking for RMSNorm variants (including fused residual-add and FP8/FP4 quantization) and standalone quantization kernels (MxFP8, MxFP4, NVFP4).New Routines
Example outputs:
🔍 Related Issues
#2361
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.