perf: Port the separate reduce kernel mode from trtllm.#1685
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @weireweire, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates a separate reduction kernel mode, originally from TensorRT-LLM, into FlashInfer. This architectural change aims to optimize Flash Attention (FMHA) computations by decoupling the reduction phase into a distinct, specialized CUDA kernel. The modifications involve adding new CUDA kernel implementations, integrating them into the existing FMHA kernel selection logic, and updating the build system to support these new components, ultimately enhancing the efficiency and flexibility of attention mechanisms.
Highlights
- New Separate Reduction Kernel: A dedicated CUDA kernel (
fmhaReduction.cu) has been introduced to perform the reduction step in Flash Attention (FMHA), allowing for more specialized and potentially optimized processing. - Enhanced Kernel Integration: The main FMHA kernel (
fmhaKernels.cuh) now conditionally utilizes this new separate reduction kernel, specifically for high-throughputKeepsMmaAbForGenerationkernels, improving flexibility and performance. - New Multi-CTA KV Mode: A new
MultiCtasKvMode::GmemReductionWithSeparateKernelhas been added to explicitly manage and enable the use of this separate reduction kernel within the system. - CUDA Utility Functions: A new header file (
kernelUtils.h) provides essential CUDA helper functions for data type conversions (e.g., float to half, bfloat16, e4m3) andfloat2vector operations, supporting the new kernel's operations. - Build System Updates: The build configuration has been updated to include the newly added CUDA source files and reflect changes in the compiled artifacts for the TRTLLM FMHA kernel.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request ports a separate reduction kernel from TensorRT-LLM for use in fused multi-head attention, which is a valuable addition for specific kernel configurations. The implementation is well-contained, and the new CUDA code demonstrates good performance practices. I have identified a critical issue related to potential division-by-zero and integer overflow in the kernel launcher, as well as a medium-severity issue concerning a magic number that could be refactored for better maintainability. Addressing these points will strengthen the robustness and clarity of the code.
3097254 to
d27fbbe
Compare
yzh119
left a comment
There was a problem hiding this comment.
Overall LGTM, left some minor suggestions
|
@weireweire Could you address the comments so that we can merge this soon? Thanks |
d27fbbe to
7e93256
Compare
|
@yzh119 we found there are some accuracy regression on fp16/bf16 sink test For accuracy debugging and fp8 test adding, I don't have enough time for now, let do it in separate PR later. |
|
looking forward to this |
|
|
||
| // Run the separate reduction kernel if needed. | ||
| runFmhaReduction(kernelMeta, kernelParams, params.mMultiProcessorCount, params.enable_pdl, | ||
| params.stream); |
There was a problem hiding this comment.
This is defined in namespace tensorrt_llm::kernels, how could this ever build for you?
There was a problem hiding this comment.
Related modules are compiled in UT and succeed:
https://ci.tlcpack.ai/blue/rest/organizations/jenkins/pipelines/flashinfer-ci/branches/PR-1685/runs/5/nodes/24/steps/195/log/?start=0
[2025-09-19T07:27:57.117Z] [1659/3127] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output fmha_gen/trtllm_fmha_kernel_launcher.cuda.o.d -DTORCH_EXTENSION_NAME=fmha_gen -DTORCH_API_INCLUDE_EXTENSION_H -DPy_LIMITED_API=0x03090000 -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -D_GLIBCXX_USE_CXX11_ABI=1 -I/root/.cache/flashinfer/cubins/538f8e38ace07f701f61e26b138b2b8c70ce9e8e/fmha/trtllm-gen/include -isystem /opt/conda/envs/py312/include/python3.12 -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/torch/include -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem /workspace/include -isystem /workspace/csrc -isystem /workspace/3rdparty/cutlass/include -isystem /workspace/3rdparty/cutlass/tools/util/include -isystem /workspace/3rdparty/spdlog/include --compiler-options=-fPIC --expt-relaxed-constexpr -static-global-template-stub=false -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a -gencode=arch=compute_120a,code=sm_120a -gencode=arch=compute_100a,code=sm_100a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -O3 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16 -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -lineinfo -DTLLM_GEN_FMHA_CUBIN_PATH=\"538f8e38ace07f701f61e26b138b2b8c70ce9e8e/fmha/trtllm-gen/\" -DTLLM_GEN_FMHA_METAINFO_HASH=\"71f06a8fc03d28cc94ee6fc180fb7e37256a9e1c30ab2a6c0bf20a2d97af3eff\" -c /workspace/csrc/trtllm_fmha_kernel_launcher.cu -o fmha_gen/trtllm_fmha_kernel_launcher.cuda.o
There was a problem hiding this comment.
It should compile because C++ Argument-Dependent Lookup , maybe you are using different compiler? But I concur adding namespace is better.
📌 Description
This also updated the kernels which fixed mha perf regression and fp8 sink attention accuracy issue.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes