Skip to content

Conversation

@luccafong
Copy link
Collaborator

@luccafong luccafong commented Aug 26, 2025

Purpose

Support MP Executor for multi node distributed inference when no ray setup. (Compatible with DP hybrid/internal/external lb)

  • How to use:
    vllm serve "model_name" -tp <TP_Size> -dp <DP_Size> -pp <PP_Size> --nnodes <# nodes> --node-rank <rank of node> --master-addr <leader_host_ip> [--master-port <port>] [--headless (if not exposing API endpoint)>] [--data-parallel-external-lb (for external lb) or --data-parallel-hybrid-lb]
    See concrete examples in test plans

Note for DP Compatibility:

  • We don't implement new DP coordinator/backend in this PR, we reuse what we already have.
  • This PR support auto inference of --data-parallel-rank --data-parallel-local-size --data-parallel-start-rank based on -distributed-node-rank for all 3 dp lb modes, and it works the same with using them explicitly before.
  • For engines that do not expose api server (e.g. leader or external dp engine), we need to add --headless (it will auto start a headless DP engine if it is leader(driver) of a DP group; or start a headless Executor Instance if it is non-driver workers)
  • This is backward compatible with existing interface for other data parallel.

Architecture change

image

Test Plan

TP=4 (2 Instances )

CUDA_VISIBLE_DEVICES=0,1 vllm serve "Qwen/Qwen3-1.7B" -tp=4 --max-model-len=32768 --nnodes 2 --node-rank 0 --master-addr 127.0.0.1 --port 8000 > /tmp/test_rank0.log 2>&1 &

CUDA_VISIBLE_DEVICES=2,3 vllm serve "Qwen/Qwen3-1.7B" -tp=4 --max-model-len=32768 --nnodes 2 --node-rank 1 --master-addr 127.0.0.1 --headless > /tmp/test_rank1.log 2>&1 &

PP=2 x TP=2 (2 Instances)

CUDA_VISIBLE_DEVICES=0,1 vllm serve "Qwen/Qwen3-1.7B" -pp=2 -tp=2 --max-model-len=32768 --nnodes 2 --node-rank 0 --master-addr 127.0.0.1 --port 8000 > /tmp/test_rank0.log 2>&1 &

CUDA_VISIBLE_DEVICES=2,3 vllm serve "Qwen/Qwen3-1.7B" -pp=2 -tp=2 --max-model-len=32768 --nnodes 2 --node-rank 1 --master-addr 127.0.0.1 --headless > /tmp/test_rank1.log 2>&1 &

DP 2 (External) * TP 4 (2 in node, 2 across node) (4 Instances)

CUDA_VISIBLE_DEVICES=0,1 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 0 --master-addr 127.0.0.1 --port 8000 --data-parallel-external-lb > /tmp/test_dp0.log 2>&1 &

CUDA_VISIBLE_DEVICES=2,3 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 1 --master-addr 127.0.0.1 --headless > /tmp/test_dp1.log 2>&1 &

CUDA_VISIBLE_DEVICES=4,5 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 2 --master-addr 127.0.0.1 --port 8001 --data-parallel-external-lb > /tmp/test_dp2.log 2>&1 &

CUDA_VISIBLE_DEVICES=6,7 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 3 --master-addr 127.0.0.1 --headless > /tmp/test_dp3.log 2>&1 &

DP 2 (Internal with headless) * TP 4 (2 in node, 2 across node) (4 Instances)

CUDA_VISIBLE_DEVICES=0,1 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 0 --master-addr 127.0.0.1 --port 8000 > /tmp/test_dp0.log 2>&1 &

CUDA_VISIBLE_DEVICES=2,3 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 1 --master-addr 127.0.0.1 --headless > /tmp/test_dp1.log 2>&1 &

CUDA_VISIBLE_DEVICES=4,5 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 2 --master-addr 127.0.0.1 --headless > /tmp/test_dp2.log 2>&1 &

CUDA_VISIBLE_DEVICES=6,7 vllm serve "Qwen/Qwen3-1.7B" -dp=2 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 3 --master-addr 127.0.0.1 --headless > /tmp/test_dp3.log 2>&1 &

DP * 4 (Internal) * TP 2 (inner node) (4 Instances)

CUDA_VISIBLE_DEVICES=0,1 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 0 --master-addr 127.0.0.1 --port 8000 > /tmp/test_dp0.log 2>&1 &

CUDA_VISIBLE_DEVICES=2,3 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=4 --max-model-len=32768 --nnodes 4 --node-rank 1 --master-addr 127.0.0.1 --headless > /tmp/test_dp1.log 2>&1 &

CUDA_VISIBLE_DEVICES=4,5 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 2 --master-addr 127.0.0.1 --headless > /tmp/test_dp2.log 2>&1 &

CUDA_VISIBLE_DEVICES=6,7 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 3 --master-addr 127.0.0.1 --headless > /tmp/test_dp3.log 2>&1 &

DP * 4 (External) * TP 2 (inner node) (4 Instances)

CUDA_VISIBLE_DEVICES=0,1 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 0 --master-addr 127.0.0.1 --port 8000 --data-parallel-external-lb > /tmp/test_dp0.log 2>&1 &

CUDA_VISIBLE_DEVICES=2,3 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 1 --master-addr 127.0.0.1 --port 8001 --data-parallel-external-lb > /tmp/test_dp1.log 2>&1 &

CUDA_VISIBLE_DEVICES=4,5 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 2 --master-addr 127.0.0.1 --port 8002 --data-parallel-external-lb > /tmp/test_dp2.log 2>&1 &

CUDA_VISIBLE_DEVICES=6,7 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 4 --node-rank 3 --master-addr 127.0.0.1 --port 8003 --data-parallel-external-lb > /tmp/test_dp3.log 2>&1 &

DP * 2(external) * 2(internal) Hybrid. * TP * 2 (2 Instances)

CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 2 --node-rank 0 --master-addr 127.0.0.1 --data-parallel-hybrid-lb > /tmp/test_dp0.log 2>&1 &

CUDA_VISIBLE_DEVICES=4,5,6,7 vllm serve "Qwen/Qwen3-1.7B" -dp=4 -tp=2 --max-model-len=32768 --nnodes 2 --node-rank 1 --master-addr 127.0.0.1 --data-parallel-hybrid-lb --port 8002 > /tmp/test_dp1.log 2>&1 &

DP * 2 (internal) * PP * 2 * TP * 2(across nodes) (8 Instances)

for i in {0..7}; do
     HEADLESS=""
     if [ $i -gt 0 ]; then HEADLESS="--headless";fi
    CUDA_VISIBLE_DEVICES=$i vllm serve Qwen/Qwen3-1.7B -dp=2  -pp=2 -tp=2 --max-model-len=32768 -n 8 -r  $i --master-addr 127.0.0.1  --port 8000 $HEADLESS > /tmp/test_mp${i}.log 2>&1 &
done

Test Result

Eval tested on all above combinations and on par with baseline, below are some examples.

TP Eval

lm_eval --model local-completions --tasks gsm8k     --model_args model=Qwen/Qwen3-1.7B,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=200,max_retries=3,tokenized_requests=False     --limit 1000
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.681 ± 0.0147
strict-match 5 exact_match 0.678 ± 0.0148

Baseline

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.681 ± 0.0147
strict-match 5 exact_match 0.679 ± 0.0148

TP Perf

vllm bench serve --model Qwen/Qwen3-1.7B  --port 8003  --dataset-name random  --ignore-eos  --num-prompts 512  --request-rate inf  --random-input-len 1024  --random-output-len 128  --max-concurrency 128
============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             128
Benchmark duration (s):                  5.22
Total input tokens:                      522598
Total generated tokens:                  65536
Request throughput (req/s):              98.05
Output token throughput (tok/s):         12550.62
Total Token throughput (tok/s):          112631.95
---------------Time to First Token----------------
Mean TTFT (ms):                          225.67
Median TTFT (ms):                        235.68
P99 TTFT (ms):                           517.28
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.39
Median TPOT (ms):                        8.42
P99 TPOT (ms):                           9.68
---------------Inter-token Latency----------------

Baseline

============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             128
Benchmark duration (s):                  5.17
Total input tokens:                      522598
Total generated tokens:                  65536
Request throughput (req/s):              99.08
Output token throughput (tok/s):         12682.43
Total Token throughput (tok/s):          113814.84
---------------Time to First Token----------------
Mean TTFT (ms):                          229.18
Median TTFT (ms):                        231.62
P99 TTFT (ms):                           510.13
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.25
Median TPOT (ms):                        8.33
P99 TPOT (ms):                           9.50
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.25
Median ITL (ms):                         6.32
P99 ITL (ms):                            31.12
==================================================

TPx PP Eval

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.689 ± 0.0146
strict-match 5 exact_match 0.686 ± 0.0147

TPx PP Perf

Perf (Multi Instance)

============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             128
Benchmark duration (s):                  7.90
Total input tokens:                      522598
Total generated tokens:                  65536
Request throughput (req/s):              64.77
Output token throughput (tok/s):         8291.08
Total Token throughput (tok/s):          74405.90
---------------Time to First Token----------------
Mean TTFT (ms):                          146.42
Median TTFT (ms):                        123.13
P99 TTFT (ms):                           406.15
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.30
Median TPOT (ms):                        14.43
P99 TPOT (ms):                           15.01
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.30
Median ITL (ms):                         12.82
P99 ITL (ms):                            43.65
==================================================

Baseline (Single Instance)

