Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Mar 14, 2025

With FlashMLA now being the default on Nvidia GPU and the fact that it seemingly doesnt help the Triton backend anymore/ever-did (bit of a mystery). I think we can go ahead and rip this out reclaiming the memory lost to padding:

VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE=1 VLLM_USE_V1=0 VLLM_USE_FLASHINFER_SAMPLER=1 vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

INFO 03-14 18:57:59 [executor_base.py:111] # cuda blocks: 4739, # CPU blocks: 859
INFO 03-14 18:57:59 [executor_base.py:116] Maximum concurrency for 163840 tokens per request: 1.85x

Data Preview:
  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000     854.219899  0.854220      52.914861     3398.488862
2    vllm          5000           1000     712.867089  0.712867      50.757768    13790.549284
4    vllm         10000           1000     153.534461  0.153534     142.423406    19432.205255
1    vllm         32000           1000      49.799586  0.049800     142.251487   361420.859763


VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE=0 VLLM_USE_V1=0 VLLM_USE_FLASHINFER_SAMPLER=1 vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

INFO 03-14 20:43:52 [executor_base.py:111] # cuda blocks: 5266, # CPU blocks: 954
INFO 03-14 20:43:52 [executor_base.py:116] Maximum concurrency for 163840 tokens per request: 2.06x

Data Preview:
  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000     835.217510  0.835218      51.893893     4393.823143
2    vllm          5000           1000     736.631281  0.736631      51.266942    13791.848651
4    vllm         10000           1000     155.732726  0.155733     139.090127    20673.930987
1    vllm         32000           1000      55.991667  0.055992     139.235539   362350.173109

This PR:

INFO 03-14 21:50:28 [executor_base.py:111] # cuda blocks: 5266, # CPU blocks: 954
INFO 03-14 21:50:28 [executor_base.py:116] Maximum concurrency for 163840 tokens per request: 2.06x

Data Preview:
  backend  input_tokens  output_tokens  output_toks/s     req/s  median_itl_ms  median_ttft_ms
3    vllm          1000           1000     846.707105  0.846707      53.101287     3406.700477
0    vllm          5000              5       1.085574  0.217115      79.162302      848.571256
2    vllm          5000           1000     717.566549  0.717567      53.103782    13764.687492
4    vllm         10000           1000     154.981544  0.154982     140.511256    20659.451817
1    vllm         32000           1000      55.701930  0.055702     139.734178   364505.537944

Signed-off-by: Lucas Wilkinson <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 14, 2025
@WoosukKwon
Copy link
Collaborator

@tylertitsworth @simon-mo Can you take a look? I'm a bit worried that we make this kind of changes back and forth.

@LucasWilkinson Can you please test the accuracy after this change?

@simon-mo
Copy link
Collaborator

Afiak this is getting rid of previous complexity and simplify the codebase

@simon-mo simon-mo added this to the v0.8.0 milestone Mar 15, 2025
@LucasWilkinson
Copy link
Collaborator Author

VLLM_USE_V1=0 vllm serve /home/vllm-dev/DeepSeek-R1 --tensor-parallel-size 8 --trust-remote-code --disable-log-requests

lm_eval --model local-completions --tasks gsm8k --model_args model=/home/vllm-dev/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_requests=False --limit 100

local-completions (model=/home/vllm-dev/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=5,max_retries=3,tokenized_requests=False), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.96|±  |0.0197|
|     |       |strict-match    |     5|exact_match|↑  | 0.96|±  |0.0197|

@WoosukKwon the goal with this and #14770 was to just go back and audit alot of the MLA code with the goal of reducing complexity (like @simon-mo mentioned) since original code was written in a bit of a rush. I definitely echo the sentiment to be biased towards not changing these kinds of things, so definitely open to an extra pair of eyes on it. In this case specifically I do think some un-needed complexity slipped in due to the rushed nature of the MLA code (also makes sense to revisit after the recent updates, namely FlashMLA).

@simon-mo simon-mo enabled auto-merge (squash) March 15, 2025 05:05
@simon-mo simon-mo merged commit 5952d8a into vllm-project:main Mar 15, 2025
44 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
@yuchiwang
Copy link

image
It is a bug for kvcache fp8,model_config.dtype is bf16 but dtype is int8,the key_cache_entry for int8/fp8 should be 768, but now is 640, it will cause oom.

shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants