Skip to content

benchmarks: Add microbenchmark support for Mamba selective_state_update#2512

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
bkryu:microbench_mamba
Feb 16, 2026
Merged

benchmarks: Add microbenchmark support for Mamba selective_state_update#2512
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
bkryu:microbench_mamba

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Feb 6, 2026

📌 Description

  • Adds benchmarking support for the Mamba selective_state_update kernel to flashinfer_benchmark.py, covering both single-token prediction (STP) and multi-token prediction (MTP) modes.
  • Supports two backends: flashinfer (architecture-specific CUDA kernels for base/SM90/SM100+) and triton (reference implementation, used for correctness checking).
  • Updates README.md with Mamba API documentation, CLI flags, and backend support matrix. Adds 11 sample test cases to sample_testlist.txt.
## Example on B200:
$ python3 flashinfer_benchmark.py --routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --input_dtype bfloat16 --backends flashinfer triton --refcheck -v --generate_repro_command --case_tag "mamba2_stp_bf16"
[INFO] args = Namespace(routine='selective_state_update', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=1, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mamba2_stp_bf16', generate_repro_command=True, repro_command='', batch_size=64, nheads=64, dim=128, dstate=128, ngroups=8, cache_steps=0, input_dtype='bfloat16', state_dtype='bfloat16', weight_dtype='float32', has_z=False, dt_softplus=False, backends=['flashinfer', 'triton'])
[INFO] Running testSelectiveStateUpdate
[INFO] FlashInfer version: 0.6.2
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --input_dtype bfloat16 --backends flashinfer triton --refcheck -v --generate_repro_command --case_tag mamba2_stp_bf16
[REFCHECK] Backend flashinfer: PASSED (1/524288 elements differ (0.0002%), within 0.01% threshold)
[PERF] flashinfer     :: median time 0.050 ms; std 0.000 ms; achieved tflops 6.752 TFLOPs/sec; achieved tb_per_sec 5.449 TB/sec
[PERF] triton         :: median time 0.157 ms; std 0.001 ms; achieved tflops 2.131 TFLOPs/sec; achieved tb_per_sec 1.720 TB/sec

## Same case on H200:
$ python3 flashinfer_benchmark.py --routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --input_dtype bfloat16 --backends flashinfer triton --refcheck -v --generate_repro_command --case_tag "mamba2_stp_bf16"
[INFO] args = Namespace(routine='selective_state_update', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=1, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mamba2_stp_bf16', generate_repro_command=True, repro_command='', batch_size=64, nheads=64, dim=128, dstate=128, ngroups=8, cache_steps=0, input_dtype='bfloat16', state_dtype='bfloat16', weight_dtype='float32', has_z=False, dt_softplus=False, backends=['flashinfer', 'triton'])
[INFO] Running testSelectiveStateUpdate
[INFO] FlashInfer version: 0.6.2
[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine selective_state_update --batch_size 64 --nheads 64 --dim 128 --dstate 128 --ngroups 8 --input_dtype bfloat16 --backends flashinfer triton --refcheck -v --generate_repro_command --case_tag mamba2_stp_bf16
[REFCHECK] Backend flashinfer: PASSED (all 524288 elements match)
[PERF] flashinfer     :: median time 0.073 ms; std 0.001 ms; achieved tflops 4.589 TFLOPs/sec; achieved tb_per_sec 3.704 TB/sec
[PERF] triton         :: median time 0.175 ms; std 0.000 ms; achieved tflops 1.918 TFLOPs/sec; achieved tb_per_sec 1.548 TB/sec

cc @ishovkun

🔍 Related Issues

#2513

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added Mamba selective_state_update benchmarks with single-/multi-token modes, head/group variations, dtype and backend selection (FlashInfer, Triton), optional gating and softplus, reference checks, and performance metrics.
  • Documentation

    • Expanded Overview, Quick Start, flags, and backend support matrix to include Mamba API and Mamba-specific flags.
  • Samples

    • Added sample test cases covering diverse Mamba configurations and perf scenarios.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

Adds a Mamba selective_state_update benchmark: documentation and sample tests, registers a new mamba routine in the benchmark harness, extends benchmark utilities with Mamba schema and backend mappings, and implements a Triton reference kernel plus an end-to-end benchmarking harness and CLI parsing for Mamba.

Changes

Cohort / File(s) Summary
Docs & Samples
benchmarks/README.md, benchmarks/samples/sample_testlist.txt
Add Mamba documentation and Flags (FP8/FP4, dt_softplus), update backend legend, and insert numerous selective_state_update sample tests (STP/MTP, dtype/backends, cache_steps).
Benchmark Harness & Utils
benchmarks/flashinfer_benchmark.py, benchmarks/routines/flashinfer_benchmark_utils.py
Register mamba routine in CLI and run routing; add Mamba output columns; map selective_state_update to flashinfer and triton across CUDA versions and update dtype→backend mappings and supported-backend table.
Mamba Kernel & Benchmark Implementation
benchmarks/routines/mamba.py, tests/mamba/selective_state_update_triton.py
Add Triton reference import; implement parse_mamba_args, run_mamba_test, testSelectiveStateUpdate; generate synthetic inputs, support flashinfer/triton runs, optional reference validation, STP/MTP handling, timing and TFLOPs/TB/s metrics, and CUDA-graph/CUPTI options.

Sequence Diagram(s)

sequenceDiagram
    participant User as User
    participant Runner as Benchmark Runner\n(`benchmarks/flashinfer_benchmark.py`)
    participant Router as Mamba Runner\n(`benchmarks/routines/mamba.py`)
    participant Backend as Compute Backend\n(FlashInfer / Triton)
    participant Ref as Triton Reference\nKernel (`tests/mamba/selective_state_update_triton.py`)

    User->>Runner: invoke CLI (parse_mamba_args)
    Runner->>Router: run_mamba_test(args)
    Router->>Router: prepare inputs (state, x, dt, A,B,C,...)
    alt reference check enabled
        Router->>Ref: selective_state_update_triton_reference(...)
        Ref-->>Router: reference result
    end
    Router->>Backend: run selective_state_update on backend(s)
    Backend-->>Router: outputs + timings
    alt verification enabled
        Router->>Router: compare outputs -> handle mismatches
    end
    Router->>Router: compute metrics (median time, TFLOPs, TB/s)
    Router-->>Runner: return perf results
    Runner-->>User: display report
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested labels

benchmark, op: misc

Suggested reviewers

  • Anerudhan
  • yzh119
  • cyx-6
  • jiahanc
  • kahyunnam
  • nv-yunzheq

Poem

🐇 I hopped through kernels, found a Mamba tune,
I timed each hop beneath the silicon moon,
Triton and FlashInfer danced in my sight,
Metrics gleamed like carrots, measured just right,
Hop, bench, nibble — a rabbit's delight.

🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (38 files):

⚔️ benchmarks/README.md (content)
⚔️ benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (content)
⚔️ benchmarks/flashinfer_benchmark.py (content)
⚔️ benchmarks/routines/flashinfer_benchmark_utils.py (content)
⚔️ benchmarks/routines/gemm.py (content)
⚔️ benchmarks/samples/sample_testlist.txt (content)
⚔️ csrc/gdn_prefill_launcher.cu (content)
⚔️ csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (content)
⚔️ csrc/trtllm_fmha_kernel_launcher.cu (content)
⚔️ docker/Dockerfile.cu126 (content)
⚔️ docker/Dockerfile.cu128 (content)
⚔️ docker/Dockerfile.cu129 (content)
⚔️ docker/Dockerfile.cu130 (content)
⚔️ flashinfer/__init__.py (content)
⚔️ flashinfer/artifacts.py (content)
⚔️ flashinfer/cute_dsl/__init__.py (content)
⚔️ flashinfer/cute_dsl/utils.py (content)
⚔️ flashinfer/decode.py (content)
⚔️ flashinfer/fused_moe/__init__.py (content)
⚔️ flashinfer/fused_moe/core.py (content)
⚔️ flashinfer/gemm/__init__.py (content)
⚔️ flashinfer/gemm/gemm_base.py (content)
⚔️ flashinfer/gemm/kernels/grouped_gemm_masked_blackwell.py (content)
⚔️ flashinfer/jit/__init__.py (content)
⚔️ flashinfer/jit/gemm/__init__.py (content)
⚔️ flashinfer/jit/gemm/core.py (content)
⚔️ flashinfer/mla.py (content)
⚔️ flashinfer/prefill.py (content)
⚔️ flashinfer/triton/__init__.py (content)
⚔️ flashinfer/utils.py (content)
⚔️ include/flashinfer/trtllm/fmha/fmhaKernels.cuh (content)
⚔️ include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (content)
⚔️ include/flashinfer/trtllm/fmha/kernelParams.h (content)
⚔️ scripts/authorized_codeowner.txt (content)
⚔️ scripts/task_run_unit_tests.sh (content)
⚔️ scripts/test_utils.sh (content)
⚔️ tests/attention/test_trtllm_gen_attention.py (content)
⚔️ tests/gemm/test_bmm_fp8.py (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding microbenchmark support for Mamba selective_state_update to the benchmarks suite.
Description check ✅ Passed The PR description covers all required sections from the template: a detailed Description section explaining the changes and their purpose, Related Issues linking to #2513, and a completed Pull Request Checklist with pre-commit and tests marked as done.
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch microbench_mamba
  • Post resolved changes as copyable diffs in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive benchmarking capabilities for the Mamba selective_state_update operation within the flashinfer framework. It enables performance evaluation and correctness verification for Mamba layers across different backends and configurations, facilitating optimization and development of efficient state space models.

Highlights

  • Mamba Benchmarking Support: Benchmarking support for the Mamba selective_state_update kernel has been added to flashinfer_benchmark.py, covering both single-token prediction (STP) and multi-token prediction (MTP) modes.
  • Backend Integration: The new benchmarks support two backends: flashinfer (architecture-specific CUDA kernels for base/SM90/SM100+) for performance and triton (a reference implementation) for correctness checking.
  • Documentation Updates: The README.md has been updated to include Mamba API documentation, new command-line interface flags specific to Mamba, and an updated backend support matrix.
  • New Sample Test Cases: Eleven new sample test cases for Mamba selective_state_update have been added to sample_testlist.txt, covering various configurations and scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/README.md
    • Updated the overview to explicitly mention Mamba API performance benchmarking.
    • Added selective_state_update to the list of supported Mamba APIs, detailing its functionality and backend support for STP and MTP modes.
    • Introduced a new section for Mamba-specific CLI flags, including parameters like batch_size, nheads, dim, dstate, ngroups, cache_steps, input_dtype, state_dtype, weight_dtype, has_z, dt_softplus, and backends.
    • Updated the flashinfer_benchmark.py Routine & Backend Support Matrix to include selective_state_update with flashinfer and triton backends across various CUDA compute capabilities.
    • Added triton to the Backend Legend with its specific use case for Mamba selective_state_update.
  • benchmarks/flashinfer_benchmark.py
    • Modified the run_test function to dynamically import and execute Mamba routines via run_mamba_test when a Mamba API is specified.
    • Updated the parse_args function to include Mamba APIs in the list of supported routines and to call parse_mamba_args for handling Mamba-specific command-line arguments.
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Added a new mamba key to output_column_dict to define Mamba-specific output columns for benchmark results, such as nheads, dim, dstate, ngroups, cache_steps, state_dtype, weight_dtype, has_z, and dt_softplus.
    • Included the newly defined mamba columns in the all_output_columns list for comprehensive result reporting.
    • Added selective_state_update to the mamba category within benchmark_apis.
    • Extended the backend_support_matrix to include selective_state_update with flashinfer and triton backends for CUDA compute capabilities ranging from 7.5 to 12.0.
  • benchmarks/routines/mamba.py
    • New file added, containing the Triton reference implementation for selective_state_update, adapted from vllm-project/vllm.
    • Implemented run_mamba_test and parse_mamba_args functions to manage Mamba-specific benchmarking logic and argument parsing.
    • The testSelectiveStateUpdate function provides the core benchmarking logic, including random input tensor generation, execution across flashinfer and triton backends, optional reference checking against Triton, and calculation of performance metrics (TFLOPs/sec, TB/sec).
    • Supports both single-token prediction (STP) and multi-token prediction (MTP) modes for Mamba layer updates.
  • benchmarks/samples/sample_testlist.txt
    • Added 11 new test cases for the selective_state_update routine under a new 'Mamba (Selective State Space Models)' section.
    • New test cases cover various scenarios including Single-Token Prediction (STP) with bfloat16 and float32 state dtypes, STP with z gating and dt_softplus enabled (individually and combined).
    • Included STP tests with different nheads/ngroups ratios (1 and 16) to test kernel flexibility.
    • Added Multi-Token Prediction (MTP) tests with cache_steps of 1 and 2.
    • Incorporated a large batch size STP test and a FlashInfer-only performance-focused test case.
Activity
  • No human activity has been recorded on this pull request since its creation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds comprehensive microbenchmark support for the Mamba selective_state_update kernel to the benchmarking framework. The changes include updates to the main benchmark script, utility files, and documentation to integrate the new Mamba routine. A new file, benchmarks/routines/mamba.py, contains the core benchmarking logic, including a Triton reference implementation for correctness checking. The PR also adds several sample test cases. The implementation is robust, covering both single-token (STP) and multi-token (MTP) prediction modes, and the performance metric calculations are sound. I have one minor suggestion to improve the readability of the Triton reference kernel. Overall, this is a solid contribution.

Comment on lines +229 to +320
current_step_idx = 0
for _ in range(T):
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
if current_step_idx != 0 and cache_idx >= 0:
parent_ptr = (
retrieve_parent_token_ptr
+ pid_b * stride_retrieve_parent_token_batch
+ current_step_idx * stride_retrieve_parent_token_T
)
parent_step_idx = tl.load(parent_ptr).to(tl.int32)

if parent_step_idx >= 0 and parent_step_idx < T:
step_offset = parent_step_idx * nheads * dim * dstate
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * nheads * dim * dstate
+ step_offset
+ pid_h * dim * dstate
+ offs_m[:, None] * dstate
+ offs_n[None, :]
)
state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32)

x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim

x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix

B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]

if CACHE_INTERMEDIATE_STATES:
if state_batch_idx != pad_slot_id:
cache_ptr_base = (
intermediate_states_buffer
+ cache_idx * cache_steps * nheads * dim * dstate
+ current_step_idx * nheads * dim * dstate
+ pid_h * dim * dstate
)
cache_ptrs = cache_ptr_base + (
offs_m[:, None] * dstate + offs_n[None, :]
)
tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask)

