benchmarks: Add microbenchmark support for Mamba selective_state_update#2512
benchmarks: Add microbenchmark support for Mamba selective_state_update#2512yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds a Mamba selective_state_update benchmark: documentation and sample tests, registers a new Changes
Sequence Diagram(s)sequenceDiagram
participant User as User
participant Runner as Benchmark Runner\n(`benchmarks/flashinfer_benchmark.py`)
participant Router as Mamba Runner\n(`benchmarks/routines/mamba.py`)
participant Backend as Compute Backend\n(FlashInfer / Triton)
participant Ref as Triton Reference\nKernel (`tests/mamba/selective_state_update_triton.py`)
User->>Runner: invoke CLI (parse_mamba_args)
Runner->>Router: run_mamba_test(args)
Router->>Router: prepare inputs (state, x, dt, A,B,C,...)
alt reference check enabled
Router->>Ref: selective_state_update_triton_reference(...)
Ref-->>Router: reference result
end
Router->>Backend: run selective_state_update on backend(s)
Backend-->>Router: outputs + timings
alt verification enabled
Router->>Router: compare outputs -> handle mismatches
end
Router->>Router: compute metrics (median time, TFLOPs, TB/s)
Router-->>Runner: return perf results
Runner-->>User: display report
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive benchmarking capabilities for the Mamba Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds comprehensive microbenchmark support for the Mamba selective_state_update kernel to the benchmarking framework. The changes include updates to the main benchmark script, utility files, and documentation to integrate the new Mamba routine. A new file, benchmarks/routines/mamba.py, contains the core benchmarking logic, including a Triton reference implementation for correctness checking. The PR also adds several sample test cases. The implementation is robust, covering both single-token (STP) and multi-token (MTP) prediction modes, and the performance metric calculations are sound. I have one minor suggestion to improve the readability of the Triton reference kernel. Overall, this is a solid contribution.
benchmarks/routines/mamba.py
Outdated
| current_step_idx = 0 | ||
| for _ in range(T): | ||
| if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: | ||
| if current_step_idx != 0 and cache_idx >= 0: | ||
| parent_ptr = ( | ||
| retrieve_parent_token_ptr | ||
| + pid_b * stride_retrieve_parent_token_batch | ||
| + current_step_idx * stride_retrieve_parent_token_T | ||
| ) | ||
| parent_step_idx = tl.load(parent_ptr).to(tl.int32) | ||
|
|
||
| if parent_step_idx >= 0 and parent_step_idx < T: | ||
| step_offset = parent_step_idx * nheads * dim * dstate | ||
| cache_ptr = ( | ||
| intermediate_states_buffer | ||
| + cache_idx * cache_steps * nheads * dim * dstate | ||
| + step_offset | ||
| + pid_h * dim * dstate | ||
| + offs_m[:, None] * dstate | ||
| + offs_n[None, :] | ||
| ) | ||
| state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) | ||
|
|
||
| x_ptrs = x_ptr + offs_m * stride_x_dim | ||
| dt_ptrs = dt_ptr + offs_m * stride_dt_dim | ||
| B_ptrs = B_ptr + offs_n * stride_B_dstate | ||
| C_ptrs = C_ptr + offs_n * stride_C_dstate | ||
| if HAS_Z: | ||
| z_ptrs = z_ptr + offs_m * stride_z_dim | ||
| out_ptrs = out_ptr + offs_m * stride_out_dim | ||
|
|
||
| x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if not TIE_HDIM: | ||
| dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if HAS_DT_BIAS: | ||
| dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if DT_SOFTPLUS: | ||
| dt = softplus(dt) | ||
| A = tl.load( | ||
| A_ptrs, | ||
| mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), | ||
| other=0.0, | ||
| ).to(tl.float32) | ||
| dA = tl.exp(A * dt[:, None]) | ||
| else: | ||
| dt = tl.load(dt_ptr).to(tl.float32) | ||
| if HAS_DT_BIAS: | ||
| dt += tl.load(dt_bias_ptr).to(tl.float32) | ||
| if DT_SOFTPLUS: | ||
| dt = softplus(dt) | ||
| A = tl.load(A_ptr).to(tl.float32) | ||
| dA = tl.exp(A * dt) # scalar, not a matrix | ||
|
|
||
| B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) | ||
| C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) | ||
| if HAS_D: | ||
| D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if HAS_Z: | ||
| z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
|
|
||
| dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt | ||
| state = state * dA + dB * x[:, None] | ||
|
|
||
| if CACHE_INTERMEDIATE_STATES: | ||
| if state_batch_idx != pad_slot_id: | ||
| cache_ptr_base = ( | ||
| intermediate_states_buffer | ||
| + cache_idx * cache_steps * nheads * dim * dstate | ||
| + current_step_idx * nheads * dim * dstate | ||
| + pid_h * dim * dstate | ||
| ) | ||
| cache_ptrs = cache_ptr_base + ( | ||
| offs_m[:, None] * dstate + offs_n[None, :] | ||
| ) | ||
| tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask) | ||
|
|
||
| out = tl.sum(state * C[None, :], axis=1) | ||
| if HAS_D: | ||
| out += x * D | ||
| if HAS_Z: | ||
| out *= z * tl.sigmoid(z) | ||
| tl.store(out_ptrs, out, mask=offs_m < dim) | ||
|
|
||
| current_step_idx += 1 # noqa: SIM113 | ||
|
|
||
| x_ptr += stride_x_T | ||
| dt_ptr += stride_dt_T | ||
| B_ptr += stride_B_T | ||
| C_ptr += stride_C_T | ||
| out_ptr += stride_out_T | ||
| if HAS_Z: | ||
| z_ptr += stride_z_T |
There was a problem hiding this comment.
The manual increment of current_step_idx inside the for _ in range(T): loop can be simplified by using the loop variable directly. This improves readability and is a more idiomatic way to write such loops in Python and Triton. The noqa: SIM113 comment indicates awareness of this, but a refactor would still be beneficial for clarity.
for current_step_idx in range(T):
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
if current_step_idx != 0 and cache_idx >= 0:
parent_ptr = (
retrieve_parent_token_ptr
+ pid_b * stride_retrieve_parent_token_batch
+ current_step_idx * stride_retrieve_parent_token_T
)
parent_step_idx = tl.load(parent_ptr).to(tl.int32)
if parent_step_idx >= 0 and parent_step_idx < T:
step_offset = parent_step_idx * nheads * dim * dstate
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * nheads * dim * dstate
+ step_offset
+ pid_h * dim * dstate
+ offs_m[:, None] * dstate
+ offs_n[None, :]
)
state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
if CACHE_INTERMEDIATE_STATES:
if state_batch_idx != pad_slot_id:
cache_ptr_base = (
intermediate_states_buffer
+ cache_idx * cache_steps * nheads * dim * dstate
+ current_step_idx * nheads * dim * dstate
+ pid_h * dim * dstate
)
cache_ptrs = cache_ptr_base + (
offs_m[:, None] * dstate + offs_n[None, :]
)
tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
x_ptr += stride_x_T
dt_ptr += stride_dt_T
B_ptr += stride_B_T
C_ptr += stride_C_T
out_ptr += stride_out_T
if HAS_Z:
z_ptr += stride_z_TThere was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 384-398: Insert a blank line immediately before the Markdown table
that follows the "### Mamba Flags" heading so the table is separated from the
preceding text; modify the README content around the "### Mamba Flags" section
to ensure there is an empty line between the heading (or prior paragraph) and
the table start to satisfy Markdownlint MD058 and proper rendering.
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 726-737: The "selective_state_update" mapping in
routine_cc_to_supported_backends contains an inconsistent compute capability key
"11.0"; remove the "11.0" entry (or if you intended to support a real CC, add it
consistently across all routine_cc_to_supported_backends mappings and the README
backend matrix) so the set of keys matches the rest ({7.5, 8.0, 8.6, 8.9, 9.0,
10.0, 10.3, 12.0}); update the "selective_state_update" dict to drop the "11.0"
key (or add corresponding keys elsewhere) to restore consistency.
🧹 Nitpick comments (2)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
105-115: Duplicate "weight_dtype" column across output categories.
"weight_dtype"appears in bothoutput_column_dict["moe"](Line 54) andoutput_column_dict["mamba"](Line 112). Sincefull_output_columnsis a flat concatenation, the CSV header will have two"weight_dtype"columns. Both will resolve to the samecur_res["weight_dtype"]value.This won't crash, but it produces a confusing CSV with duplicate column names. Consider renaming the mamba column to e.g.
"mamba_weight_dtype", or moving shared fields likeweight_dtypeto the"general"category.Also applies to: 150-150
benchmarks/routines/mamba.py (1)
993-1002: Useargs.state_dtype/args.weight_dtypestrings instead ofstr(torch_dtype)for consistent CSV formatting.
str(state_dtype)produces"torch.bfloat16"while other dtype columns (e.g.input_dtypefromargs) are stored as"bfloat16". This creates an inconsistency in the CSV output.Proposed fix
- cur_res["state_dtype"] = str(state_dtype) - cur_res["weight_dtype"] = str(weight_dtype) + cur_res["state_dtype"] = args.state_dtype + cur_res["weight_dtype"] = args.weight_dtype
benchmarks/routines/mamba.py
Outdated
| } | ||
| ) | ||
| @triton.jit(do_not_specialize=["T"]) | ||
| def _selective_scan_update_kernel( |
There was a problem hiding this comment.
[nit] To avoid confusion, maybe add a _reference suffix?
There was a problem hiding this comment.
This is a good suggestion. Done in the latest commit! 1bf361d
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (6)
benchmarks/routines/mamba.py (6)
64-323: Triton kernel implementation looks correct.The selective scan update kernel follows standard Mamba SSM patterns. A couple of minor static-analysis notes from the kernel body:
Line 118 –
batchunused (Ruff ARG001): This is a false positive.batchis part of the Triton kernel's parameter interface and the batch dimension is handled viatl.program_id(axis=1)/ grid dispatch — no action needed.Line 312 – stale
# noqa: SIM113(Ruff RUF100): The directive references a non-enabled rule. Safe to remove.Cleanup for line 312
- current_step_idx += 1 # noqa: SIM113 + current_step_idx += 1
424-440: Readability: lambda grid assignment and deeply nested ternary.Two minor style items flagged by static analysis and readability review:
- Line 424 (Ruff E731): Prefer a
defover alambdaassignment.- Lines 432-439: The four-level nested ternary for
BLOCK_SIZE_M/num_warpsis hard to follow at a glance.Suggested cleanup
- grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + def grid(META): + return (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) - BLOCK_SIZE_M, num_warps = ( - (32, 4) - if dstate <= 16 - else ( - (16, 4) - if dstate <= 32 - else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) - ) - ) + if dstate <= 16: + BLOCK_SIZE_M, num_warps = 32, 4 + elif dstate <= 32: + BLOCK_SIZE_M, num_warps = 16, 4 + elif dstate <= 64: + BLOCK_SIZE_M, num_warps = 8, 4 + elif dstate <= 128: + BLOCK_SIZE_M, num_warps = 4, 4 + else: + BLOCK_SIZE_M, num_warps = 4, 8
546-660: Argument parsing and validation look solid.The
nheads % ngroupsdivisibility and supported-ratio checks are good guardrails. One minor note:supported_ratios = [1, 8, 16]at line 650 is hardcoded. If the FlashInfer CUDA kernel adds support for new ratios in the future, this will silently reject valid configurations. Consider adding a comment noting this mirrors a specific kernel constraint so future maintainers know where to update.
806-838: Closure captures many variables from enclosing scope — works, but worth noting.
run_backendcloses overz,dt_bias,dt_softplus,slot_idx,cache_steps, andtriton_cache_steps. This is fine for benchmarking, but note thatintermediate_states_bufferis never allocated or passed to either backend — so MTP intermediate-state caching is not exercised by this benchmark. If that's intentional (benchmarking the core SSM update path only), a brief comment clarifying the omission would help future readers.
871-888:state_cacheis mutated in-place across backends bybench_gpu_time.The loop passes the same
state_cachetensor tobench_gpu_timefor every backend. After the first backend's warm-up + measurement iterations,state_cachecontains different values for the second backend's run. This doesn't affect correctness of refcheck (captured earlier fromclean_state_snapshotclones), and for timing the compute pattern is identical regardless of state values, so results are valid.Still, if a future change introduces a data-dependent fast-path or NaN-propagation concern, this could silently skew timings. A defensive one-liner to restore state before each backend's bench run would be cheap insurance:
Optional defensive clone
for cur_backend in backends: + if clean_state_snapshot is not None: + state_cache.copy_(clean_state_snapshot) if run_refcheck and cur_backend != "triton":
983-1002:defaultdict(str)is unnecessary here.All keys are explicitly assigned, so a plain
dict()(or{}) would be clearer and avoid silently returning""for typo'd keys during downstream consumption.Minor simplification
- cur_res = defaultdict(str) + cur_res = {}
| read_bytes = ( | ||
| batch_size * nheads * dim * dstate * state_dtype.itemsize # state | ||
| + batch_size * T_val * nheads * dim * input_dtype.itemsize # x | ||
| + batch_size * T_val * nheads * weight_dtype.itemsize # dt (broadcast) | ||
| + nheads * 4 # A (float32, broadcast) | ||
| + batch_size * T_val * ngroups * dstate * input_dtype.itemsize # B | ||
| + batch_size * T_val * ngroups * dstate * input_dtype.itemsize # C | ||
| + nheads * weight_dtype.itemsize # D (broadcast) | ||
| + nheads * weight_dtype.itemsize # dt_bias (broadcast) | ||
| ) | ||
| if has_z: | ||
| read_bytes += batch_size * T_val * nheads * dim * input_dtype.itemsize | ||
|
|
||
| write_bytes = ( | ||
| batch_size * nheads * dim * dstate * state_dtype.itemsize # state | ||
| + batch_size * T_val * nheads * dim * input_dtype.itemsize # output | ||
| ) | ||
|
|
||
| problem_bytes = read_bytes + write_bytes | ||
|
|
||
| # FLOPs estimate (TIE_HDIM case, where dt/A/D/dt_bias broadcast across dim): | ||
| # Per (dim, dstate) element per (batch, T, head): | ||
| # state * dA: 1 mul, dB * x[:, None]: 1 mul, state + ...: 1 add, | ||
| # state * C[None, :]: 1 mul, sum reduction: ~1 add => 5 FLOPs/element | ||
| problem_flops = batch_size * T_val * nheads * dim * dstate * 5 | ||
| tflops = problem_flops / (10**9 * median_time) # TFLOPs/sec | ||
| tb_per_sec = problem_bytes / (10**9 * median_time) # TB/sec |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify bench_gpu_time return value documentation / units
rg -n -C5 "def bench_gpu_time" --type=pyRepository: flashinfer-ai/flashinfer
Length of output: 2662
🏁 Script executed:
#!/bin/bash
# Verify torch.dtype.itemsize usage in the repo
rg -n "\.itemsize" --type=py -C2Repository: flashinfer-ai/flashinfer
Length of output: 41189
🏁 Script executed:
#!/bin/bash
# Check PyTorch version requirements in the repo
find . -maxdepth 2 -type f \( -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "*.yml" -o -name "*.yaml" \) | head -20Repository: flashinfer-ai/flashinfer
Length of output: 246
🏁 Script executed:
#!/bin/bash
# Get full bench_gpu_time function definition
sed -n '1508,1550p' flashinfer/testing/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 1723
🏁 Script executed:
#!/bin/bash
# Check PyTorch version requirements
cat pyproject.toml | grep -A 20 "dependencies"Repository: flashinfer-ai/flashinfer
Length of output: 1256
🏁 Script executed:
#!/bin/bash
# Also check requirements.txt
cat requirements.txt | grep -i torchRepository: flashinfer-ai/flashinfer
Length of output: 73
🏁 Script executed:
#!/bin/bash
# Get full requirements.txt to see all dependencies and versions
cat requirements.txtRepository: flashinfer-ai/flashinfer
Length of output: 255
🏁 Script executed:
#!/bin/bash
# Check CI configuration for PyTorch version requirements
find . -maxdepth 3 -type f \( -name "*.yml" -o -name "*.yaml" \) -exec grep -l "pytorch\|torch" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 139
🏁 Script executed:
#!/bin/bash
# Look for torch version constraints in setup.py if it exists
cat setup.py 2>/dev/null || echo "No setup.py found"Repository: flashinfer-ai/flashinfer
Length of output: 85
🏁 Script executed:
#!/bin/bash
# Check CI workflow files for PyTorch version constraints
cat .github/workflows/nightly-release.yml | grep -i "torch\|pytorch" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 530
🏁 Script executed:
#!/bin/bash
# Also check release.yml
cat .github/workflows/release.yml | grep -i "torch\|pytorch" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 466
🏁 Script executed:
#!/bin/bash
# Search for any PyTorch version constraints or minimum version docs
rg -i "minimum.*torch|torch.*version|torch.*>=|>=.*torch" -A 3 -B 3 --type=py --type=mdRepository: flashinfer-ai/flashinfer
Length of output: 24312
🏁 Script executed:
#!/bin/bash
# Check PyTorch source history for itemsize on dtype
# First, let's search for any version checks or compatibility notes about itemsize
rg -i "itemsize" --type=py benchmarks/routines/mamba.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1377
🏁 Script executed:
#!/bin/bash
# Look for the actual bench_gpu_time function to see full return type and timing backend
sed -n '1508,1600p' flashinfer/testing/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 3588
🏁 Script executed:
#!/bin/bash
# Check if there are any torch version checks related to dtype.itemsize usage
rg "TorchVersion|torch_version|version" flashinfer/testing/utils.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 470
🏁 Script executed:
#!/bin/bash
# Check for any README or documentation about minimum PyTorch version
rg "minimum.*torch|torch.*>=|>=.*torch|supported.*version" README.md CONTRIBUTING.md -i -A 2Repository: flashinfer-ai/flashinfer
Length of output: 288
🏁 Script executed:
#!/bin/bash
# Check the full calc_bw_tflops function to see the time unit context
sed -n '740,760p' flashinfer/testing/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 842
🏁 Script executed:
#!/bin/bash
# Search for any itemsize-related comments or version checks
rg "itemsize" --type=py -B 3 -A 3 | grep -E "(itemsize|version|>=|<=)" -A 2 -B 2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 3029
🏁 Script executed:
#!/bin/bash
# Look for any PyTorch version checks or constraints in the codebase
rg "TorchVersion.*2\.|torch.*>=" --type=py -B 2 -A 2 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 6809
🏁 Script executed:
#!/bin/bash
# Check PyTorch documentation or comments about itemsize availability
rg "itemsize" --type=py -B 5 | grep -E "(version|2\.|deprecated|available|since)" -i | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look at the actual mamba.py file to confirm the exact context and verify the formulas
sed -n '945,980p' benchmarks/routines/mamba.pyRepository: flashinfer-ai/flashinfer
Length of output: 2034
Verified: median_time is in milliseconds; formulas are correct. Itemsize availability not explicitly documented.
-
Time units confirmed:
bench_gpu_timeexplicitly returns times in milliseconds (docstring states "Per-iteration execution times in milliseconds"), and the math checks out:10^9 * time_ms = 10^6 * time_sec, soflops / (10^9 * median_time)correctly yields TFLOPs/sec when median_time is in milliseconds. The formulas are correct. -
torch.dtype.itemsizeconcern: The attribute is used extensively throughout the codebase (40+ instances across tests, benchmarks, and production code) without version guards. However,requirements.txtspecifies onlytorchwith no minimum version constraint. PyTorch addeddtype.itemsizein version 2.0, but this is not explicitly documented as a minimum requirement. While the ubiquitous usage suggests it's expected to be available, the lack of an explicit version constraint in the project's dependency specification is a valid concern if supporting older PyTorch versions.
|
Overall, looks good. I but think that you copied the contents of the Triton kernel to benchmarks/routines/mamba.py. |
Thanks @ishovkun, I removed the duplicated Triton reference in the latest commit. |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/routines/mamba.py`:
- Around line 219-226: Wrap the nheads/ngroups ratio validation in the same
backend guard used by the dim and dstate checks (i.e., only run when
"flashinfer" is in args.backends); specifically, move or wrap the
supported_ratios/ratio check so it is executed only if "flashinfer" in
args.backends, leaving the existing variables supported_ratios, ratio,
args.nheads and args.ngroups unchanged and keeping the same ValueError message
when triggered.
- Around line 553-572: The result dict population in the if args.output_path
block (where cur_res is built and keys like "nheads", "dim", "state_dtype" are
set) is missing batch_size and input_dtype; update that block to add
cur_res["batch_size"] = batch_size and cur_res["input_dtype"] = str(input_dtype)
(or the existing variable's string representation) alongside the other
Mamba-specific columns so outputs include the batch size and input dtype for
reproducibility.
🧹 Nitpick comments (1)
benchmarks/routines/mamba.py (1)
74-74: Eager module-level import will breakflashinfer-only users if the Triton reference file is missing or Triton is not installed.
_import_triton_reference()runs unconditionally at import time. If someone only wants theflashinferbackend, this still fails the entire module load when Triton or the reference file is absent.Consider lazy-loading: call
_import_triton_reference()only when thetritonbackend is actually requested (e.g., insiderun_backendor at the top oftestSelectiveStateUpdatewhen"triton" in backends or run_refcheck).Suggested approach
-selective_state_update_triton_reference = _import_triton_reference() +selective_state_update_triton_reference = None + +def _get_triton_reference(): + global selective_state_update_triton_reference + if selective_state_update_triton_reference is None: + selective_state_update_triton_reference = _import_triton_reference() + return selective_state_update_triton_referenceThen replace usages of
selective_state_update_triton_reference(...)with_get_triton_reference()(...).
| # Validate nheads/ngroups ratio is supported by the CUDA kernel | ||
| supported_ratios = [1, 8, 16] | ||
| ratio = args.nheads // args.ngroups | ||
| if ratio not in supported_ratios: | ||
| raise ValueError( | ||
| f"nheads/ngroups ratio ({ratio} = {args.nheads}/{args.ngroups}) is not supported by the FlashInfer kernel. " | ||
| f"Supported ratios: {supported_ratios}." | ||
| ) |
There was a problem hiding this comment.
nheads/ngroups ratio validation is not gated on flashinfer backend, unlike dim and dstate checks.
Lines 203 and 212 correctly gate their validation with if "flashinfer" in args.backends, but this ratio check applies unconditionally. A user running --backends triton with a ratio not in [1, 8, 16] would be incorrectly rejected.
Proposed fix
supported_ratios = [1, 8, 16]
ratio = args.nheads // args.ngroups
- if ratio not in supported_ratios:
+ if "flashinfer" in args.backends and ratio not in supported_ratios:
raise ValueError(
f"nheads/ngroups ratio ({ratio} = {args.nheads}/{args.ngroups}) is not supported by the FlashInfer kernel. "
f"Supported ratios: {supported_ratios}."
)🧰 Tools
🪛 Ruff (0.15.0)
[warning] 223-226: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@benchmarks/routines/mamba.py` around lines 219 - 226, Wrap the nheads/ngroups
ratio validation in the same backend guard used by the dim and dstate checks
(i.e., only run when "flashinfer" is in args.backends); specifically, move or
wrap the supported_ratios/ratio check so it is executed only if "flashinfer" in
args.backends, leaving the existing variables supported_ratios, ratio,
args.nheads and args.ngroups unchanged and keeping the same ValueError message
when triggered.
There was a problem hiding this comment.
This is fine because Triton is a reference. We would always be benchmarking FlashInfer as the main backend.
There was a problem hiding this comment.
@bkryu, understood! If FlashInfer is always the primary target, the current validation makes sense for your workflow.
Just to note: since the code does accept --backends triton without flashinfer, a user running only Triton with an unsupported ratio would see "is not supported by the FlashInfer kernel" in the error message, which could be a bit confusing. But if that's not a practical use case for you, the current approach is fine.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/mamba.py`:
- Line 74: The module currently performs an eager import by assigning
selective_state_update_triton_reference = _import_triton_reference() which
causes ImportError for users who don't have the Triton test file; change this to
a lazy loader by implementing a _get_triton_reference() function that calls
_import_triton_reference() on first use and caches the result (e.g., store in a
module-level variable), then remove the eager assignment and update all call
sites to invoke _get_triton_reference()(...) instead of
selective_state_update_triton_reference(...); ensure the new loader handles
ImportError gracefully so non-Triton backends continue to work.
🧹 Nitpick comments (1)
benchmarks/routines/mamba.py (1)
233-233: Nit:testSelectiveStateUpdateuses camelCase.Python convention (PEP 8) favors
snake_casefor function names. Considertest_selective_state_updatefor consistency.
| return module.selective_state_update_triton | ||
|
|
||
|
|
||
| selective_state_update_triton_reference = _import_triton_reference() |
There was a problem hiding this comment.
Module-level eager import will fail even when Triton is not requested.
_import_triton_reference() runs at import time, so any user who only needs --backends flashinfer will still get an ImportError if tests/mamba/selective_state_update_triton.py is absent (e.g., in a packaged/installed environment without the test tree). Consider lazy-loading:
Proposed fix
-selective_state_update_triton_reference = _import_triton_reference()
+selective_state_update_triton_reference = None
+
+def _get_triton_reference():
+ global selective_state_update_triton_reference
+ if selective_state_update_triton_reference is None:
+ selective_state_update_triton_reference = _import_triton_reference()
+ return selective_state_update_triton_referenceThen replace usages of selective_state_update_triton_reference(...) with _get_triton_reference()(...).
🤖 Prompt for AI Agents
In `@benchmarks/routines/mamba.py` at line 74, The module currently performs an
eager import by assigning selective_state_update_triton_reference =
_import_triton_reference() which causes ImportError for users who don't have the
Triton test file; change this to a lazy loader by implementing a
_get_triton_reference() function that calls _import_triton_reference() on first
use and caches the result (e.g., store in a module-level variable), then remove
the eager assignment and update all call sites to invoke
_get_triton_reference()(...) instead of
selective_state_update_triton_reference(...); ensure the new loader handles
ImportError gracefully so non-Triton backends continue to work.
📌 Description
flashinfer_benchmark.py, covering both single-token prediction (STP) and multi-token prediction (MTP) modes.flashinfer(architecture-specific CUDA kernels for base/SM90/SM100+) andtriton(reference implementation, used for correctness checking).README.mdwith Mamba API documentation, CLI flags, and backend support matrix. Adds 11 sample test cases tosample_testlist.txt.cc @ishovkun
🔍 Related Issues
#2513
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Samples