Skip to content

perf: Port the separate reduce kernel mode from trtllm.#1685

Merged
yzh119 merged 3 commits into
flashinfer-ai:mainfrom
weireweire:fix-mha-perf
Sep 19, 2025
Merged

perf: Port the separate reduce kernel mode from trtllm.#1685
yzh119 merged 3 commits into
flashinfer-ai:mainfrom
weireweire:fix-mha-perf

Conversation

@weireweire

@weireweire weireweire commented Sep 16, 2025

Copy link
Copy Markdown
Collaborator

📌 Description

This also updated the kernels which fixed mha perf regression and fp8 sink attention accuracy issue.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @weireweire, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a separate reduction kernel mode, originally from TensorRT-LLM, into FlashInfer. This architectural change aims to optimize Flash Attention (FMHA) computations by decoupling the reduction phase into a distinct, specialized CUDA kernel. The modifications involve adding new CUDA kernel implementations, integrating them into the existing FMHA kernel selection logic, and updating the build system to support these new components, ultimately enhancing the efficiency and flexibility of attention mechanisms.

Highlights

  • New Separate Reduction Kernel: A dedicated CUDA kernel (fmhaReduction.cu) has been introduced to perform the reduction step in Flash Attention (FMHA), allowing for more specialized and potentially optimized processing.
  • Enhanced Kernel Integration: The main FMHA kernel (fmhaKernels.cuh) now conditionally utilizes this new separate reduction kernel, specifically for high-throughput KeepsMmaAbForGeneration kernels, improving flexibility and performance.
  • New Multi-CTA KV Mode: A new MultiCtasKvMode::GmemReductionWithSeparateKernel has been added to explicitly manage and enable the use of this separate reduction kernel within the system.
  • CUDA Utility Functions: A new header file (kernelUtils.h) provides essential CUDA helper functions for data type conversions (e.g., float to half, bfloat16, e4m3) and float2 vector operations, supporting the new kernel's operations.
  • Build System Updates: The build configuration has been updated to include the newly added CUDA source files and reflect changes in the compiled artifacts for the TRTLLM FMHA kernel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports a separate reduction kernel from TensorRT-LLM for use in fused multi-head attention, which is a valuable addition for specific kernel configurations. The implementation is well-contained, and the new CUDA code demonstrates good performance practices. I have identified a critical issue related to potential division-by-zero and integer overflow in the kernel launcher, as well as a medium-severity issue concerning a magic number that could be refactored for better maintainability. Addressing these points will strengthen the robustness and clarity of the code.

Comment thread csrc/fmhaReduction.cu Outdated
Comment thread csrc/fmhaReduction.cu Outdated
Comment thread csrc/kernelUtils.h Outdated
Comment thread include/flashinfer/trtllm/common/cudaUtils.h Outdated
@weireweire weireweire changed the title [draft]Port the separate reduce kernel mode from trtllm. Port the separate reduce kernel mode from trtllm. Sep 17, 2025

@yzh119 yzh119 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, left some minor suggestions

Comment thread include/flashinfer/trtllm/common/cudaUtils.h Outdated
Comment thread csrc/fmhaReduction.cu Outdated
Comment thread flashinfer/artifacts.py Outdated
@nvpohanh

Copy link
Copy Markdown
Contributor

@weireweire Could you address the comments so that we can merge this soon? Thanks

@yzh119 yzh119 changed the title Port the separate reduce kernel mode from trtllm. perf: Port the separate reduce kernel mode from trtllm. Sep 19, 2025
@yzh119 yzh119 enabled auto-merge (squash) September 19, 2025 06:24
@weireweire

Copy link
Copy Markdown
Collaborator Author

@yzh119 we found there are some accuracy regression on fp16/bf16 sink test test_blackwell_trtllm_gen_decode_attention_sink, I just added a 1% tolerance to make it pass.
We have tested the e2e gpt-oss don't have significant accuracy drop.
let merge this first.

For accuracy debugging and fp8 test adding, I don't have enough time for now, let do it in separate PR later.

@fzyzcjy

fzyzcjy commented Sep 19, 2025

Copy link
Copy Markdown
Collaborator

looking forward to this

@yzh119 yzh119 merged commit ed1d8c1 into flashinfer-ai:main Sep 19, 2025
2 checks passed

// Run the separate reduction kernel if needed.
runFmhaReduction(kernelMeta, kernelParams, params.mMultiProcessorCount, params.enable_pdl,
params.stream);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is defined in namespace tensorrt_llm::kernels, how could this ever build for you?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related modules are compiled in UT and succeed:
https://ci.tlcpack.ai/blue/rest/organizations/jenkins/pipelines/flashinfer-ci/branches/PR-1685/runs/5/nodes/24/steps/195/log/?start=0

[2025-09-19T07:27:57.117Z] [1659/3127] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output fmha_gen/trtllm_fmha_kernel_launcher.cuda.o.d -DTORCH_EXTENSION_NAME=fmha_gen -DTORCH_API_INCLUDE_EXTENSION_H -DPy_LIMITED_API=0x03090000 -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -D_GLIBCXX_USE_CXX11_ABI=1 -I/root/.cache/flashinfer/cubins/538f8e38ace07f701f61e26b138b2b8c70ce9e8e/fmha/trtllm-gen/include -isystem /opt/conda/envs/py312/include/python3.12 -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/torch/include -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem /workspace/include -isystem /workspace/csrc -isystem /workspace/3rdparty/cutlass/include -isystem /workspace/3rdparty/cutlass/tools/util/include -isystem /workspace/3rdparty/spdlog/include --compiler-options=-fPIC --expt-relaxed-constexpr -static-global-template-stub=false -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a -gencode=arch=compute_120a,code=sm_120a -gencode=arch=compute_100a,code=sm_100a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -O3 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -lineinfo -DTLLM_GEN_FMHA_CUBIN_PATH=\"538f8e38ace07f701f61e26b138b2b8c70ce9e8e/fmha/trtllm-gen/\" -DTLLM_GEN_FMHA_METAINFO_HASH=\"71f06a8fc03d28cc94ee6fc180fb7e37256a9e1c30ab2a6c0bf20a2d97af3eff\" -c /workspace/csrc/trtllm_fmha_kernel_launcher.cu -o fmha_gen/trtllm_fmha_kernel_launcher.cuda.o 

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should compile because C++ Argument-Dependent Lookup , maybe you are using different compiler? But I concur adding namespace is better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants