chore: update benchmark scripts; fix trtllm-gen moe comments#2412
chore: update benchmark scripts; fix trtllm-gen moe comments#2412IwakuraRein merged 11 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds MXInt4 quantization and an MXInt4xBf16 autotuner benchmark path to the fused MoE benchmarking script, refactors FP8/FP4 bench flows to use functools.partial with input_kwargs, extends CLI quant-mode choices, and clarifies BlockMajorK weight-layout shapes in fused_moe/core docstrings. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as CLI/Runner
participant Bench as Bench Script
participant Quant as MXInt4 Quantizer
participant TRTL as trtllm_mxint4_block_scale_moe
participant GPU as GPU/Autotuner
CLI->>Bench: invoke with quant_mode=MxInt4xBf16
Bench->>Quant: mxint4_quantize(input_tensor)
Quant-->>Bench: quantized_tensor + scales
Bench->>TRTL: call trtllm_mxint4_block_scale_moe(quantized_tensor, scales, routing_bias, ...)
TRTL->>GPU: submit kernels / autotune runs
GPU-->>TRTL: perf metrics
TRTL-->>Bench: results
Bench->>CLI: report timings / best config
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request expands the benchmarking suite by introducing a new script for FP4 matrix multiplication, which now includes the Cutlass backend for performance evaluation. It also enhances the existing Mixture-of-Experts (MoE) benchmark script to incorporate MXINT4 quantization, allowing for more comprehensive performance analysis of different quantization strategies. These additions aim to provide better insights into the efficiency of various low-precision arithmetic implementations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request updates benchmark scripts for MoE and adds a new benchmark for FP4 matrix multiplication. The changes are primarily within benchmark files. I've identified a few issues, including a potential bug in the new mxint4_quantize function and an incorrect return type in bench_mm_fp4.py. My comments provide suggestions to address these points.
benchmarks/bench_mm_fp4.py
Outdated
| if not mm_fp4.is_backend_supported(backend, compute_capability_number): | ||
| print( | ||
| f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." | ||
| ) | ||
| return |
There was a problem hiding this comment.
The function _bench_mm_fp4 is type-hinted to return a tuple[float, float], but the early return statements on lines 39, 44, 47, 50, and 53 return None. This will lead to a TypeError at the call site when attempting to unpack the None value. To ensure type consistency and prevent runtime errors, the function should always return a tuple of two floats. For skipped tests, you could return a sentinel value like (float('inf'), 0.0).
print(
f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}."
)
return float('inf'), 0.0| scales = amax / 8.0 | ||
| x_scaled = x_reshaped * scales.reciprocal() | ||
| x_int8 = ( | ||
| x_scaled.round().clamp(-8, 7).to(torch.uint8).reshape(-1, sf_vec_size // 2, 2) |
There was a problem hiding this comment.
Casting the float tensor from x_scaled.round().clamp(-8, 7) directly to torch.uint8 will cause all negative values to be clamped to 0, leading to incorrect quantization. To preserve the negative values, you should first cast to torch.int8 and then use .view(torch.uint8) to reinterpret the bits for the subsequent bitwise packing operations.
| x_scaled.round().clamp(-8, 7).to(torch.uint8).reshape(-1, sf_vec_size // 2, 2) | |
| x_scaled.round().clamp(-8, 7).to(torch.int8).view(torch.uint8).reshape(-1, sf_vec_size // 2, 2) |
| return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.reshape( | ||
| -1, sf_vec_size | ||
| ) |
There was a problem hiding this comment.
The scales.reshape(-1, sf_vec_size) operation can fail if the number of elements in scales is not a multiple of sf_vec_size. Since the caller of this function already reshapes the returned scales tensor, it's safer to remove this intermediate reshape and return the scales tensor directly. This avoids a potential runtime error and simplifies the function.
return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales
benchmarks/bench_mm_fp4.py
Outdated
| print("mx_fp4 is only supported for cudnn and auto backends") | ||
| return | ||
|
|
||
| input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) |
There was a problem hiding this comment.
The variable name input shadows the Python built-in function input(). This is generally discouraged as it can lead to confusion and potential bugs if the built-in is needed later in the scope. Consider renaming it to something more descriptive, like input_tensor, and applying this change to all its occurrences.
input_tensor = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)| x_reshaped = x.reshape(-1, sf_vec_size) | ||
| x_max = x_reshaped.max(dim=-1, keepdim=True)[0].to(torch.float32) | ||
| x_min = x_reshaped.min(dim=-1, keepdim=True)[0].to(torch.float32) | ||
| x_max = x_max * 8.0 / 7.0 |
b3c8b1e to
4d76c59
Compare
… scripts Signed-off-by: Siyuan Fu <[email protected]>
4d76c59 to
14550ae
Compare
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/bench_mm_fp4.py`:
- Around line 38-55: The function _bench_mm_fp4 currently uses early bare
returns that yield None while its signature declares -> tuple[float, float];
update every early-exit path inside _bench_mm_fp4 (the branches printing
"Skipping test..." and the trtllm/mx_fp4 checks) to return a consistent tuple of
floats (e.g., (0.0, 0.0) or (float("nan"), float("nan"))), so callers like ms,
tflops = _bench_mm_fp4(...) can always unpack safely; keep the existing
print/log lines and replace each plain return with the chosen float tuple.
In `@benchmarks/bench_trtllm_gen_fused_moe_autotuner.py`:
- Around line 392-422: The partial for trtllm_mxint4_block_scale_moe binds
hidden_states and input_kwargs also includes "hidden_states", causing a
duplicate-argument TypeError; remove the hidden_states binding from the partial
call (leave it to be passed via input_kwargs) so trtllm_mxint4_block_scale_moe
is only given hidden_states once when calling fn(**input_kwargs).
🧹 Nitpick comments (2)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2)
34-50: Minor dtype difference from test implementation.Line 45 uses
torch.uint8for the intermediatex_int8, while the test implementation attests/moe/test_trtllm_gen_fused_moe.py:590-606usestorch.int8. Both produce the same packed result due to the bitwise masking (& 0x0F), so this is functionally equivalent.The scale tensor reshape at lines 48-50 returns shape
(-1, sf_vec_size)which appears to be inconsistent with how it's later reshaped at lines 380-384 and 386-390. The intermediate reshape seems unnecessary sincescalesis already the correct shape from line 42.✨ Optional: simplify scale return
- return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.reshape( - -1, sf_vec_size - ) + return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2), scales.squeeze(-1)This would return scales with shape
(-1,)which is then reshaped at the call site anyway.
368-368: Minor: inconsistent device specification.Line 368 uses
device="cuda"while line 365 definesdevice = torch.device("cuda:0"). For consistency with the rest of the function, use thedevicevariable.✨ Proposed fix
- routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16)
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/bench_trtllm_gen_fused_moe_autotuner.py`:
- Around line 34-50: The mxint4_quantize function uses torch.uint8 when
converting rounded/clamped values but the clamped range is signed (−8..7);
change the cast from torch.uint8 to torch.int8 in mxint4_quantize (the block
that creates x_int8) so the tensor uses a signed dtype consistent with the test
implementation, keep the subsequent bitwise masking and packing logic unchanged
and ensure scales reshape remains the same.
🧹 Nitpick comments (4)
benchmarks/bench_mm_fp4.py (2)
114-114: FIXME comment indicates a known issue with Cupti.The comment notes that Cupti causes CUDA Illegal Memory Access. Consider opening an issue to track this problem so it can be properly investigated and resolved.
Would you like me to help draft an issue to track this Cupti-related bug?
150-159: Consider exposingfp4_typeandres_dtypeas CLI arguments.The benchmark currently hardcodes
"nvfp4"forfp4_typeandtorch.bfloat16forres_dtype, but_bench_mm_fp4supports other values like"mxfp4","mxfp4_alpha", andtorch.float16. Exposing these as CLI arguments would allow benchmarking all supported configurations.♻️ Proposed enhancement
parser.add_argument( "--backend", type=str, nargs="+", default=["cudnn", "trtllm", "cutlass"] ) + parser.add_argument( + "--fp4-type", type=str, default="nvfp4", choices=["nvfp4", "mxfp4", "mxfp4_alpha"] + ) args = parser.parse_args() for m, n, k in product(args.m, args.n, args.k): print(f"m={m}, n={n}, k={k}".center(100, "-")) for backend in args.backend: print(f" {backend}:") ms, tflops = _bench_mm_fp4( - m, n, k, torch.bfloat16, backend, True, "nvfp4", False + m, n, k, torch.bfloat16, backend, True, args.fp4_type, False )benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2)
354-364: Unused parameterquant_modecould be validated or documented.The
quant_modeparameter is declared but never used in the function body. Since this function only supports"MxInt4xBf16", consider either:
- Adding an assertion to validate the expected mode
- Removing the parameter if it's only for signature consistency
♻️ Option 1: Add validation
def bench_trtllm_gen_fused_moe_autotuner_mxint4( tune_max_num_tokens: Optional[int], quant_mode: Literal["MxInt4xBf16"], ... ): + assert quant_mode == "MxInt4xBf16", f"Unsupported quant_mode: {quant_mode}" device = torch.device("cuda:0")
484-490: Consider dict-based dispatch for better maintainability.The nested ternary works but may become harder to maintain as more quantization modes are added. A dictionary-based dispatch would be clearer.
♻️ Proposed refactor
- fn = ( - bench_trtllm_gen_fused_moe_autotuner_fp8 - if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"] - else bench_trtllm_gen_fused_moe_autotuner_mxint4 - if args.quant_mode == "MxInt4xBf16" - else bench_trtllm_gen_fused_moe_autotuner_fp4 - ) + dispatch = { + "Fp8-Per-Tensor": bench_trtllm_gen_fused_moe_autotuner_fp8, + "Fp8-Block": bench_trtllm_gen_fused_moe_autotuner_fp8, + "MxInt4xBf16": bench_trtllm_gen_fused_moe_autotuner_mxint4, + "NvFP4xNvFP4": bench_trtllm_gen_fused_moe_autotuner_fp4, + "MxFP4xMxFP8": bench_trtllm_gen_fused_moe_autotuner_fp4, + "MxFP4xBf16": bench_trtllm_gen_fused_moe_autotuner_fp4, + } + fn = dispatch[args.quant_mode]
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/bench_trtllm_gen_fused_moe_autotuner.py`:
- Around line 354-364: The function bench_trtllm_gen_fused_moe_autotuner_mxint4
declares an unused parameter quant_mode which triggers an ARG001 lint; rename
quant_mode to _quant_mode in the function signature so the parameter intent is
clear and Ruff treats it as intentionally unused, and update any callers or
references if they exist to match the new parameter name.
| def bench_trtllm_gen_fused_moe_autotuner_mxint4( | ||
| tune_max_num_tokens: Optional[int], | ||
| quant_mode: Literal["MxInt4xBf16"], | ||
| num_tokens: int, | ||
| num_experts: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| top_k: int, | ||
| warmups: int, | ||
| iterations: int, | ||
| ): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's look at the full function to see if quant_mode is used anywhere
# Search for the function and capture its content
sed -n '354,400p' benchmarks/bench_trtllm_gen_fused_moe_autotuner.pyRepository: flashinfer-ai/flashinfer
Length of output: 1557
🏁 Script executed:
#!/bin/bash
# Find where the function ends and check for quant_mode references within it
# First, find the line count of the file
wc -l benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
# Then search for the function and any references to quant_mode
rg -A 150 "def bench_trtllm_gen_fused_moe_autotuner_mxint4\(" benchmarks/bench_trtllm_gen_fused_moe_autotuner.py | head -200Repository: flashinfer-ai/flashinfer
Length of output: 4806
Rename unused quant_mode parameter to _quant_mode to satisfy Ruff (ARG001).
The parameter is declared but never referenced in the function body.
🔧 Proposed fix
def bench_trtllm_gen_fused_moe_autotuner_mxint4(
tune_max_num_tokens: Optional[int],
- quant_mode: Literal["MxInt4xBf16"],
+ _quant_mode: Literal["MxInt4xBf16"],
num_tokens: int,🧰 Tools
🪛 Ruff (0.14.14)
[warning] 356-356: Unused function argument: quant_mode
(ARG001)
🤖 Prompt for AI Agents
In `@benchmarks/bench_trtllm_gen_fused_moe_autotuner.py` around lines 354 - 364,
The function bench_trtllm_gen_fused_moe_autotuner_mxint4 declares an unused
parameter quant_mode which triggers an ARG001 lint; rename quant_mode to
_quant_mode in the function signature so the parameter intent is clear and Ruff
treats it as intentionally unused, and update any callers or references if they
exist to match the new parameter name.
bkryu
left a comment
There was a problem hiding this comment.
Hi @IwakuraRein, today the flashinfer_benchmark.py supports benchmarking mm_fp4. Example commands are:
flashinfer_benchmark.py --routine mm_fp4 --m 256 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --use_cupti
flashinfer_benchmark.py --routine mm_fp4 --m 64 --n 8192 --k 2048 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --use_cupti
It seems like bench_mm_fp4.py is a recreation of the functionality without much being added. Is there a reason to create a separate script for it?
Thanks. I didn't realize there is a benchmark for fp4. Let me remove the |
Signed-off-by: Siyuan Fu <[email protected]>
|
/bot run |
|
[CANCELING] Pipeline #43286865: canceled |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
1-5:⚠️ Potential issue | 🟡 MinorPipeline failures: pre-commit hooks need to be re-run.
The CI shows two pre-commit hook failures:
end-of-file-fixermodified files (likely missing newline at EOF)ruff-formatreformatted 1 filePlease run
pre-commit run --all-filesand commit the changes to fix these issues.
🤖 Fix all issues with AI agents
In `@benchmarks/bench_trtllm_gen_fused_moe_autotuner.py`:
- Around line 374-375: The code creates routing_bias with a hardcoded
device="cuda" which is inconsistent with the local device variable (e.g.,
torch.device("cuda:0")) and can break multi-GPU runs; update the routing_bias
creation to use the same device variable (routing_bias =
torch.randn(num_experts, device=device, dtype=...) ) so both routing_logits and
routing_bias use the unified device, ensuring consistent behaviour across GPUs
and matching dtype handling with routing_logits where needed.
🧹 Nitpick comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
500-506: Consider dict-based dispatch for extensibility.The nested ternary works but becomes harder to maintain as more quant modes are added. A dict mapping would be cleaner.
♻️ Optional refactor using dict dispatch
bench_functions = { "Fp8-Per-Tensor": bench_trtllm_gen_fused_moe_autotuner_fp8, "Fp8-Block": bench_trtllm_gen_fused_moe_autotuner_fp8, "MxInt4xBf16": bench_trtllm_gen_fused_moe_autotuner_mxint4, } fn = bench_functions.get(args.quant_mode, bench_trtllm_gen_fused_moe_autotuner_fp4)
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/bench_trtllm_gen_fused_moe_autotuner.py`:
- Around line 28-51: The mxint4_quantize function lacks validation for the
input's last-dimension divisibility by sf_vec_size and can produce inf/NaN when
blocks are zero; add an input check at the start of mxint4_quantize that raises
a clear ValueError if x.shape[-1] % sf_vec_size != 0, then after computing amax
ensure you clamp amax to a small epsilon (e.g., torch.finfo(x.dtype).tiny or
1e-8) before computing scales = amax / 8.0 so scales.reciprocal() cannot produce
infinities; keep the rest of the existing reshaping/packing logic and return
types as-is, referencing mxint4_quantize, x_reshaped, amax, scales, and x_int4
to locate the changes.
Signed-off-by: Siyuan Fu <[email protected]>
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation