MHA chunk prefix: tune and use configs for fa3 and flashinfer#9551
MHA chunk prefix: tune and use configs for fa3 and flashinfer#9551xu-yfei wants to merge 2 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the MHA chunked prefix feature by implementing an optimized configuration file approach for both fa3 and flashinfer attention backends. This optimization refines the strategy for selecting between Multi-Head Attention (MHA) and Multi-Query Attention (MLA), leading to notable improvements in Time To First Token (TTFT) across various input lengths. The changes include a new tuning script to generate these configurations and their integration into the system for dynamic, performance-driven backend selection.
Highlights
- Optimized MHA Chunked Prefix: Introduced an optimized configuration file approach for the MHA chunked prefix feature, enhancing the decision-making between MHA and MLA backends.
- Performance Improvement: Achieved significant reductions in Time To First Token (TTFT) during stress testing: 4% for 1000 input length, 11.3% for 2000, and 20.5% for 4000.
- New Tuning Script: Added a new tuning script (
tune_mha_chunk_prefix.py) to generate and validate these optimized configurations, ensuring adaptability to different hardware and workloads. - Dynamic Backend Selection: Integrated the tuned configurations into the system, allowing dynamic selection of MHA or MLA based on a calculated
se(sum ofextend * (extend + prefix)) threshold. - FlashInfer Backend Refinement: Streamlined the FlashInfer MLA backend by removing redundant
use_raggedflags and simplifying prefill logic, contributing to cleaner and potentially more efficient code.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
|
@Fridge003 Could you please help review this PR? |
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized configuration-based approach for selecting between MHA and MLA attention mechanisms in the fa3 and flashinfer backends, aiming to reduce Time To First Token (TTFT). A tuning script is provided to generate these configurations. My review focuses on the correctness of the new tuning script and the impact of the changes on existing backend logic. I've identified a potential bug in the tuning script that may incorrectly restrict its usage on certain hardware, and a potential performance regression in the flashinfer backend due to the removal of an optimization path. Overall, the changes are well-structured and the data-driven approach to performance tuning is commendable.
1e3b6be to
0bb5127
Compare
|
IMO the point is tuning out the overhead and the performance scaling factor(considering fa kernels efficiency) of MLA and MHA for each device. (O1/O2 for overheads and S1/S2 for scaling) And then simplified S = S1/S2, O=O1-O2 |
0bb5127 to
2dfb6ec
Compare
c7b0623 to
9169d85
Compare
|
@Fridge003 @Alcanderian replaced by PR #10953: |
Motivation
Replaced by PR #10953
For the MHA chunked prefix feature, we have implemented an optimized configuration file approach for both the fa3 and flashinfer backends. This approach enables more precise decision-making in the strategy of selecting MHA or MLA. In stress testing scenarios with 500 samples each (for input sizes of 1000, 2000, and 4000), we achieved a reduction in TTFT (Time To First Token) by 4%, 11.3%, and 20.5% respectively. Additionally, we have provided a tuning script to support this optimization.
Tuning files are provided for fa3 and flashinfer on H20. When the attention backend is fa3, flashinfer, flashmla (actually extend is also flashinfer), you can enable this feature by setting the environment variable
SGL_CHUNKED_PREFIX_CACHE_USE_TUNED=1Tuning details
As described in PR #5624, the computational cost of mla and mha is approximately:
The computational loads of MLA and MHA are related to extend * (extend + prefix). However, due to the greater complexity of real-world scenarios, it is difficult to accurately determine the time consumption of MLA and MHA through simple theoretical derivation—and thus equally hard to simply decide between MLA and MHA using this method.
Via
benchmark/kernels/mha_chunk_prefix/tune_mha_chunk_prefix.pyin the PR, we measured the time consumption of MLA and MHA for different distributions of prefix lens and extend lens per batch, under the conditions of fixed sums of prefix lens, fixed sums of extend lens, and the same batch size. We found that the execution time of MLA has a strong correlation with context * (context + prefix): when this value is large, MLA is basically more time-consuming than MHA.Therefore, we generated either 500 or 100 sets of randomly distributed prefix lens and extend lens, and obtained the execution time of MLA and MHA for each set. For each sample, we calculated the sum of extend * (extend + prefix) across all batches—i.e., sum([seq_len*extend_len for seq_len, extend_len in zip(seq_lens, extend_lens)]), hereinafter referred to as se. We selected a specific se value as the threshold: for a given scenario, if its se is ≥ the threshold, MLA will most likely take more time than MHA, so we choose MHA; otherwise, we choose MLA. We further counted the distribution of each sample under this selection rule and obtained the following data, where:
serefers tosum([s*e for s, e in zip(seq_lens, extend_lens)]). These thresholdis defined for scenarios with specified batch size, fixed sum of prefix lens, and fixed sum of extend lens; when se ≥ this threshold, MHA is selected, otherwise MLA is selected.≥se threshold: filtered samples: samples with se ≥ threshold, (number of filtered samples where MLA time cost >= MHA time cost) / (total number of filtered samples:). A larger value indicates a higher probability that MLA is more time-consuming under this threshold scenario.<se threshold: filtered samples: samples with se < threshold, (number of filtered samples where MLA time cost >= MHA time cost) / (total number of filtered samples:).. A smaller value indicates fewer "missed cases" (i.e., fewer instances where MLA is unexpectedly slower despite se being below the threshold).Modifications
Added a switch SGL_CHUNKED_PREFIX_CACHE_USE_TUNED; the feature can be enabled by setting it to 1 or true, with a default value of false.
The decision on whether to select MHA is determined based on information including batch size, the sum of sequence lengths (seq lens), the sum of extension lengths (extend lens), and the total sum of all seq_len*extend_len values. The batch size range covered is 1–8; for cases where the actual batch size exceeds 8, the configuration for batch size 8 will be used.
The ragged logic of flashinfer has been fine-tuned: init_forward_metadata now only performs paged prefill initialization, while ragged initialization is triggered at runtime based on whether the MHA chunk feature is enabled. Since the logic for enabling the MHA chunk feature resides in the DeepSeek code, enabling it in advance during the init_forward_metadata initialization process requires moving the MHA chunk enablement logic out of DeepSeek, which currently involves a certain amount of refactoring work.
Accuracy Tests
Benchmarking and Profiling
Request throughput (req/s) :
Checklist