============ Serving Benchmark Result ============
Successful requests:                     512
Maximum request concurrency:             128
Benchmark duration (s):                  7.76
Total input tokens:                      522598
Total generated tokens:                  65536
Request throughput (req/s):              66.02
Output token throughput (tok/s):         8450.47
Total Token throughput (tok/s):          75836.33
---------------Time to First Token----------------
Mean TTFT (ms):                          144.65
Median TTFT (ms):                        105.95
P99 TTFT (ms):                           406.44
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.02
Median TPOT (ms):                        14.20
P99 TPOT (ms):                           14.63
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.02
Median ITL (ms):                         12.54
P99 ITL (ms):                            42.69
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Aug 26, 2025
@facebook-github-bot
Copy link

@luccafong has imported this pull request. If you are a Meta employee, you can view this in D81078278.

@facebook-github-bot
Copy link

@luccafong has imported this pull request. If you are a Meta employee, you can view this in D81078278.

@luccafong luccafong force-pushed the lucia/non-ray-di branch 2 times, most recently from 68187de to 5493fee Compare August 26, 2025 22:23
@facebook-github-bot
Copy link

@luccafong has imported this pull request. If you are a Meta employee, you can view this in D81078278.

@facebook-github-bot
Copy link

@luccafong has imported this pull request. If you are a Meta employee, you can view this in D81078278.

@mergify
Copy link

mergify bot commented Aug 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Aug 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Lu Fang <[email protected]>
@luccafong
Copy link
Collaborator Author

@njhill @youkaichao I have addressed comments and args, retested all combinations on lm_eval.

Signed-off-by: Lu Fang <[email protected]>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @luccafong. A few small things...

Signed-off-by: Lu Fang <[email protected]>
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline synced with @youkaichao and @njhill, and aligned.

Let's land this for now.

@houseroad houseroad added ready-for-merge Indicate this PR is ready to be merged by the maintainers, used by reviewers without merge access. ready ONLY add when PR is ready to merge/full CI is needed labels Nov 16, 2025
@houseroad houseroad enabled auto-merge (squash) November 16, 2025 05:16
@houseroad houseroad merged commit b316ac6 into vllm-project:main Nov 16, 2025
53 checks passed
bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
…roject#23691)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Bram Wasti <[email protected]>
@njhill njhill mentioned this pull request Nov 18, 2025
5 tasks
node_rank: int = 0
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to add nnodes as an attribute to the very general ParallelConfig class where it is just meaningful when distributed_executor_backend is mp. Many AI labs rely heavily on RAY and when the distibuted executor backend is ray, nnodes will always be wrong here is it'll stay 1 even if tensor_parallel is set to something like 16 on 8xH200.

It's not intuitive when doing self.vllm_config.parallel_config.nnodes to get 1 here for ray backend and it's not enough to just state in the comment that it's only for mp. IMO we need to make sure that nnodes always displays the correct number of nodes no matter the parallel backend. Similarly node_rank also needs to make sense for ray. If it's impossible to put good values for ray we should at least prefix it with nnodes_for_mp_backend

Copy link
Collaborator Author

@luccafong luccafong Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nnodes_for_mp_backend might be pretty hard to use, does set them to None be a better way to resolve the conflicts and confusion? and we can raise error if user set it for ray. @patrickvonplaten

wangxiyuan added a commit to vllm-project/vllm-ascend that referenced this pull request Nov 26, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Co-authored-by: 22dimensions <[email protected]>
Co-authored-by: shen-shanshan <[email protected]>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: hfadzxy <[email protected]>
Signed-off-by: leo-pony <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Kurumi5210 pushed a commit to lidenghui1110/vllm-ascend that referenced this pull request Nov 26, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Co-authored-by: 22dimensions <[email protected]>
Co-authored-by: shen-shanshan <[email protected]>

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: hfadzxy <[email protected]>
Signed-off-by: leo-pony <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Signed-off-by: Kurumi5210 <[email protected]>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
…roject#23691)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…roject#23691)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Nov 29, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Co-authored-by: 22dimensions <[email protected]>
Co-authored-by: shen-shanshan <[email protected]>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: hfadzxy <[email protected]>
Signed-off-by: leo-pony <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
…roject#23691)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
…roject#23691)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Xingyu Liu <[email protected]>
Meihan-chen pushed a commit to Meihan-chen/vllm-ascend that referenced this pull request Dec 5, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Co-authored-by: 22dimensions <[email protected]>
Co-authored-by: shen-shanshan <[email protected]>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
Signed-off-by: hfadzxy <[email protected]>
Signed-off-by: leo-pony <[email protected]>
Co-authored-by: MengqingCao <[email protected]>
Co-authored-by: hfadzxy <[email protected]>
Co-authored-by: leo-pony <[email protected]>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Dec 6, 2025
…roject#23691)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend ready ONLY add when PR is ready to merge/full CI is needed ready-for-merge Indicate this PR is ready to be merged by the maintainers, used by reviewers without merge access. v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants