Skip to content

[WIP]Optimize TritonAttention with cache load#9778

Open
yuan-luo wants to merge 3 commits intosgl-project:mainfrom
antgroup:opt_triton_attn
Open

[WIP]Optimize TritonAttention with cache load#9778
yuan-luo wants to merge 3 commits intosgl-project:mainfrom
antgroup:opt_triton_attn

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Aug 29, 2025

Motivation

In Triton, cache_modifier is used to specify how the memory loading cache strategy should be handled. .cg (cache) is a memory access optimization option provided by Triton. This PR applies it to enable caching during data loading, allowing the loaded data to better utilize the GPU's cache mechanism, thereby improving performance.

This PR by the way resolved the regression issue in unit test for triton attn swa test due to code refactor.

The benchmark result shows in Triton Attention (wo setting window_size, WINDOW_SIZE=-1 in table below) it gains 17-20% speedup. 4k input, 1.5k output, e2e TTFT reduce 3.6%.

Main:
$python ./benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py
extend_attention_triton_vs_torch:
     N_CTX  WINDOW_SIZE       torch      triton
0   1024.0         -1.0   13.313979    3.400983
1   1024.0        127.0   13.998181    2.839544
2   1024.0        256.0   13.671337    3.277899
3   1024.0        512.0   13.484036    3.724898
4   2048.0         -1.0   37.498367   12.688069
5   2048.0        127.0   39.144402    8.056528
6   2048.0        256.0   38.922625    9.115718
7   2048.0        512.0   38.468863   10.897383
8   4096.0         -1.0  136.244034   49.277617
9   4096.0        127.0  141.462463   24.554945
10  4096.0        256.0  143.137283   26.608192
11  4096.0        512.0  140.915131   30.392427
12  8192.0         -1.0  404.721649  185.647873
13  8192.0        127.0  423.228149   79.989822
14  8192.0        256.0  421.111450   84.352928
15  8192.0        512.0  419.148163   92.420738

============================================================

This PR:
$python ./benchmark/kernels/sliding_window_attention_triton/bench_triton_swa_kernel.py
extend_attention_triton_vs_torch:
     N_CTX  WINDOW_SIZE       torch      triton
0   1024.0         -1.0   13.216082    2.986481
1   1024.0        127.0   13.694473    2.850151
2   1024.0        256.0   13.617353    3.294159
3   1024.0        512.0   13.488297    3.724024
4   2048.0         -1.0   37.403521   10.948128
5   2048.0        127.0   39.125616    8.041440
6   2048.0        256.0   39.439072    9.110515
7   2048.0        512.0   38.453457   10.890365
8   4096.0         -1.0  136.164291   41.194319
9   4096.0        127.0  141.763000   24.516776
10  4096.0        256.0  140.950470   26.599690
11  4096.0        512.0  139.955719   30.417290
12  8192.0         -1.0  403.730072  148.042236
13  8192.0        127.0  421.817413   79.692383
14  8192.0        256.0  420.936523   84.381027
15  8192.0        512.0  419.162079   92.444572

Modifications

Accuracy Tests

throughput gsm8k
$python3 -m sglang.launch_server --model /home/admin/Qwen3-30B-A3B --tp-size 4 --port 30000 --attention-backend triton --disable-radix-cache

PR:
$python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1000 --parallel 1000
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:58<00:00, 17.22it/s]
Accuracy: 0.707
Invalid: 0.001
Latency: 58.318 s
Output throughput: 4511.574 token/s

Main:
$python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1000 --parallel 1000
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:58<00:00, 17.03it/s]
Accuracy: 0.706
Invalid: 0.000
Latency: 58.919 s
Output throughput: 4485.227 token/s

Benchmarking and Profiling

4k input, 1.5k output, triton backend e2e TTFT reduce 3.6%.

$python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 100 --random-input-len 4000 --random-output-len 1500 --random-range-ratio 1

TritonAttention main
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     100       
Benchmark duration (s):                  29.57     
Total input tokens:                      400000    
Total generated tokens:                  150000    
Total generated tokens (retokenized):    150000    
Request throughput (req/s):              3.38      
Input token throughput (tok/s):          13527.04  
Output token throughput (tok/s):         5072.64   
Total token throughput (tok/s):          18599.68  
Concurrency:                             99.92     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   29545.74  
Median E2E Latency (ms):                 29546.03  
---------------Time to First Token----------------
Mean TTFT (ms):                          3704.15   
Median TTFT (ms):                        3742.19   
P99 TTFT (ms):                           6994.59   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           17.24     
Median ITL (ms):                         14.72     
P95 ITL (ms):                            16.82     
P99 ITL (ms):                            19.25     
Max ITL (ms):                            6604.85   
==================================================

TritonAttention This PR:
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     100       	
Benchmark duration (s):                  29.03     
Total input tokens:                      400000    
Total generated tokens:                  150000    
Total generated tokens (retokenized):    150000    
Request throughput (req/s):              3.45      
Input token throughput (tok/s):          13781.14  
Output token throughput (tok/s):         5167.93   
Total token throughput (tok/s):          18949.07  
Concurrency:                             99.90     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   28996.56  
Median E2E Latency (ms):                 28996.46  
---------------Time to First Token----------------
Mean TTFT (ms):                          3570.24   
Median TTFT (ms):                        3605.76   
P99 TTFT (ms):                           6689.81   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           16.96     
Median ITL (ms):                         14.82     
P95 ITL (ms):                            16.61     
P99 ITL (ms):                            18.46     
Max ITL (ms):                            6042.77   
==================================================

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the TritonAttention kernel by leveraging Triton's memory caching mechanism. The changes aim to improve performance by enabling efficient data loading and better utilization of GPU caches, resulting in notable speed improvements in attention computations.

Highlights

  • Performance Enhancement: Introduced ".cg" cache modifier to tl.load operations within the Triton attention kernel, optimizing memory access and data loading for improved GPU cache utilization, leading to significant speedups (17-20% for Triton Attention without window_size).
  • Kernel Configuration Adjustment: Increased the num_stages parameter from 1 to 2 in the extend_attention_fwd function, potentially enhancing pipeline efficiency for the Triton kernel.
  • Test Infrastructure Update: Modified test/srt/test_swa_unittest.py to reflect changes in the SWARadixCache import path and updated the SWATokenToKVPoolAllocator constructor call.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yuan-luo yuan-luo changed the title Optimize TritonAttention Optimize TritonAttention with cache load Aug 29, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization to the Triton attention kernel by enabling caching for key/value loads from the buffer, which is well-supported by the provided benchmark results. The change to increase num_stages is also a good complementary optimization. My review includes suggestions to extend this caching strategy to other tensor loads within the kernel, which could potentially yield further performance gains.


grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
num_stages = 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ why did you change the num_stages?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By increasing the pipeline stages of the kernel, the time required for each computation can be reduced. Here by increasing the number of num_stages, different parts of the window computation can be processed in parallel, increasing throughput. It gives Triton more rearrangement space, allowing for overlap between "loading the next column block and computing the current column block."

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Sep 2, 2025

This related CI test failed: https://github.com/sgl-project/sglang/actions/runs/17392068796/job/49367795438?pr=9778#step:5:545

  File "/sglang-checkout/python/sglang/srt/layers/attention/triton_backend.py", line 748, in forward_extend
    self.extend_attention_fwd(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 751, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/sglang-checkout/python/sglang/srt/layers/attention/triton_ops/extend_attention.py", line 471, in extend_attention_fwd
    _fwd_kernel[grid](
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 330, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 653, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
    ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 402, in __getattribute__
    self._init_handles()
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 395, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 212992, Hardware limit: 65536. Reducing block sizes or `num_stages` may help.

@yuan-luo yuan-luo changed the title Optimize TritonAttention with cache load [WIP]Optimize TritonAttention with cache load Sep 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants