Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 88 additions & 4 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ The aim of `flashinfer_benchmark.py` is to provide a single framework for benchm
## Overview

This framework provides tools to:
- Benchmark FlashInfer's Attention, GEMM, and MOE API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, and TensorRT-LLM
- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, and Quantization API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM
- Compare performance across different configurations
- Batch performance test multiple attention test cases
- Batch performance test multiple test cases

Currently supports testing most attention, gemm, and fused MOE APIs:
Currently supports testing attention, gemm, fused MOE, normalization, and quantization APIs:
- Attention:
- `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache.
- Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`.
Expand All @@ -29,6 +29,17 @@ Currently supports testing most attention, gemm, and fused MOE APIs:
- `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling.
- `trtllm_fp8_per_tensor_scale_moe` - MOE with FP8 quantized weights and per-tensor scaling.
- `cutlass_fused_moe` - CUTLASS fused MoE (base/fp8/nvfp4 variants with optional TP/EP)
- Norm:
- `rmsnorm` - Root Mean Square Layer Normalization.
- `rmsnorm_quant` - RMSNorm with FP8 quantized output.
- `fused_add_rmsnorm_quant` - Fused residual add + RMSNorm with FP8 quantized output.
- `rmsnorm_fp4quant` - RMSNorm with FP4 quantized output (CuTe-DSL, Blackwell SM10.0+).
- `add_rmsnorm_fp4quant` - Fused residual add + RMSNorm with FP4 quantized output (CuTe-DSL, Blackwell SM10.0+).
- Quantization:
- `mxfp8_quantize` - Quantize tensor to MxFP8 format (Blackwell SM10.0+).
- `mxfp4_quantize` - Quantize tensor to MxFP4 format (Blackwell SM10.0+).
- `nvfp4_quantize` - Quantize tensor to NVFP4 format with configurable scale factor layout (Blackwell SM10.0+).
- `nvfp4_batched_quantize` - Batched NVFP4 quantization (Blackwell SM10.0+).

## Quick Start
### Single Test Run
Expand Down Expand Up @@ -81,6 +92,35 @@ $ python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper
[PERF] fa2 :: median time 0.495 ms; std 0.006 ms; achieved tflops 219.336 TFLOPs/sec; achieved tb_per_sec 1.736 TB/sec
[PERF] cutlass :: median time 0.530 ms; std 0.002 ms; achieved tflops 204.674 TFLOPs/sec; achieved tb_per_sec 1.620 TB/sec
[PERF] cudnn :: median time 0.313 ms; std 0.000 ms; achieved tflops 346.715 TFLOPs/sec; achieved tb_per_sec 2.745 TB/sec

# RMSNorm with FP8 quantized output
$ python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e4m3"
[INFO] Running testRmsnormQuant
[INFO] FlashInfer version: 0.6.1
[VVERBOSE] gpu_name = 'NVIDIA_B300_SXM6_AC'
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_fp8_e4m3
[VVERBOSE] input_tensor.shape = torch.Size([32, 4096])
[VVERBOSE] input_tensor.dtype = torch.bfloat16
[VVERBOSE] weight.shape = torch.Size([4096])
[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn
[VVERBOSE] scale = 1.0
[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.229 TFLOPs/sec; achieved tb_per_sec 0.140 TB/sec

# MxFP8 Quantization (Blackwell SM10.0+ only)
$ python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp8_quantize"
[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, no_sf_swizzled_layout=False, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16)
[INFO] Running testMxfp8Quantize
[INFO] FlashInfer version: 0.6.1
[VVERBOSE] gpu_name = 'NVIDIA_B300_SXM6_AC'
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag mxfp8_quantize
[VVERBOSE] input_tensor.shape = torch.Size([2048, 8192])
[VVERBOSE] input_tensor.dtype = torch.bfloat16
[VVERBOSE] is_sf_swizzled_layout = True
[VVERBOSE] alignment = 32
[VVERBOSE] enable_pdl = False
[VVERBOSE] Backend cuda: x_q.shape = torch.Size([2048, 8192]), x_q.dtype = torch.float8_e4m3fn, sf.shape = torch.Size([524288]), sf.dtype = torch.uint8
[VVERBOSE] Round-trip error: 0/16777216 (0.00%) elements differ
[PERF] cuda :: median time 0.016 ms; std 0.000 ms; achieved tflops 3.118 TFLOPs/sec; achieved tb_per_sec 3.150 TB/sec
```

### Batch Testing
Expand All @@ -104,7 +144,7 @@ The output CSV will contain detailed metrics including:
### General Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--routine` | Test routine to run: `BatchDecodeWithPagedKVCacheWrapper`, `BatchPrefillWithPagedKVCacheWrapper`, `BatchPrefillWithRaggedKVCacheWrapper`, `BatchMLAPagedAttentionWrapper`, `gemm_fp8_nt_groupwise`, `group_gemm_fp8_nt_groupwise`, `bmm_fp8`, `mm_fp4`, `trtllm_fp4_block_scale_moe`, `trtllm_fp8_block_scale_moe`, `trtllm_fp8_per_tensor_scale_moe`, `cutlass_fused_moe` |
| `--routine` | Test routine to run. See [Overview](#overview) for full list including attention, GEMM, MOE, norm, and quantization routines. |
| `--num_iters` | Number of iterations for performance measurement |
| `--dry_run_iters` | Number of warmup iterations |
| `--no_cuda_graph` | Disable CUDA graph to execute kernels outside of the graph. |
Expand Down Expand Up @@ -198,6 +238,38 @@ Notes:
- FP8 MOE kernels require integer values for group parameters, while FP4 MOE kernels accept optional values.
- CUTLASS fused MoE (`cutlass_fused_moe`) ignores `--routing_method`, `--n_group`, and `--topk_group`; it computes routing via softmax+top-k internally from the provided logits.

### Norm Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--hidden_size` | Hidden dimension size |
| `--num_heads` | Number of heads for 3D input shape (batch, num_heads, hidden_size). Optional; if not set, uses 2D shape. |
| `--input_dtype` | Input data type: `bfloat16` (default) or `float16` |
| `--eps` | Epsilon for numerical stability. Default: 1e-6 |
| `--enable_pdl` | Enable programmatic dependent launch |
| `--scale` | Scale factor for FP8 quantization (used by `rmsnorm_quant`, `fused_add_rmsnorm_quant`). Default: 1.0 |
| `--out_dtype` | Output dtype: `fp8_e4m3`, `fp8_e5m2` (for FP8 quant); `nvfp4`, `mxfp4` (for FP4 quant). Default: `fp8_e4m3`|
| `--use_global_scale` | Use global scale factor for NVFP4 format (FP4 routines only) |
| `--is_sf_swizzled_layout`| Use swizzled scale factor layout for tensor core GEMM (FP4 routines only) |
| `--backends` | Backend to test: `cuda` (default) or `cute-dsl` (for FP4 routines) |

### Quantization Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--m` | Number of rows in input tensor |
| `--k` | Number of columns in input tensor (must be divisible by 32) |
| `--input_dtype` | Input data type: `bfloat16` (default) or `float16` |
| `--is_sf_swizzled_layout`| Use swizzled layout for scale factors. Default: True |
| `--no_sf_swizzled_layout`| Disable swizzled layout for scale factors |
| `--alignment` | sfVecSize for quantization. Default: 32 |
| `--enable_pdl` | Enable programmatic dependent launch |
| `--batch_size` | Batch size for batched quantization (`nvfp4_batched_quantize` only) |
| `--global_scale` | Global scale factor for NVFP4 quantization. Default: 1.0 |
| `--sf_layout` | Scale factor layout for NVFP4: `128x4` (default), `8x4`, or `linear` |
| `--do_shuffle` | Shuffle scale factors for TRTLLM backend (`nvfp4_quantize` only) |
| `--sf_vec_size` | Scale factor vector size for NVFP4 quantization. Default: 16 |
| `--backends` | Backend to test. Default: `cuda` |

## `flashinfer_benchmark.py` Routine & Backend Support Matrix
The following table summarizes the support surface of each routine & backend's on various [CUDA Compute Capabilities](https://developer.nvidia.com/cuda-gpus).

Expand Down Expand Up @@ -228,13 +300,25 @@ Legend:
| **trtllm_fp8_block_scale_moe** | | | | | | trtllm | trtllm | |
| **trtllm_fp8_per_tensor_scale_moe** | | | | | | trtllm | trtllm | |
| **cutlass_fused_moe** | | | | | | cutlass | cutlass | |
| **rmsnorm** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda |
| **rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda |
| **fused_add_rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda |
| **rmsnorm_fp4quant** | | | | | | cute-dsl | cute-dsl | |
| **add_rmsnorm_fp4quant** | | | | | | cute-dsl | cute-dsl | |
| **mxfp8_quantize** | | | | | | cuda | cuda | |
| **mxfp4_quantize** | | | | | | cuda | cuda | |
| **nvfp4_quantize** | | | | | | cuda | cuda | |
| **nvfp4_batched_quantize** | | | | | | cuda | cuda | |

Backend Legend:
- fa2: FlashAttention2
- fa2_tc: FlashAttention2 (with Tensor Cores for `BatchDecodeWithPagedKVCacheWrapper`)
- fa3: FlashAttention-3
- cudnn: cuDNN
- cublas: cuBLAS
- cutlass: CUTLASS
- trtllm: TensorRT-LLM
- trtllm-gen: TensorRT-LLM
- trtllm-native: TensorRT-LLM (out-of-wrapper)
- cuda: FlashInfer CUDA kernels
- cute-dsl: FlashInfer CuTe-DSL kernels (Blackwell SM10.0+)
19 changes: 17 additions & 2 deletions benchmarks/flashinfer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
)
from routines.gemm import parse_gemm_args, run_gemm_test
from routines.moe import parse_moe_args, run_moe_test
from routines.norm import parse_norm_args, run_norm_test
from routines.quantization import parse_quantization_args, run_quantization_test


def run_test(args):
Expand All @@ -26,6 +28,10 @@ def run_test(args):
res = run_gemm_test(args)
elif args.routine in benchmark_apis["moe"]:
res = run_moe_test(args)
elif args.routine in benchmark_apis["norm"]:
res = run_norm_test(args)
elif args.routine in benchmark_apis["quantization"]:
res = run_quantization_test(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand All @@ -34,7 +40,10 @@ def run_test(args):
with open(args.output_path, "a") as fout:
for cur_res in res:
for key in output_column_dict["general"]:
cur_res[key] = getattr(args, key)
# Only set from args if the routine hasn't already set a value
# This preserves routine-specific formatting while providing defaults
if key not in cur_res or cur_res[key] == "":
cur_res[key] = getattr(args, key, "")

output_line = ",".join(
[str(cur_res[col]) for col in full_output_columns]
Expand Down Expand Up @@ -65,7 +74,9 @@ def parse_args(line=sys.argv[1:]):
required=True,
choices=list(benchmark_apis["attention"])
+ list(benchmark_apis["gemm"])
+ list(benchmark_apis["moe"]),
+ list(benchmark_apis["moe"])
+ list(benchmark_apis["norm"])
+ list(benchmark_apis["quantization"]),
)
args, _ = parser.parse_known_args(line[:])

Expand Down Expand Up @@ -156,6 +167,10 @@ def parse_args(line=sys.argv[1:]):
args = parse_gemm_args(line, parser)
elif args.routine in benchmark_apis["moe"]:
args = parse_moe_args(line, parser)
elif args.routine in benchmark_apis["norm"]:
args = parse_norm_args(line, parser)
elif args.routine in benchmark_apis["quantization"]:
args = parse_quantization_args(line, parser)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down
135 changes: 131 additions & 4 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
],
"attention": [
"page_size",
"batch_size",
"s_qo",
"s_kv",
"num_qo_heads",
Expand All @@ -37,14 +36,12 @@
"group_size",
"tile_size",
"scale_major_mode",
"out_dtype",
"mma_sm",
"use_128x4_sf_layout",
"use_nvfp4",
],
"moe": [
"num_tokens",
"hidden_size",
"intermediate_size",
"num_experts",
"top_k",
Expand All @@ -58,7 +55,6 @@
"weight_layout",
"use_routing_bias",
"use_routing_scales_on_input",
"input_dtype",
"weight_dtype",
"gated_act",
# CUTLASS fused MoE specific
Expand All @@ -69,7 +65,30 @@
"ep_size",
"ep_rank",
],
"norm": [
"num_heads",
"scale",
"eps",
"enable_pdl",
"use_global_scale",
"is_sf_swizzled_layout",
],
"quantization": [
"m",
"k",
"is_sf_swizzled_layout",
"alignment",
"enable_pdl",
"global_scale",
"sf_layout",
"do_shuffle",
"sf_vec_size",
],
"general": [
"batch_size",
"hidden_size",
"input_dtype",
"out_dtype",
"refcheck",
"no_cuda_graph",
"use_cupti",
Expand All @@ -86,6 +105,8 @@
+ output_column_dict["attention"]
+ output_column_dict["gemm"]
+ output_column_dict["moe"]
+ output_column_dict["norm"]
+ output_column_dict["quantization"]
+ output_column_dict["general"]
)

Expand All @@ -109,6 +130,19 @@
"trtllm_fp8_per_tensor_scale_moe",
"cutlass_fused_moe",
],
"norm": [
"rmsnorm",
"rmsnorm_quant",
"fused_add_rmsnorm_quant",
"rmsnorm_fp4quant",
"add_rmsnorm_fp4quant",
],
"quantization": [
"mxfp8_quantize",
"mxfp4_quantize",
"nvfp4_quantize",
"nvfp4_batched_quantize",
],
}


Expand Down Expand Up @@ -289,6 +323,99 @@ def dtype_str_to_torch_dtype(dtype_str):
"10.3": ["cutlass"],
"12.0": [],
},
# NORM
"rmsnorm": {
"7.5": ["cuda"],
"8.0": ["cuda"],
"8.6": ["cuda"],
"8.9": ["cuda"],
"9.0": ["cuda"],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
"rmsnorm_quant": {
"7.5": ["cuda"],
"8.0": ["cuda"],
"8.6": ["cuda"],
"8.9": ["cuda"],
"9.0": ["cuda"],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
"fused_add_rmsnorm_quant": {
"7.5": ["cuda"],
"8.0": ["cuda"],
"8.6": ["cuda"],
"8.9": ["cuda"],
"9.0": ["cuda"],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
# NORM - FP4 Quantization (Blackwell SM100+ only, CuTe-DSL kernels)
"rmsnorm_fp4quant": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cute-dsl"],
"10.3": ["cute-dsl"],
"12.0": ["cute-dsl"],
},
"add_rmsnorm_fp4quant": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cute-dsl"],
"10.3": ["cute-dsl"],
"12.0": ["cute-dsl"],
},
# QUANTIZATION
"mxfp8_quantize": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
"mxfp4_quantize": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
"nvfp4_quantize": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
"nvfp4_batched_quantize": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
},
}


Expand Down
Loading