Skip to content

MHA chunk prefix: tune and use configs for fa3 and flashinfer#9551

Closed
xu-yfei wants to merge 2 commits intosgl-project:mainfrom
xu-yfei:xyf/tune_mha
Closed

MHA chunk prefix: tune and use configs for fa3 and flashinfer#9551
xu-yfei wants to merge 2 commits intosgl-project:mainfrom
xu-yfei:xyf/tune_mha

Conversation

@xu-yfei
Copy link
Copy Markdown
Contributor

@xu-yfei xu-yfei commented Aug 24, 2025

Motivation

Replaced by PR #10953

For the MHA chunked prefix feature, we have implemented an optimized configuration file approach for both the fa3 and flashinfer backends. This approach enables more precise decision-making in the strategy of selecting MHA or MLA. In stress testing scenarios with 500 samples each (for input sizes of 1000, 2000, and 4000), we achieved a reduction in TTFT (Time To First Token) by 4%, 11.3%, and 20.5% respectively. Additionally, we have provided a tuning script to support this optimization.

Tuning files are provided for fa3 and flashinfer on H20. When the attention backend is fa3, flashinfer, flashmla (actually extend is also flashinfer), you can enable this feature by setting the environment variable SGL_CHUNKED_PREFIX_CACHE_USE_TUNED=1

Tuning details

As described in PR #5624, the computational cost of mla and mha is approximately:

MHA + Extact KV:
extend * (extend + prefix) * 320 * 2 + prefix * 65536 * 2 
MLA:
extend * (extend + prefix) * 1088 * 2

The computational loads of MLA and MHA are related to extend * (extend + prefix). However, due to the greater complexity of real-world scenarios, it is difficult to accurately determine the time consumption of MLA and MHA through simple theoretical derivation—and thus equally hard to simply decide between MLA and MHA using this method.

Via benchmark/kernels/mha_chunk_prefix/tune_mha_chunk_prefix.py in the PR, we measured the time consumption of MLA and MHA for different distributions of prefix lens and extend lens per batch, under the conditions of fixed sums of prefix lens, fixed sums of extend lens, and the same batch size. We found that the execution time of MLA has a strong correlation with context * (context + prefix): when this value is large, MLA is basically more time-consuming than MHA.

Therefore, we generated either 500 or 100 sets of randomly distributed prefix lens and extend lens, and obtained the execution time of MLA and MHA for each set. For each sample, we calculated the sum of extend * (extend + prefix) across all batches—i.e., sum([seq_len*extend_len for seq_len, extend_len in zip(seq_lens, extend_lens)]), hereinafter referred to as se. We selected a specific se value as the threshold: for a given scenario, if its se is ≥ the threshold, MLA will most likely take more time than MHA, so we choose MHA; otherwise, we choose MLA. We further counted the distribution of each sample under this selection rule and obtained the following data, where:

  • se refers to sum([s*e for s, e in zip(seq_lens, extend_lens)]). The se threshold is defined for scenarios with specified batch size, fixed sum of prefix lens, and fixed sum of extend lens; when se ≥ this threshold, MHA is selected, otherwise MLA is selected.

  • ≥se threshold: filtered samples: samples with se ≥ threshold, (number of filtered samples where MLA time cost >= MHA time cost) / (total number of filtered samples:). A larger value indicates a higher probability that MLA is more time-consuming under this threshold scenario.
    <se threshold: filtered samples: samples with se < threshold, (number of filtered samples where MLA time cost >= MHA time cost) / (total number of filtered samples:).. A smaller value indicates fewer "missed cases" (i.e., fewer instances where MLA is unexpectedly slower despite se being below the threshold).

bs prefix extend se threshold >=se threshold <se threshold
1 0 1536 0 1/1 0/0
1 1 2304 0 1/1 0/0
1 256 2048 0 1/1 0/0
1 512 2048 0 1/1 0/0
1 1024 1792 0 1/1 0/0
1 1280 1536 0 1/1 0/0
1 1536 1536 0 1/1 0/0
1 1792 1280 0 1/1 0/0
1 2048 1280 0 1/1 0/0
2 0 1536 1948448 85/89 2/411
2 0 1792 1989320 253/263 18/237
2 0 2048 2097152 500/500 0/0
2 2 2560 4921032 143/147 8/353
2 2 2816 5013202 238/246 12/254
2 2 3072 4938146 386/397 8/103
2 2 3328 5541120 500/500 0/0
2 256 2304 4721960 88/91 7/409
2 256 2560 4723080 209/226 8/274
2 256 2816 4583928 352/369 10/131
2 256 3072 5105376 500/500 0/0
2 2048 1536 4193602 36/37 2/463
2 2048 1792 4421660 137/140 12/360
2 2048 2048 4286264 305/326 45/174
2 2048 2304 4585024 491/493 2/7
8 0 2560 1858268 64/69 12/431
8 0 2816 1780646 188/204 14/296
8 0 3072 1744434 339/356 19/144
8 0 3328 1629902 491/492 1/8
8 8 4096 4841674 59/65 1/435
8 8 4352 5053346 13/14 0/86
8 8 4608 5121876 26/27 2/73
8 8 4864 4874954 63/65 1/35
8 8 5120 4899366 78/78 0/22
8 8 5376 4845480 87/88 1/12
8 8 5632 4842674 99/99 0/1
8 256 3840 5095241 28/28 3/472
8 256 4096 4908994 50/56 4/444
8 256 4352 4801450 19/20 1/80
8 256 4608 4890369 48/49 2/51
8 256 4864 4756875 65/66 2/34
8 256 5120 4697939 87/87 0/13
8 256 5376 4859800 89/89 1/11
8 256 5632 4698181 99/99 0/1
8 2048 3328 4593762 22/24 4/476
8 2048 3584 4637970 55/63 7/437
8 2048 3840 4654369 113/117 9/383
8 2048 4096 4546091 219/231 17/269
8 2048 4352 4496188 75/76 3/24
8 2048 4608 4562169 88/88 3/12
8 2048 4864 4423998 98/98 1/2

Modifications

  • Added a switch SGL_CHUNKED_PREFIX_CACHE_USE_TUNED; the feature can be enabled by setting it to 1 or true, with a default value of false.

  • The decision on whether to select MHA is determined based on information including batch size, the sum of sequence lengths (seq lens), the sum of extension lengths (extend lens), and the total sum of all seq_len*extend_len values. The batch size range covered is 1–8; for cases where the actual batch size exceeds 8, the configuration for batch size 8 will be used.

  • The ragged logic of flashinfer has been fine-tuned: init_forward_metadata now only performs paged prefill initialization, while ragged initialization is triggered at runtime based on whether the MHA chunk feature is enabled. Since the logic for enabling the MHA chunk feature resides in the DeepSeek code, enabling it in advance during the init_forward_metadata initialization process requires moving the MHA chunk enablement logic out of DeepSeek, which currently involves a certain amount of refactoring work.

Accuracy Tests

# gsm8k
# flashinfer
Accuracy: 0.954
Invalid: 0.000
Latency: 337.475 s
Output throughput: 375.755 token/s

# fa3
Accuracy: 0.955
Invalid: 0.000
Latency: 318.394 s
Output throughput: 401.848 token/s
# mmlu
# flashinfer
subject: abstract_algebra, #q:100, acc: 0.720
subject: anatomy, #q:135, acc: 0.852
subject: astronomy, #q:152, acc: 0.941
subject: business_ethics, #q:100, acc: 0.880
subject: clinical_knowledge, #q:265, acc: 0.928
subject: college_biology, #q:144, acc: 0.972
subject: college_chemistry, #q:100, acc: 0.630
subject: college_computer_science, #q:100, acc: 0.830
subject: college_mathematics, #q:100, acc: 0.770
subject: college_medicine, #q:173, acc: 0.879
subject: college_physics, #q:102, acc: 0.824
subject: computer_security, #q:100, acc: 0.880
subject: conceptual_physics, #q:235, acc: 0.923
subject: econometrics, #q:114, acc: 0.746
subject: electrical_engineering, #q:145, acc: 0.876
subject: elementary_mathematics, #q:378, acc: 0.937
subject: formal_logic, #q:126, acc: 0.794
subject: global_facts, #q:100, acc: 0.660
subject: high_school_biology, #q:310, acc: 0.961
subject: high_school_chemistry, #q:203, acc: 0.862
subject: high_school_computer_science, #q:100, acc: 0.950
subject: high_school_european_history, #q:165, acc
subject: high_school_geography, #q:198, acc: 0.965
subject: high_school_government_and_politics, #q:193, acc: 0.984
subject: high_school_macroeconomics, #q:390, acc: 0.915
subject: high_school_mathematics, #q:270, acc: 0.748
subject: high_school_microeconomics, #q:238, acc: 0.966
subject: high_school_physics, #q:151, acc: 0.841
subject: high_school_psychology, #q:545, acc: 0.972
subject: high_school_statistics, #q:216, acc: 0.856
subject: high_school_us_history, #q:204, acc: 0.956
subject: high_school_world_history, #q:237, acc: 0.937
subject: human_aging, #q:223, acc: 0.857
subject: human_sexuality, #q:131, acc: 0.931
subject: international_law, #q:121, acc: 0.942
subject: jurisprudence, #q:108, acc: 0.917
subject: logical_fallacies, #q:163, acc: 0.920
subject: machine_learning, #q:112, acc: 0.777
subject: management, #q:103, acc: 0.922
subject: marketing, #q:234, acc: 0.949
subject: medical_genetics, #q:100, acc: 0.950
subject: miscellaneous, #q:783, acc: 0.954
subject: moral_disputes, #q:346, acc: 0.876
subject: moral_scenarios, #q:895, acc: 0.782
subject: nutrition, #q:306, acc: 0.912
subject: philosophy, #q:311, acc: 0.900
subject: prehistory, #q:324, acc: 0.944
subject: professional_accounting, #q:282, acc: 0.869
subject: professional_law, #q:1534, acc: 0.702
subject: professional_medicine, #q:272, acc: 0.960
subject: professional_psychology, #q:612, acc: 0.913
subject: public_relations, #q:110, acc: 0.827
subject: security_studies, #q:245, acc: 0.886
subject: sociology, #q:201, acc: 0.965
subject: us_foreign_policy, #q:100, acc: 0.920
subject: virology, #q:166, acc: 0.584
subject: world_religions, #q:171, acc: 0.924
Total latency: 538.643
Average accuracy: 0.871

# fa3
subject: abstract_algebra, #q:100, acc: 0.750
subject: anatomy, #q:135, acc: 0.859
subject: astronomy, #q:152, acc: 0.941
subject: business_ethics, #q:100, acc: 0.880
subject: clinical_knowledge, #q:265, acc: 0.925
subject: college_biology, #q:144, acc: 0.958
subject: college_chemistry, #q:100, acc: 0.620
subject: college_computer_science, #q:100, acc: 0.840
subject: college_mathematics, #q:100, acc: 0.790
subject: college_medicine, #q:173, acc: 0.879
subject: college_physics, #q:102, acc: 0.824
subject: computer_security, #q:100, acc: 0.880
subject: conceptual_physics, #q:235, acc: 0.923
subject: econometrics, #q:114, acc: 0.754
subject: electrical_engineering, #q:145, acc: 0.876
subject: elementary_mathematics, #q:378, acc: 0.939
subject: formal_logic, #q:126, acc: 0.794
subject: global_facts, #q:100, acc: 0.680
subject: high_school_biology, #q:310, acc: 0.961
subject: high_school_chemistry, #q:203, acc: 0.862
subject: high_school_computer_science, #q:100, acc: 0.960
subject: high_school_european_history, #q:165, acc: 0.885
subject: high_school_geography, #q:198, acc: 0.965
subject: high_school_government_and_politics, #q:193, acc: 0.984
subject: high_school_macroeconomics, #q:390, acc: 0.921
subject: high_school_mathematics, #q:270, acc: 0.741
subject: high_school_microeconomics, #q:238, acc: 0.966
subject: high_school_physics, #q:151, acc: 0.848
subject: high_school_psychology, #q:545, acc: 0.969
subject: high_school_statistics, #q:216, acc: 0.866
subject: high_school_us_history, #q:204, acc: 0.951
subject: high_school_world_history, #q:237, acc: 0.941
subject: human_aging, #q:223, acc: 0.839
subject: human_sexuality, #q:131, acc: 0.931
subject: international_law, #q:121, acc: 0.959
subject: jurisprudence, #q:108, acc: 0.907
subject: logical_fallacies, #q:163, acc: 0.920
subject: machine_learning, #q:112, acc: 0.795
subject: management, #q:103, acc: 0.942
subject: marketing, #q:234, acc: 0.944
subject: medical_genetics, #q:100, acc: 0.960
subject: miscellaneous, #q:783, acc: 0.953
subject: moral_disputes, #q:346, acc: 0.879
subject: moral_scenarios, #q:895, acc: 0.770
subject: nutrition, #q:306, acc: 0.908
subject: philosophy, #q:311, acc: 0.897
subject: prehistory, #q:324, acc: 0.938
subject: professional_accounting, #q:282, acc: 0.876
subject: professional_law, #q:1534, acc: 0.705
subject: professional_medicine, #q:272, acc: 0.960
subject: professional_psychology, #q:612, acc: 0.913
subject: public_relations, #q:110, acc: 0.818
subject: security_studies, #q:245, acc: 0.898
subject: sociology, #q:201, acc: 0.975
subject: us_foreign_policy, #q:100, acc: 0.940
subject: virology, #q:166, acc: 0.584
subject: world_religions, #q:171, acc: 0.930
Total latency: 546.428
Average accuracy: 0.871

Benchmarking and Profiling

export SGL_ENABLE_JIT_DEEPGEMM=1
export TORCHINDUCTOR_CACHE_DIR=/home/admin/inductor_root_cache
export SGLANG_TORCH_PROFILER_DIR=/home/admin/torch_profiler

