Skip to content

Support MLA in Torch Native Attention Backend#3475

Open
YangQun1 wants to merge 11 commits intosgl-project:mainfrom
YangQun1:dev/torch-native-attn-mla-support
Open

Support MLA in Torch Native Attention Backend#3475
YangQun1 wants to merge 11 commits intosgl-project:mainfrom
YangQun1:dev/torch-native-attn-mla-support

Conversation

@YangQun1
Copy link
Copy Markdown
Contributor

Motivation

Modifications

Checklist

  • Format your code according to the Code Formatting with Pre-Commit.
  • Add unit tests as outlined in the Running Unit Tests.
  • Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.

@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch from e18840e to 69f6898 Compare February 11, 2025 00:27
@YangQun1 YangQun1 marked this pull request as ready for review February 11, 2025 00:35
@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch from ba44fb1 to fd8b47b Compare February 11, 2025 00:48
@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch from f150931 to 78f6c02 Compare February 11, 2025 12:30
@YangQun1
Copy link
Copy Markdown
Contributor Author

Hi @ispobock , could you help to review?

@ispobock
Copy link
Copy Markdown
Collaborator

Could you fix the pr test and provide some benchmark data vs previous version?

@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch from 503ba06 to 77cccf5 Compare February 13, 2025 01:43
@YangQun1
Copy link
Copy Markdown
Contributor Author

Hi @ispobock, the failed test seems to be unrelated to this PR change. Is there any way to retrigger failed test to avoid flaky error?

FAIL: test_gsm8k (__main__.TestW8A8)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/public_sglang_ci/runner-b-gpu-67/_work/sglang/sglang/test/srt/test_w8a8_quantization.py", line 45, in test_gsm8k
    self.assertGreater(metrics["accuracy"], 0.7)
AssertionError: 0.68 not greater than 0.7

@YangQun1
Copy link
Copy Markdown
Contributor Author

YangQun1 commented Feb 13, 2025

The performance comparison by using TestTorchNativeAttnBackend::test_latency

  • before change:
Warmup ...
Prefill. latency: 0.09715 s, throughput:   1317.49 token/s
Decode.  latency: 0.02727 s, throughput:     36.67 token/s
Decode.  latency: 0.02193 s, throughput:     45.60 token/s
Decode.  latency: 0.02186 s, throughput:     45.74 token/s
Decode.  latency: 0.02186 s, throughput:     45.74 token/s
Decode.  latency: 0.02188 s, throughput:     45.70 token/s
Decode.  median latency: 0.02188 s, median throughput:     45.70 token/s
Total. latency:  0.256 s, throughput:    531.85 token/s
Benchmark ...
Prefill. latency: 0.02665 s, throughput:   4803.48 token/s
Decode.  latency: 0.02181 s, throughput:     45.86 token/s
Decode.  latency: 0.02164 s, throughput:     46.22 token/s
Decode.  latency: 0.02168 s, throughput:     46.12 token/s
Decode.  latency: 0.02165 s, throughput:     46.18 token/s
Decode.  latency: 0.02164 s, throughput:     46.22 token/s
Decode.  median latency: 0.02168 s, median throughput:     46.12 token/s
Total. latency:  0.179 s, throughput:    761.47 token/s
  • after change:
Warmup ...
Prefill. latency: 0.09651 s, throughput:   1326.30 token/s
Decode.  latency: 0.90841 s, throughput:      1.10 token/s
Decode.  latency: 0.02256 s, throughput:     44.32 token/s
Decode.  latency: 0.02215 s, throughput:     45.14 token/s
Decode.  latency: 0.02197 s, throughput:     45.52 token/s
Decode.  latency: 0.02222 s, throughput:     45.00 token/s
Decode.  median latency: 0.02222 s, median throughput:     45.00 token/s
Total. latency:  1.138 s, throughput:    119.49 token/s
Benchmark ...
Prefill. latency: 0.02040 s, throughput:   6275.52 token/s
Decode.  latency: 0.02217 s, throughput:     45.11 token/s
Decode.  latency: 0.02213 s, throughput:     45.19 token/s
Decode.  latency: 0.02201 s, throughput:     45.43 token/s
Decode.  latency: 0.02194 s, throughput:     45.57 token/s
Decode.  latency: 0.02206 s, throughput:     45.34 token/s
Decode.  median latency: 0.02206 s, median throughput:     45.34 token/s
Total. latency:  0.175 s, throughput:    778.87 token/s

It seems that the decode perf is impacted, I will investigate it.

@YangQun1
Copy link
Copy Markdown
Contributor Author

Compared to latest main branch, the decode perf has no obvious gap, the prefill perf improved.

  • main branch:
Warmup ...
Prefill. latency: 0.10436 s, throughput:   1226.48 token/s
Decode.  latency: 0.92203 s, throughput:      1.08 token/s
Decode.  latency: 0.02368 s, throughput:     42.22 token/s
Decode.  latency: 0.02328 s, throughput:     42.96 token/s
Decode.  latency: 0.02306 s, throughput:     43.37 token/s
Decode.  latency: 0.02324 s, throughput:     43.02 token/s
Decode.  median latency: 0.02328 s, median throughput:     42.96 token/s
Total. latency:  1.166 s, throughput:    116.63 token/s
Benchmark ...
Prefill. latency: 0.02907 s, throughput:   4403.07 token/s
Decode.  latency: 0.02334 s, throughput:     42.84 token/s
Decode.  latency: 0.02322 s, throughput:     43.06 token/s
Decode.  latency: 0.02301 s, throughput:     43.47 token/s
Decode.  latency: 0.02306 s, throughput:     43.36 token/s
Decode.  latency: 0.02321 s, throughput:     43.09 token/s
Decode.  median latency: 0.02320 s, median throughput:     43.10 token/s
Total. latency:  0.191 s, throughput:    711.22 token/s
  • this PR:
Warmup ...
Prefill. latency: 0.11434 s, throughput:   1119.49 token/s
Decode.  latency: 0.91263 s, throughput:      1.10 token/s
Decode.  latency: 0.02263 s, throughput:     44.18 token/s
Decode.  latency: 0.02238 s, throughput:     44.69 token/s
Decode.  latency: 0.02231 s, throughput:     44.81 token/s
Decode.  latency: 0.02233 s, throughput:     44.78 token/s
Decode.  median latency: 0.02238 s, median throughput:     44.69 token/s
Total. latency:  1.161 s, throughput:    117.10 token/s
Benchmark ...
Prefill. latency: 0.02052 s, throughput:   6238.77 token/s
Decode.  latency: 0.02227 s, throughput:     44.91 token/s
Decode.  latency: 0.02228 s, throughput:     44.88 token/s
Decode.  latency: 0.02226 s, throughput:     44.92 token/s
Decode.  latency: 0.02230 s, throughput:     44.84 token/s
Decode.  latency: 0.02238 s, throughput:     44.68 token/s
Decode.  median latency: 0.02227 s, median throughput:     44.90 token/s
Total. latency:  0.177 s, throughput:    770.39 token/s

@ispobock

@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch 4 times, most recently from 7ce307f to 7cd5375 Compare February 13, 2025 12:46
@ispobock
Copy link
Copy Markdown
Collaborator

Hi @YangQun1, I reviewed this PR but not sure why this change is related to MLA?

@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch from 7cd5375 to 7ac8da3 Compare February 14, 2025 07:40
@YangQun1
Copy link
Copy Markdown
Contributor Author

YangQun1 commented Feb 14, 2025

Hi @YangQun1, I reviewed this PR but not sure why this change is related to MLA?

With this PR, we can run DeepSeek-V2-Lite model with torch native backend while not setting --disable-mla flag.

@ispobock
Copy link
Copy Markdown
Collaborator

With this PR, we can run DeepSeek-V2-Lite model with torch native backend while not setting --disable-mla flag.

Got it. This change is mainly for the forward_normal part, the kv is different from the kv cache.

@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch 3 times, most recently from 2ade106 to 298fbed Compare February 17, 2025 02:57
@YangQun1
Copy link
Copy Markdown
Contributor Author

Hi @ispobock , ci tests passed, could you help to merge?

@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch 5 times, most recently from c3534a5 to adfdd51 Compare August 14, 2025 05:24
@yanbing-j yanbing-j force-pushed the dev/torch-native-attn-mla-support branch 2 times, most recently from ede4af2 to a48615d Compare August 25, 2025 04:48
@Alcanderian Alcanderian added ready-to-merge The PR is ready to merge after the CI is green. and removed ready-to-merge The PR is ready to merge after the CI is green. labels Aug 27, 2025
@YangQun1 YangQun1 force-pushed the dev/torch-native-attn-mla-support branch from a48615d to d123570 Compare September 18, 2025 12:35
@yanbing-j yanbing-j force-pushed the dev/torch-native-attn-mla-support branch from d123570 to 7b5cad9 Compare September 22, 2025 03:13
@yanbing-j yanbing-j force-pushed the dev/torch-native-attn-mla-support branch from 95bb18f to f29cf10 Compare September 22, 2025 05:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants