Skip to content

Conversation

@juliendenize
Copy link
Contributor

@juliendenize juliendenize commented Nov 17, 2025

Purpose

This is a hotfix to launch vllm with multiple nodes and Ray.

To reproduce on a 2x8 (nodes x gpus) Ray cluster:

vllm serve mistralai/Mistral-Small-3.2-24B-Instruct-2506   --tokenizer_mode mistral --config_format mistral   --load_format mistral --tool-call-parser mistral   --enable-auto-tool-choice --limit-mm-per-prompt '{"image":10}'   --tensor-parallel-size 16   --max_model_len 65536   --max_num_seqs 128   --distributed-executor-backend ray --enforce_eager

Error:

(EngineCore_DP0 pid=570282)   File "/home/julien.denize/workspace/vllm/vllm/v1/worker/gpu_worker.py", line 212, in init_device
(EngineCore_DP0 pid=570282)     assert self.parallel_config.local_world_size <= visible_device_count, (
(EngineCore_DP0 pid=570282)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=570282) AssertionError: local_world_size (16) must be less than or equal to the number of visible devices (8).

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Julien Denize <[email protected]>
@mergify mergify bot added the v1 label Nov 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a crash when running vLLM with Ray on multiple nodes. The crash is caused by an assertion that checks if the number of workers on a node exceeds the number of available GPUs. The fix in this PR is to move this assertion inside an if block that is skipped for Ray, which resolves the immediate issue. However, my review finds that this change is too broad and incorrectly disables this important sanity check for other valid configurations, which could lead to other crashes. I've provided a critical comment explaining the issue and suggesting a more targeted fix.

Comment on lines +209 to +216
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
f"({visible_device_count})."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change moves the assertion for local_world_size inside the if block that handles a specific single-node data parallelism setup. While this fixes the issue for multi-node Ray where local_world_size might be miscalculated, it incorrectly disables this important sanity check for other valid configurations, such as multi-node setups without Ray or single-node setups without data parallelism.

The assertion self.parallel_config.local_world_size <= visible_device_count is a general check to ensure that the number of workers on a node does not exceed the number of available GPUs. It should not be confined to the specific data parallelism case.

A more targeted fix would be to skip this check only for Ray, or to fix the underlying issue with the calculation of local_world_size for Ray environments. Disabling this check for all other configurations could hide potential resource allocation issues and lead to crashes in other scenarios.

@njhill
Copy link
Member

njhill commented Nov 18, 2025

Thanks @juliendenize... looks like this was introduced by #23691? cc @luccafong

@njhill njhill changed the title Hotfix: ray with multiple nodes [BugFix] Ray with multiple nodes Nov 18, 2025
@njhill njhill requested a review from luccafong November 18, 2025 03:52
@youkaichao youkaichao enabled auto-merge (squash) November 18, 2025 13:53
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 18, 2025
Copy link
Collaborator

@luccafong luccafong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for the fix!

@youkaichao youkaichao merged commit cdeec2e into vllm-project:main Nov 19, 2025
45 checks passed
khluu pushed a commit that referenced this pull request Nov 19, 2025
Signed-off-by: Julien Denize <[email protected]>
(cherry picked from commit cdeec2e)
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
Signed-off-by: Julien Denize <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Nov 27, 2025
…sible device count error (#4457)

### What this PR does / why we need it?
Fix the ray start failed bug: local_world_size cannot little than
visible device count error
detail see issue #4456.

This fix code is copied from vllm fixing modify, PR:
[#28873](vllm-project/vllm#28873)


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: leo-pony <[email protected]>
@ortegaalfredo
Copy link

Still not working for me:

(APIServer pid=3167402) Value error, Tensor parallel size (10) cannot be larger than the number of available GPUs (8). [type=value_error, input_value=ArgsKwargs((), {'pipeline...'_api_process_rank': 0}), input_type=ArgsKwargs]
(APIServer pid=3167402) For further information visit https://errors.pydantic.dev/2.12/v/value_error

However, this is the output of ray status showing 12 GPUs available:

ray status
======== Autoscaler status: 2025-11-28 07:01:38.452724 ========
Node status

Active:
1 node_5bc3651f81e58183465936ce830e1197646fbeca8a72b261d79a0b17
1 node_6e0cdeabd00bec7a6039f1745ddbeb6aa2a6592fcd4feb7e971b2bdb
Pending:
(no pending nodes)
Recent failures:
(no failures)

Resources

Total Usage:
0.0/40.0 CPU
0.0/12.0 GPU

devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
@devonthomas35
Copy link

Still not working for me either

ChenCangtao pushed a commit to ChenCangtao/vllm-ascend that referenced this pull request Dec 3, 2025
…sible device count error (vllm-project#4457)

### What this PR does / why we need it?
Fix the ray start failed bug: local_world_size cannot little than
visible device count error
detail see issue vllm-project#4456.

This fix code is copied from vllm fixing modify, PR:
[#28873](vllm-project/vllm#28873)


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: leo-pony <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 4, 2025
…sible device count error (vllm-project#4457)

### What this PR does / why we need it?
Fix the ray start failed bug: local_world_size cannot little than
visible device count error
detail see issue vllm-project#4456.

This fix code is copied from vllm fixing modify, PR:
[#28873](vllm-project/vllm#28873)

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: leo-pony <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 4, 2025
…sible device count error (vllm-project#4457)

### What this PR does / why we need it?
Fix the ray start failed bug: local_world_size cannot little than
visible device count error
detail see issue vllm-project#4456.

This fix code is copied from vllm fixing modify, PR:
[#28873](vllm-project/vllm#28873)

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: leo-pony <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
Signed-off-by: Julien Denize <[email protected]>
Signed-off-by: Xingyu Liu <[email protected]>
Meihan-chen pushed a commit to Meihan-chen/vllm-ascend that referenced this pull request Dec 5, 2025
…sible device count error (vllm-project#4457)

### What this PR does / why we need it?
Fix the ray start failed bug: local_world_size cannot little than
visible device count error
detail see issue vllm-project#4456.

This fix code is copied from vllm fixing modify, PR:
[#28873](vllm-project/vllm#28873)


- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

---------

Signed-off-by: leo-pony <[email protected]>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Dec 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants