-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Ray backend] Better error when pg topology is bad. #7584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
16bcb3f
7f75862
3a5ab31
59cdddb
0f8655f
9ba1f73
17d4385
c361edd
30e0df2
c832faf
a85dcbc
87654a6
a8e5aca
a3f9719
5ef0512
53848df
59f8e2b
2c91954
f5772fc
a01a1ff
09f841a
196d857
4802308
8d6b00a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """Compare the outputs of HF and distributed vLLM when using greedy sampling. | ||
| Run: | ||
| ```sh | ||
| cd $VLLM_PATH/tests | ||
| pytest distributed/test_multi_node.py | ||
| ``` | ||
| """ | ||
| import os | ||
|
|
||
| import pytest | ||
| import ray | ||
| from ray.cluster_utils import Cluster | ||
|
|
||
| from vllm.utils import cuda_device_count_stateless | ||
|
|
||
| TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") | ||
|
|
||
|
|
||
| @pytest.mark.skipif(cuda_device_count_stateless() < 2, | ||
| reason="Need at least 2 GPUs to run the test.") | ||
| @pytest.mark.parametrize("model, distributed_executor_backend, test_suite", [ | ||
| ("facebook/opt-125m", "ray", "L4"), | ||
| ]) | ||
| def test_multi_node_bad_topology( | ||
|
||
| vllm_runner, | ||
| model: str, | ||
| distributed_executor_backend: str, | ||
| test_suite: str, | ||
| ) -> None: | ||
| """Verify ray + multi node's bad topology raises an exception. | ||
| This test simulates multi node ray cluster, so we don't have to start | ||
| real 2 multi nodes. | ||
| There are 2 potential bad issues. | ||
| - the engine's node doesn't have enough GPUs. | ||
| - the tensor parallel size exceeds the available GPUs in a current node. | ||
| """ | ||
| dtype = "half" | ||
| assert test_suite == TARGET_TEST_SUITE | ||
|
|
||
| # Simulate 2 node clusters, 1 GPU each. | ||
| cluster = Cluster() | ||
| head_node = cluster.add_node(num_cpus=8, num_gpus=1, resources={"head": 1}) | ||
|
||
| ray.init(address=head_node.address) | ||
| cluster.add_node(num_cpus=8, num_gpus=1) | ||
|
|
||
| # Creating tp == 2. Since TP workers are supposed to spread to 2 workers | ||
| # it should log warning. | ||
| with vllm_runner( | ||
| model, | ||
| dtype=dtype, | ||
| tensor_parallel_size=2, | ||
| distributed_executor_backend=distributed_executor_backend) as _: | ||
| pass | ||
|
|
||
| # Simulate there's no GPU in a current node. | ||
| @ray.remote(num_gpus=1, resources={"head": 1}) | ||
| class Actor: | ||
| pass | ||
|
|
||
| # a is created on a head node. | ||
| a = Actor.remote() # type: ignore | ||
| ray.get(a.__ray_ready__.remote()) | ||
|
|
||
| # Now vLLM is created on a head node, but there's no GPU. It should raise | ||
| # an exception. | ||
| with pytest.raises(RuntimeError), vllm_runner( | ||
| model, | ||
| dtype=dtype, | ||
| tensor_parallel_size=1, | ||
| distributed_executor_backend=distributed_executor_backend) as _: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,10 @@ | ||
| from typing import List, Optional, Tuple, Union | ||
| import time | ||
| from collections import defaultdict | ||
| from typing import Dict, List, Optional, Tuple, Union | ||
|
|
||
| from ray._private.state import available_resources_per_node | ||
| from ray.util import placement_group_table | ||
| from ray.util.placement_group import PlacementGroup | ||
|
|
||
| from vllm.config import ParallelConfig | ||
| from vllm.logger import init_logger | ||
|
|
@@ -8,6 +14,7 @@ | |
| from vllm.worker.worker_base import WorkerWrapperBase | ||
|
|
||
| logger = init_logger(__name__) | ||
| PG_WAIT_TIMEOUT = 1800 | ||
|
|
||
| try: | ||
| import ray | ||
|
|
@@ -83,6 +90,88 @@ def assert_ray_available(): | |
| "`pip install ray`.") from ray_import_err | ||
|
|
||
|
|
||
| def _verify_bundles(placement_group: PlacementGroup, | ||
| parallel_config: ParallelConfig): | ||
| """Verify a given placement group has bundles located in the right place. | ||
|
|
||
| There are 2 rules. | ||
| - Warn if all tensor parallel workers cannot fit in a single node. | ||
| - Fail if driver node is not included in a placement group. | ||
| """ | ||
| assert ray.is_initialized(), ( | ||
| "Ray is not initialized although distributed-executor-backend is ray.") | ||
| pg_data = placement_group_table(placement_group) | ||
| # bundle_idx -> node_id | ||
| bundle_to_node_ids = pg_data["bundles_to_node_id"] | ||
| # bundle_idx -> bundle (e.g., {"GPU": 1}) | ||
| bundles = pg_data["bundles"] | ||
| # node_id -> List of bundle (e.g., {"GPU": 1}) | ||
| node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) | ||
|
|
||
| for bundle_idx, node_id in bundle_to_node_ids.items(): | ||
| node_id_to_bundle[node_id].append(bundles[bundle_idx]) | ||
| driver_node_id = ray.get_runtime_context().get_node_id() | ||
|
|
||
| if driver_node_id not in node_id_to_bundle: | ||
| raise RuntimeError( | ||
| f"driver node id {driver_node_id} is not included in a placement " | ||
| f"group {placement_group.id}. Node id -> bundles " | ||
| f"{node_id_to_bundle}. " | ||
| "You don't have enough GPUs available in a current node. Check " | ||
| "`ray status` to see if you have available GPUs in a node " | ||
| f"{driver_node_id} before starting an vLLM engine.") | ||
|
|
||
| for node_id, bundles in node_id_to_bundle.items(): | ||
| if len(bundles) < parallel_config.tensor_parallel_size: | ||
| logger.warning( | ||
| "tensor_parallel_size=%d " | ||
| "is smaller than the reserved number of GPUs ({len(bundles)} " | ||
| "GPUs) in a node %s. Tensor parallel workers can be " | ||
| "spread out to 2 nodes which can degrade the performance. " | ||
rkooo567 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "To resolve this issue, make sure you have more than " | ||
| "%d GPUs available at each node.", | ||
| parallel_config.tensor_parallel_size, node_id, | ||
| parallel_config.tensor_parallel_size) | ||
|
|
||
|
|
||
| def _wait_until_pg_ready(current_placement_group: PlacementGroup): | ||
| """Wait until a placement group is ready. | ||
|
|
||
| It prints the informative log messages if the placement group is | ||
| not created within time. | ||
|
|
||
| """ | ||
| # Wait until PG is ready - this will block until all | ||
| # requested resources are available, and will timeout | ||
| # if they cannot be provisioned. | ||
| placement_group_specs = current_placement_group.bundle_specs | ||
|
|
||
| s = time.time() | ||
| ref = current_placement_group.ready() | ||
| wait_interval = 10 | ||
| while time.time() - s < PG_WAIT_TIMEOUT: | ||
| ready, _ = ray.wait([ref], timeout=wait_interval) | ||
| if len(ready) > 0: | ||
| break | ||
|
|
||
| # Exponential backoff for warning print. | ||
| wait_interval *= 2 | ||
| logger.info( | ||
| "Waiting for creating a placement group of specs for " | ||
| "%d seconds. specs=%s. Check " | ||
| "`ray status` to see if you have enough resources.", | ||
| int(time.time() - s), placement_group_specs) | ||
|
|
||
| try: | ||
| ray.get(current_placement_group.ready(), timeout=0) | ||
| except ray.exceptions.GetTimeoutError: | ||
| raise ValueError( | ||
| "Cannot provide a placement group of " | ||
| f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " | ||
| "`ray status` to make sure the cluster has enough resources." | ||
| ) from None | ||
|
|
||
|
|
||
| def initialize_ray_cluster( | ||
| parallel_config: ParallelConfig, | ||
| ray_address: Optional[str] = None, | ||
|
|
@@ -141,15 +230,32 @@ def initialize_ray_cluster( | |
| f"The number of required {device_str}s exceeds the total " | ||
| f"number of available {device_str}s in the placement group.") | ||
| # Create a new placement group | ||
| placement_group_specs = ([{ | ||
| device_str: 1 | ||
| }] * parallel_config.world_size) | ||
| placement_group_specs: List[Dict[str, float]] = ([{ | ||
| device_str: 1.0 | ||
| } for _ in range(parallel_config.world_size)]) | ||
|
|
||
| # vLLM engine is also a worker to execute model with an accelerator, | ||
| # so it requires to have the device in a current node. Check if | ||
| # the current node has at least one device. | ||
| current_ip = get_ip() | ||
| current_node_id = ray.get_runtime_context().get_node_id() | ||
| current_node_resource = available_resources_per_node()[current_node_id] | ||
| if current_node_resource.get(device_str) < 1: | ||
| raise ValueError( | ||
| f"Current node has no {device_str} available. " | ||
| f"{current_node_resource=}. vLLM engine cannot start without " | ||
| f"{device_str}. Make sure you have at least 1 {device_str} " | ||
| f"available in a node {current_node_id=} {current_ip=}.") | ||
| # This way, at least bundle is required to be created in a current | ||
| # node. | ||
| placement_group_specs[0][f"node:{current_ip}"] = 0.001 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the key to make sure current node is included in the placement group?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes that's right. I failed to find other clean way to support this unfortunately... |
||
|
|
||
| # By default, Ray packs resources as much as possible. | ||
| current_placement_group = ray.util.placement_group( | ||
| placement_group_specs) | ||
| # Wait until PG is ready - this will block until all | ||
| # requested resources are available, and will timeout | ||
| # if they cannot be provisioned. | ||
| ray.get(current_placement_group.ready(), timeout=1800) | ||
| placement_group_specs, strategy="PACK") | ||
| _wait_until_pg_ready(current_placement_group) | ||
|
|
||
| assert current_placement_group is not None | ||
| _verify_bundles(current_placement_group, parallel_config) | ||
rkooo567 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Set the placement group in the parallel config | ||
| parallel_config.placement_group = current_placement_group | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please clean up the tests and remove unnecessary code. this test does not need
TARGET_TEST_SUITEI think.