out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)

current_step_idx += 1 # noqa: SIM113

x_ptr += stride_x_T
dt_ptr += stride_dt_T
B_ptr += stride_B_T
C_ptr += stride_C_T
out_ptr += stride_out_T
if HAS_Z:
z_ptr += stride_z_T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The manual increment of current_step_idx inside the for _ in range(T): loop can be simplified by using the loop variable directly. This improves readability and is a more idiomatic way to write such loops in Python and Triton. The noqa: SIM113 comment indicates awareness of this, but a refactor would still be beneficial for clarity.

    for current_step_idx in range(T):
        if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
            if current_step_idx != 0 and cache_idx >= 0:
                parent_ptr = (
                    retrieve_parent_token_ptr
                    + pid_b * stride_retrieve_parent_token_batch
                    + current_step_idx * stride_retrieve_parent_token_T
                )
                parent_step_idx = tl.load(parent_ptr).to(tl.int32)

                if parent_step_idx >= 0 and parent_step_idx < T:
                    step_offset = parent_step_idx * nheads * dim * dstate
                    cache_ptr = (
                        intermediate_states_buffer
                        + cache_idx * cache_steps * nheads * dim * dstate
                        + step_offset
                        + pid_h * dim * dstate
                        + offs_m[:, None] * dstate
                        + offs_n[None, :]
                    )
                    state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32)

        x_ptrs = x_ptr + offs_m * stride_x_dim
        dt_ptrs = dt_ptr + offs_m * stride_dt_dim
        B_ptrs = B_ptr + offs_n * stride_B_dstate
        C_ptrs = C_ptr + offs_n * stride_C_dstate
        if HAS_Z:
            z_ptrs = z_ptr + offs_m * stride_z_dim
        out_ptrs = out_ptr + offs_m * stride_out_dim

        x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        if not TIE_HDIM:
            dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
            if HAS_DT_BIAS:
                dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
            if DT_SOFTPLUS:
                dt = softplus(dt)
            A = tl.load(
                A_ptrs,
                mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
                other=0.0,
            ).to(tl.float32)
            dA = tl.exp(A * dt[:, None])
        else:
            dt = tl.load(dt_ptr).to(tl.float32)
            if HAS_DT_BIAS:
                dt += tl.load(dt_bias_ptr).to(tl.float32)
            if DT_SOFTPLUS:
                dt = softplus(dt)
            A = tl.load(A_ptr).to(tl.float32)
            dA = tl.exp(A * dt)  # scalar, not a matrix

        B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
        C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
        if HAS_D:
            D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
        if HAS_Z:
            z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

        dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
        state = state * dA + dB * x[:, None]

        if CACHE_INTERMEDIATE_STATES:
            if state_batch_idx != pad_slot_id:
                cache_ptr_base = (
                    intermediate_states_buffer
                    + cache_idx * cache_steps * nheads * dim * dstate
                    + current_step_idx * nheads * dim * dstate
                    + pid_h * dim * dstate
                )
                cache_ptrs = cache_ptr_base + (
                    offs_m[:, None] * dstate + offs_n[None, :]
                )
                tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask)

        out = tl.sum(state * C[None, :], axis=1)
        if HAS_D:
            out += x * D
        if HAS_Z:
            out *= z * tl.sigmoid(z)
        tl.store(out_ptrs, out, mask=offs_m < dim)

        x_ptr += stride_x_T
        dt_ptr += stride_dt_T
        B_ptr += stride_B_T
        C_ptr += stride_C_T
        out_ptr += stride_out_T
        if HAS_Z:
            z_ptr += stride_z_T

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 384-398: Insert a blank line immediately before the Markdown table
that follows the "### Mamba Flags" heading so the table is separated from the
preceding text; modify the README content around the "### Mamba Flags" section
to ensure there is an empty line between the heading (or prior paragraph) and
the table start to satisfy Markdownlint MD058 and proper rendering.