export SGL_CHUNKED_PREFIX_CACHE_USE_TUNED=1  # this PR
attn_backend=flashinfer # fa3

model_path=/home/models/deepseek-ai__DeepSeek-R1

python3 -m sglang.launch_server --model-path $model_path \
--host 0.0.0.0 --port 8000 --trust-remote-code  \
--enable-cache-report --quantization fp8 --log-level info \
--max-running-requests 16 \
--mem-fraction-static 0.92 --chunked-prefill-size 8192 
--context-length 65535 --chat-template /home/nas/r1.jinja \
--attention-backend ${attn_backend} \
--tp-size 8 --enable-metrics --cuda-graph-max-bs 16
input_len=1000  # 2000, 4000
python3 -m sglang.bench_serving --backend sglang --dataset-name random \
--random-input ${input_len} --random-output 1 --request-rate 1000 \
--num-prompt 500 --random-range-ratio 1 --max-concurrency 16  --port 8000 
--dataset-path /home/ShareGPT_V3_unfiltered_cleaned_split.json

Request throughput (req/s) :

Input Length Before PR (flashinfer) After PR (flashinfer) Before PR (fa3) After PR (fa3)
1000 11.88 12.38 11.86 12.35
2000 5.54 6.25 5.55 6.26
4000 2.36 2.97 2.37 2.98

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @xu-yfei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the MHA chunked prefix feature by implementing an optimized configuration file approach for both fa3 and flashinfer attention backends. This optimization refines the strategy for selecting between Multi-Head Attention (MHA) and Multi-Query Attention (MLA), leading to notable improvements in Time To First Token (TTFT) across various input lengths. The changes include a new tuning script to generate these configurations and their integration into the system for dynamic, performance-driven backend selection.

Highlights

  • Optimized MHA Chunked Prefix: Introduced an optimized configuration file approach for the MHA chunked prefix feature, enhancing the decision-making between MHA and MLA backends.
  • Performance Improvement: Achieved significant reductions in Time To First Token (TTFT) during stress testing: 4% for 1000 input length, 11.3% for 2000, and 20.5% for 4000.
  • New Tuning Script: Added a new tuning script (tune_mha_chunk_prefix.py) to generate and validate these optimized configurations, ensuring adaptability to different hardware and workloads.
  • Dynamic Backend Selection: Integrated the tuned configurations into the system, allowing dynamic selection of MHA or MLA based on a calculated se (sum of extend * (extend + prefix)) threshold.
  • FlashInfer Backend Refinement: Streamlined the FlashInfer MLA backend by removing redundant use_ragged flags and simplifying prefill logic, contributing to cleaner and potentially more efficient code.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Aug 24, 2025

@Fridge003 Could you please help review this PR?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized configuration-based approach for selecting between MHA and MLA attention mechanisms in the fa3 and flashinfer backends, aiming to reduce Time To First Token (TTFT). A tuning script is provided to generate these configurations. My review focuses on the correctness of the new tuning script and the impact of the changes on existing backend logic. I've identified a potential bug in the tuning script that may incorrectly restrict its usage on certain hardware, and a potential performance regression in the flashinfer backend due to the removal of an optimization path. Overall, the changes are well-structured and the data-driven approach to performance tuning is commendable.

@xu-yfei xu-yfei force-pushed the xyf/tune_mha branch 2 times, most recently from 1e3b6be to 0bb5127 Compare September 12, 2025 02:30
@Alcanderian
Copy link
Copy Markdown
Collaborator

Alcanderian commented Sep 17, 2025

IMO the point is tuning out the overhead and the performance scaling factor(considering fa kernels efficiency) of MLA and MHA for each device. (O1/O2 for overheads and S1/S2 for scaling)
And the real costs are

MHA + Extact KV:
extend * (extend + prefix) * 320 * 2 * S1 + prefix * 65536 * 2 + O1
MLA:
extend * (extend + prefix) * 1088 * 2 * S2 + O2

And then simplified S = S1/S2, O=O1-O2

MHA + Extact KV:
extend * (extend + prefix) * 320 * 2 * S + prefix * 65536 * 2 + O
MLA:
extend * (extend + prefix) * 1088 * 2

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Sep 26, 2025

@Fridge003 @Alcanderian replaced by PR #10953:
For the fa3 and flashinfer backends, when seq_lens ≤ 128K, the fused prefix and extended kv perform only one MHA computation. Performance is optimized through fused operators to avoid multiple copies and type conversions. MHA performance is generally better than MLA performance. It is recommended to directly switch to MHA via SGL_CHUNKED_PREFIX_CACHE_THRESHOLD=0, which makes this PR unnecessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants