Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ The aim of `flashinfer_benchmark.py` is to provide a single framework for benchm
## Overview

This framework provides tools to:
- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, Quantization, Sampling, and RoPE API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM
- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, Quantization, Sampling, RoPE, and Mamba API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, TensorRT-LLM, and Triton
- Compare performance across different configurations
- Batch performance test multiple test cases

Currently supports testing attention, gemm, fused MOE, normalization, quantization, sampling, and RoPE APIs:
Currently supports testing attention, gemm, fused MOE, normalization, quantization, sampling, RoPE, and Mamba APIs:
- Attention:
- `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache.
- Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`.
Expand Down Expand Up @@ -67,6 +67,8 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati
- `mla_rope_quantize_fp8` - MLA RoPE with FP8 quantization (SM8.9+).
- `rope_quantize_fp8` - RoPE with FP8 quantization (SM8.9+).
- `rope_quantize_fp8_append_paged_kv_cache` - RoPE with FP8 quantization and paged KV cache append (SM8.9+).
- Mamba (Selective State Space Models):
- `selective_state_update` - Selective state update for Mamba layers (generation phase). Supports both single-token prediction (STP) and multi-token prediction (MTP) via `--cache_steps`. Backends: `flashinfer` (CUDA, architecture-specific kernels for base/SM90/SM100+) and `triton` (reference).

## Quick Start
### Single Test Run
Expand Down Expand Up @@ -379,6 +381,22 @@ mpirun -np 8 python benchmarks/flashinfer_benchmark.py \
| `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 |
| `--backends` | Backend to test: `cuda` (default) |

### Mamba Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--nheads` | Number of SSM heads |
| `--dim` | Head dimension (headdim) |
| `--dstate` | SSM state size |
| `--ngroups` | Number of groups for B and C matrices. `nheads` must be divisible by `ngroups`, and `nheads/ngroups` must be 1, 8, or 16. Default: 8 |
| `--cache_steps` | Number of steps/tokens for multi-token prediction (MTP). 0 = single-token prediction (STP). Default: 0 |
| `--input_dtype` | Data type for input tensors (x, B, C, z): `bfloat16` (default). Only `bfloat16` is supported. |
| `--state_dtype` | Data type for the SSM state cache: `bfloat16` (default), `float16`, or `float32` |
| `--weight_dtype` | Data type for weight tensors (dt, D, dt_bias): `float32` (default) or `bfloat16` |
| `--has_z` | Include z tensor for gating (`z * sigmoid(z)` applied to output) |
| `--dt_softplus` | Apply softplus to dt before use |
| `--backends` | Backends to test: `flashinfer` (default), `triton` (reference). Refcheck compares against Triton reference |

## `flashinfer_benchmark.py` Routine & Backend Support Matrix
The following table summarizes the support surface of each routine & backend's on various [CUDA Compute Capabilities](https://developer.nvidia.com/cuda-gpus).

Expand Down Expand Up @@ -443,6 +461,7 @@ Legend:
| **mla_rope_quantize_fp8** | | | | cuda | cuda | cuda | cuda | cuda |
| **rope_quantize_fp8** | | | | cuda | cuda | cuda | cuda | cuda |
| **rope_quantize_fp8_append_paged_kv_cache** | | | | cuda | cuda | cuda | cuda | cuda |
| **selective_state_update** | flashinfer, triton | flashinfer, triton | flashinfer, triton | flashinfer, triton | flashinfer, triton | flashinfer, triton | flashinfer, triton | flashinfer, triton |

Backend Legend:
- fa2: FlashAttention2
Expand All @@ -458,3 +477,4 @@ Backend Legend:
- cuda: FlashInfer CUDA kernels
- cute-dsl: FlashInfer CuTe-DSL kernels (Blackwell SM10.0+)
- moe_a2a: MoE All-to-All communication (requires mpirun, Blackwell SM10.0+ with MNNVL)
- triton: Triton reference kernels (used for Mamba selective_state_update)
11 changes: 10 additions & 1 deletion benchmarks/flashinfer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def run_test(args):
from routines.rope import run_rope_test

res = run_rope_test(args)
elif args.routine in benchmark_apis["mamba"]:
from routines.mamba import run_mamba_test

res = run_mamba_test(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down Expand Up @@ -99,7 +103,8 @@ def parse_args(line=sys.argv[1:]):
+ list(benchmark_apis["norm"])
+ list(benchmark_apis["quantization"])
+ list(benchmark_apis["sampling"])
+ list(benchmark_apis["rope"]),
+ list(benchmark_apis["rope"])
+ list(benchmark_apis["mamba"]),
)
args, _ = parser.parse_known_args(line[:])

Expand Down Expand Up @@ -217,6 +222,10 @@ def parse_args(line=sys.argv[1:]):
from routines.rope import parse_rope_args

args = parse_rope_args(line, parser)
elif args.routine in benchmark_apis["mamba"]:
from routines.mamba import parse_mamba_args

args = parse_mamba_args(line, parser)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down
27 changes: 27 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@
"interleave",
"kv_layout",
],
"mamba": [
"nheads",
"dim",
"dstate",
"ngroups",
"cache_steps",
"state_dtype",
"weight_dtype",
"has_z",
"dt_softplus",
],
"general": [
"batch_size",
"hidden_size",
Expand Down Expand Up @@ -136,6 +147,7 @@
+ output_column_dict["quantization"]
+ output_column_dict["sampling"]
+ output_column_dict["rope"]
+ output_column_dict["mamba"]
+ output_column_dict["general"]
)

Expand Down Expand Up @@ -202,6 +214,9 @@
"rope_quantize_fp8",
"rope_quantize_fp8_append_paged_kv_cache",
],
"mamba": [
"selective_state_update",
],
}


Expand Down Expand Up @@ -708,6 +723,18 @@ def dtype_str_to_torch_dtype(dtype_str):
"10.3": ["cuda"],
"12.0": ["cuda"],
},
# MAMBA
"selective_state_update": {
"7.5": ["flashinfer", "triton"],
"8.0": ["flashinfer", "triton"],
"8.6": ["flashinfer", "triton"],
"8.9": ["flashinfer", "triton"],
"9.0": ["flashinfer", "triton"],
"10.0": ["flashinfer", "triton"],
"10.3": ["flashinfer", "triton"],
"11.0": ["flashinfer", "triton"],
"12.0": ["flashinfer", "triton"],
},
}


Expand Down
Loading
Loading