In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 726-737: The "selective_state_update" mapping in
routine_cc_to_supported_backends contains an inconsistent compute capability key
"11.0"; remove the "11.0" entry (or if you intended to support a real CC, add it
consistently across all routine_cc_to_supported_backends mappings and the README
backend matrix) so the set of keys matches the rest ({7.5, 8.0, 8.6, 8.9, 9.0,
10.0, 10.3, 12.0}); update the "selective_state_update" dict to drop the "11.0"
key (or add corresponding keys elsewhere) to restore consistency.
🧹 Nitpick comments (2)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

105-115: Duplicate "weight_dtype" column across output categories.

"weight_dtype" appears in both output_column_dict["moe"] (Line 54) and output_column_dict["mamba"] (Line 112). Since full_output_columns is a flat concatenation, the CSV header will have two "weight_dtype" columns. Both will resolve to the same cur_res["weight_dtype"] value.

This won't crash, but it produces a confusing CSV with duplicate column names. Consider renaming the mamba column to e.g. "mamba_weight_dtype", or moving shared fields like weight_dtype to the "general" category.

Also applies to: 150-150

benchmarks/routines/mamba.py (1)

993-1002: Use args.state_dtype / args.weight_dtype strings instead of str(torch_dtype) for consistent CSV formatting.

str(state_dtype) produces "torch.bfloat16" while other dtype columns (e.g. input_dtype from args) are stored as "bfloat16". This creates an inconsistency in the CSV output.

Proposed fix
-                cur_res["state_dtype"] = str(state_dtype)
-                cur_res["weight_dtype"] = str(weight_dtype)
+                cur_res["state_dtype"] = args.state_dtype
+                cur_res["weight_dtype"] = args.weight_dtype

}
)
@triton.jit(do_not_specialize=["T"])
def _selective_scan_update_kernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To avoid confusion, maybe add a _reference suffix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good suggestion. Done in the latest commit! 1bf361d

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
benchmarks/routines/mamba.py (6)

64-323: Triton kernel implementation looks correct.

The selective scan update kernel follows standard Mamba SSM patterns. A couple of minor static-analysis notes from the kernel body:

  • Line 118 – batch unused (Ruff ARG001): This is a false positive. batch is part of the Triton kernel's parameter interface and the batch dimension is handled via tl.program_id(axis=1) / grid dispatch — no action needed.

  • Line 312 – stale # noqa: SIM113 (Ruff RUF100): The directive references a non-enabled rule. Safe to remove.

Cleanup for line 312
-        current_step_idx += 1  # noqa: SIM113
+        current_step_idx += 1

424-440: Readability: lambda grid assignment and deeply nested ternary.

Two minor style items flagged by static analysis and readability review:

  1. Line 424 (Ruff E731): Prefer a def over a lambda assignment.
  2. Lines 432-439: The four-level nested ternary for BLOCK_SIZE_M/num_warps is hard to follow at a glance.
Suggested cleanup
-    grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
+    def grid(META):
+        return (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
 
-    BLOCK_SIZE_M, num_warps = (
-        (32, 4)
-        if dstate <= 16
-        else (
-            (16, 4)
-            if dstate <= 32
-            else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
-        )
-    )
+    if dstate <= 16:
+        BLOCK_SIZE_M, num_warps = 32, 4
+    elif dstate <= 32:
+        BLOCK_SIZE_M, num_warps = 16, 4
+    elif dstate <= 64:
+        BLOCK_SIZE_M, num_warps = 8, 4
+    elif dstate <= 128:
+        BLOCK_SIZE_M, num_warps = 4, 4
+    else:
+        BLOCK_SIZE_M, num_warps = 4, 8

546-660: Argument parsing and validation look solid.

The nheads % ngroups divisibility and supported-ratio checks are good guardrails. One minor note: supported_ratios = [1, 8, 16] at line 650 is hardcoded. If the FlashInfer CUDA kernel adds support for new ratios in the future, this will silently reject valid configurations. Consider adding a comment noting this mirrors a specific kernel constraint so future maintainers know where to update.


806-838: Closure captures many variables from enclosing scope — works, but worth noting.

run_backend closes over z, dt_bias, dt_softplus, slot_idx, cache_steps, and triton_cache_steps. This is fine for benchmarking, but note that intermediate_states_buffer is never allocated or passed to either backend — so MTP intermediate-state caching is not exercised by this benchmark. If that's intentional (benchmarking the core SSM update path only), a brief comment clarifying the omission would help future readers.


871-888: state_cache is mutated in-place across backends by bench_gpu_time.

The loop passes the same state_cache tensor to bench_gpu_time for every backend. After the first backend's warm-up + measurement iterations, state_cache contains different values for the second backend's run. This doesn't affect correctness of refcheck (captured earlier from clean_state_snapshot clones), and for timing the compute pattern is identical regardless of state values, so results are valid.

Still, if a future change introduces a data-dependent fast-path or NaN-propagation concern, this could silently skew timings. A defensive one-liner to restore state before each backend's bench run would be cheap insurance:

Optional defensive clone
     for cur_backend in backends:
+        if clean_state_snapshot is not None:
+            state_cache.copy_(clean_state_snapshot)
         if run_refcheck and cur_backend != "triton":

983-1002: defaultdict(str) is unnecessary here.

All keys are explicitly assigned, so a plain dict() (or {}) would be clearer and avoid silently returning "" for typo'd keys during downstream consumption.

Minor simplification
-                cur_res = defaultdict(str)
+                cur_res = {}

Comment on lines +953 to +979
read_bytes = (
batch_size * nheads * dim * dstate * state_dtype.itemsize # state
+ batch_size * T_val * nheads * dim * input_dtype.itemsize # x
+ batch_size * T_val * nheads * weight_dtype.itemsize # dt (broadcast)
+ nheads * 4 # A (float32, broadcast)
+ batch_size * T_val * ngroups * dstate * input_dtype.itemsize # B
+ batch_size * T_val * ngroups * dstate * input_dtype.itemsize # C
+ nheads * weight_dtype.itemsize # D (broadcast)
+ nheads * weight_dtype.itemsize # dt_bias (broadcast)
)
if has_z:
read_bytes += batch_size * T_val * nheads * dim * input_dtype.itemsize

write_bytes = (
batch_size * nheads * dim * dstate * state_dtype.itemsize # state
+ batch_size * T_val * nheads * dim * input_dtype.itemsize # output
)

problem_bytes = read_bytes + write_bytes

# FLOPs estimate (TIE_HDIM case, where dt/A/D/dt_bias broadcast across dim):
# Per (dim, dstate) element per (batch, T, head):
# state * dA: 1 mul, dB * x[:, None]: 1 mul, state + ...: 1 add,
# state * C[None, :]: 1 mul, sum reduction: ~1 add => 5 FLOPs/element
problem_flops = batch_size * T_val * nheads * dim * dstate * 5
tflops = problem_flops / (10**9 * median_time) # TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # TB/sec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify bench_gpu_time return value documentation / units
rg -n -C5 "def bench_gpu_time" --type=py

Repository: flashinfer-ai/flashinfer

Length of output: 2662


🏁 Script executed:

#!/bin/bash
# Verify torch.dtype.itemsize usage in the repo
rg -n "\.itemsize" --type=py -C2

Repository: flashinfer-ai/flashinfer

Length of output: 41189


🏁 Script executed:

#!/bin/bash
# Check PyTorch version requirements in the repo
find . -maxdepth 2 -type f \( -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "*.yml" -o -name "*.yaml" \) | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 246


🏁 Script executed:

#!/bin/bash
# Get full bench_gpu_time function definition
sed -n '1508,1550p' flashinfer/testing/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 1723


🏁 Script executed:

#!/bin/bash
# Check PyTorch version requirements
cat pyproject.toml | grep -A 20 "dependencies"

Repository: flashinfer-ai/flashinfer

Length of output: 1256


🏁 Script executed:

#!/bin/bash
# Also check requirements.txt
cat requirements.txt | grep -i torch

Repository: flashinfer-ai/flashinfer

Length of output: 73


🏁 Script executed:

#!/bin/bash
# Get full requirements.txt to see all dependencies and versions
cat requirements.txt

Repository: flashinfer-ai/flashinfer

Length of output: 255


🏁 Script executed:

#!/bin/bash
# Check CI configuration for PyTorch version requirements
find . -maxdepth 3 -type f \( -name "*.yml" -o -name "*.yaml" \) -exec grep -l "pytorch\|torch" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 139


🏁 Script executed:

#!/bin/bash
# Look for torch version constraints in setup.py if it exists
cat setup.py 2>/dev/null || echo "No setup.py found"

Repository: flashinfer-ai/flashinfer

Length of output: 85


🏁 Script executed:

#!/bin/bash
# Check CI workflow files for PyTorch version constraints
cat .github/workflows/nightly-release.yml | grep -i "torch\|pytorch" -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 530


🏁 Script executed:

#!/bin/bash
# Also check release.yml
cat .github/workflows/release.yml | grep -i "torch\|pytorch" -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 466


🏁 Script executed:

#!/bin/bash
# Search for any PyTorch version constraints or minimum version docs
rg -i "minimum.*torch|torch.*version|torch.*>=|>=.*torch" -A 3 -B 3 --type=py --type=md

Repository: flashinfer-ai/flashinfer

Length of output: 24312


🏁 Script executed:

#!/bin/bash
# Check PyTorch source history for itemsize on dtype
# First, let's search for any version checks or compatibility notes about itemsize
rg -i "itemsize" --type=py benchmarks/routines/mamba.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1377


🏁 Script executed:

#!/bin/bash
# Look for the actual bench_gpu_time function to see full return type and timing backend
sed -n '1508,1600p' flashinfer/testing/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 3588


🏁 Script executed:

#!/bin/bash
# Check if there are any torch version checks related to dtype.itemsize usage
rg "TorchVersion|torch_version|version" flashinfer/testing/utils.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 470


🏁 Script executed:

#!/bin/bash
# Check for any README or documentation about minimum PyTorch version
rg "minimum.*torch|torch.*>=|>=.*torch|supported.*version" README.md CONTRIBUTING.md -i -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 288


🏁 Script executed:

#!/bin/bash
# Check the full calc_bw_tflops function to see the time unit context
sed -n '740,760p' flashinfer/testing/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 842


🏁 Script executed:

#!/bin/bash
# Search for any itemsize-related comments or version checks
rg "itemsize" --type=py -B 3 -A 3 | grep -E "(itemsize|version|>=|<=)" -A 2 -B 2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 3029


🏁 Script executed:

#!/bin/bash
# Look for any PyTorch version checks or constraints in the codebase
rg "TorchVersion.*2\.|torch.*>=" --type=py -B 2 -A 2 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 6809


🏁 Script executed:

#!/bin/bash
# Check PyTorch documentation or comments about itemsize availability
rg "itemsize" --type=py -B 5 | grep -E "(version|2\.|deprecated|available|since)" -i | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look at the actual mamba.py file to confirm the exact context and verify the formulas
sed -n '945,980p' benchmarks/routines/mamba.py

Repository: flashinfer-ai/flashinfer

Length of output: 2034


Verified: median_time is in milliseconds; formulas are correct. Itemsize availability not explicitly documented.

  1. Time units confirmed: bench_gpu_time explicitly returns times in milliseconds (docstring states "Per-iteration execution times in milliseconds"), and the math checks out: 10^9 * time_ms = 10^6 * time_sec, so flops / (10^9 * median_time) correctly yields TFLOPs/sec when median_time is in milliseconds. The formulas are correct.

  2. torch.dtype.itemsize concern: The attribute is used extensively throughout the codebase (40+ instances across tests, benchmarks, and production code) without version guards. However, requirements.txt specifies only torch with no minimum version constraint. PyTorch added dtype.itemsize in version 2.0, but this is not explicitly documented as a minimum requirement. While the ubiquitous usage suggests it's expected to be available, the lack of an explicit version constraint in the project's dependency specification is a valid concern if supporting older PyTorch versions.

@ishovkun
Copy link
Contributor

ishovkun commented Feb 6, 2026

Overall, looks good. I but think that you copied the contents of the Triton kernel to benchmarks/routines/mamba.py.
There are already at least 4 versions of the kernel in various frameworks, so let's avoid having to maintain two copies of this in the same repo.

@bkryu
Copy link
Collaborator Author

bkryu commented Feb 12, 2026

Overall, looks good. I but think that you copied the contents of the Triton kernel to benchmarks/routines/mamba.py. There are already at least 4 versions of the kernel in various frameworks, so let's avoid having to maintain two copies of this in the same repo.

Thanks @ishovkun, I removed the duplicated Triton reference in the latest commit.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/routines/mamba.py`:
- Around line 219-226: Wrap the nheads/ngroups ratio validation in the same
backend guard used by the dim and dstate checks (i.e., only run when
"flashinfer" is in args.backends); specifically, move or wrap the
supported_ratios/ratio check so it is executed only if "flashinfer" in
args.backends, leaving the existing variables supported_ratios, ratio,
args.nheads and args.ngroups unchanged and keeping the same ValueError message
when triggered.
- Around line 553-572: The result dict population in the if args.output_path
block (where cur_res is built and keys like "nheads", "dim", "state_dtype" are
set) is missing batch_size and input_dtype; update that block to add
cur_res["batch_size"] = batch_size and cur_res["input_dtype"] = str(input_dtype)
(or the existing variable's string representation) alongside the other
Mamba-specific columns so outputs include the batch size and input dtype for
reproducibility.
🧹 Nitpick comments (1)
benchmarks/routines/mamba.py (1)

74-74: Eager module-level import will break flashinfer-only users if the Triton reference file is missing or Triton is not installed.

_import_triton_reference() runs unconditionally at import time. If someone only wants the flashinfer backend, this still fails the entire module load when Triton or the reference file is absent.

Consider lazy-loading: call _import_triton_reference() only when the triton backend is actually requested (e.g., inside run_backend or at the top of testSelectiveStateUpdate when "triton" in backends or run_refcheck).

Suggested approach
-selective_state_update_triton_reference = _import_triton_reference()
+selective_state_update_triton_reference = None
+
+def _get_triton_reference():
+    global selective_state_update_triton_reference
+    if selective_state_update_triton_reference is None:
+        selective_state_update_triton_reference = _import_triton_reference()
+    return selective_state_update_triton_reference

Then replace usages of selective_state_update_triton_reference(...) with _get_triton_reference()(...).

Comment on lines +219 to +226
# Validate nheads/ngroups ratio is supported by the CUDA kernel
supported_ratios = [1, 8, 16]
ratio = args.nheads // args.ngroups
if ratio not in supported_ratios:
raise ValueError(
f"nheads/ngroups ratio ({ratio} = {args.nheads}/{args.ngroups}) is not supported by the FlashInfer kernel. "
f"Supported ratios: {supported_ratios}."
)
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

nheads/ngroups ratio validation is not gated on flashinfer backend, unlike dim and dstate checks.

Lines 203 and 212 correctly gate their validation with if "flashinfer" in args.backends, but this ratio check applies unconditionally. A user running --backends triton with a ratio not in [1, 8, 16] would be incorrectly rejected.

Proposed fix
     supported_ratios = [1, 8, 16]
     ratio = args.nheads // args.ngroups
-    if ratio not in supported_ratios:
+    if "flashinfer" in args.backends and ratio not in supported_ratios:
         raise ValueError(
             f"nheads/ngroups ratio ({ratio} = {args.nheads}/{args.ngroups}) is not supported by the FlashInfer kernel. "
             f"Supported ratios: {supported_ratios}."
         )
🧰 Tools
🪛 Ruff (0.15.0)

[warning] 223-226: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@benchmarks/routines/mamba.py` around lines 219 - 226, Wrap the nheads/ngroups
ratio validation in the same backend guard used by the dim and dstate checks
(i.e., only run when "flashinfer" is in args.backends); specifically, move or
wrap the supported_ratios/ratio check so it is executed only if "flashinfer" in
args.backends, leaving the existing variables supported_ratios, ratio,
args.nheads and args.ngroups unchanged and keeping the same ValueError message
when triggered.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine because Triton is a reference. We would always be benchmarking FlashInfer as the main backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bkryu, understood! If FlashInfer is always the primary target, the current validation makes sense for your workflow.

Just to note: since the code does accept --backends triton without flashinfer, a user running only Triton with an unsupported ratio would see "is not supported by the FlashInfer kernel" in the error message, which could be a bit confusing. But if that's not a practical use case for you, the current approach is fine.

@bkryu bkryu self-assigned this Feb 12, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/mamba.py`:
- Line 74: The module currently performs an eager import by assigning
selective_state_update_triton_reference = _import_triton_reference() which
causes ImportError for users who don't have the Triton test file; change this to
a lazy loader by implementing a _get_triton_reference() function that calls
_import_triton_reference() on first use and caches the result (e.g., store in a
module-level variable), then remove the eager assignment and update all call
sites to invoke _get_triton_reference()(...) instead of
selective_state_update_triton_reference(...); ensure the new loader handles
ImportError gracefully so non-Triton backends continue to work.
🧹 Nitpick comments (1)
benchmarks/routines/mamba.py (1)

233-233: Nit: testSelectiveStateUpdate uses camelCase.

Python convention (PEP 8) favors snake_case for function names. Consider test_selective_state_update for consistency.

return module.selective_state_update_triton


selective_state_update_triton_reference = _import_triton_reference()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Module-level eager import will fail even when Triton is not requested.

_import_triton_reference() runs at import time, so any user who only needs --backends flashinfer will still get an ImportError if tests/mamba/selective_state_update_triton.py is absent (e.g., in a packaged/installed environment without the test tree). Consider lazy-loading:

Proposed fix
-selective_state_update_triton_reference = _import_triton_reference()
+selective_state_update_triton_reference = None
+
+def _get_triton_reference():
+    global selective_state_update_triton_reference
+    if selective_state_update_triton_reference is None:
+        selective_state_update_triton_reference = _import_triton_reference()
+    return selective_state_update_triton_reference

Then replace usages of selective_state_update_triton_reference(...) with _get_triton_reference()(...).

🤖 Prompt for AI Agents
In `@benchmarks/routines/mamba.py` at line 74, The module currently performs an
eager import by assigning selective_state_update_triton_reference =
_import_triton_reference() which causes ImportError for users who don't have the
Triton test file; change this to a lazy loader by implementing a
_get_triton_reference() function that calls _import_triton_reference() on first
use and caches the result (e.g., store in a module-level variable), then remove
the eager assignment and update all call sites to invoke
_get_triton_reference()(...) instead of
selective_state_update_triton_reference(...); ensure the new loader handles
ImportError gracefully so non-Triton backends continue to work.

@yzh119 yzh119 merged commit 0ebf05e into flashinfer-ai:main Feb 16, 2026
37 of 39 checks passed
@bkryu bkryu deleted the microbench_mamba branch February 17, 2026 19:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants