Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ steps:
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
- pytest -v -s distributed/test_multi_node_topology.py

- label: Pipeline Parallelism Test # 23min
working_dir: "/vllm-workspace/tests"
Expand Down
75 changes: 75 additions & 0 deletions tests/distributed/test_multi_node_topology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
Run:
```sh
cd $VLLM_PATH/tests
pytest distributed/test_multi_node.py
```
"""
import os

import pytest
import ray
from ray.cluster_utils import Cluster

from vllm.utils import cuda_device_count_stateless

TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please clean up the tests and remove unnecessary code. this test does not need TARGET_TEST_SUITE I think.



@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model, distributed_executor_backend, test_suite", [
("facebook/opt-125m", "ray", "L4"),
])
def test_multi_node_bad_topology(
Copy link
Member

@youkaichao youkaichao Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should test it with 2 nodes test (where ray cluster is already created), and try to launch the test with 1 gpu from both the head node and the worker node, and make sure the vllm instance is scheduled with the same node as the process who launched it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah I think we need real 2 nodes if we want to use node ip resources. let me fix it.

vllm_runner,
model: str,
distributed_executor_backend: str,
test_suite: str,
) -> None:
"""Verify ray + multi node's bad topology raises an exception.
This test simulates multi node ray cluster, so we don't have to start
real 2 multi nodes.
There are 2 potential bad issues.
- the engine's node doesn't have enough GPUs.
- the tensor parallel size exceeds the available GPUs in a current node.
"""
dtype = "half"
assert test_suite == TARGET_TEST_SUITE

# Simulate 2 node clusters, 1 GPU each.
cluster = Cluster()
head_node = cluster.add_node(num_cpus=8, num_gpus=1, resources={"head": 1})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why num_cpus=8 here? is 8 some magic number?

We don't have any guarentee on the number of GPUs we have for this 2 GPUs test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ray doesn't require real hardware cpus when specifying resources like this. 8 is just a random number. (I think technically num_cpus=1 should also work)

ray.init(address=head_node.address)
cluster.add_node(num_cpus=8, num_gpus=1)

# Creating tp == 2. Since TP workers are supposed to spread to 2 workers
# it should log warning.
with vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend) as _:
pass

# Simulate there's no GPU in a current node.
@ray.remote(num_gpus=1, resources={"head": 1})
class Actor:
pass

# a is created on a head node.
a = Actor.remote() # type: ignore
ray.get(a.__ray_ready__.remote())

# Now vLLM is created on a head node, but there's no GPU. It should raise
# an exception.
with pytest.raises(RuntimeError), vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=1,
distributed_executor_backend=distributed_executor_backend) as _:
pass
124 changes: 115 additions & 9 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from typing import List, Optional, Tuple, Union
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

from ray._private.state import available_resources_per_node
from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup

from vllm.config import ParallelConfig
from vllm.logger import init_logger
Expand All @@ -8,6 +14,7 @@
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
PG_WAIT_TIMEOUT = 1800

try:
import ray
Expand Down Expand Up @@ -83,6 +90,88 @@ def assert_ray_available():
"`pip install ray`.") from ray_import_err


def _verify_bundles(placement_group: PlacementGroup,
parallel_config: ParallelConfig):
"""Verify a given placement group has bundles located in the right place.

There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
"""
assert ray.is_initialized(), (
"Ray is not initialized although distributed-executor-backend is ray.")
pg_data = placement_group_table(placement_group)
# bundle_idx -> node_id
bundle_to_node_ids = pg_data["bundles_to_node_id"]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)

for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
driver_node_id = ray.get_runtime_context().get_node_id()

if driver_node_id not in node_id_to_bundle:
raise RuntimeError(
f"driver node id {driver_node_id} is not included in a placement "
f"group {placement_group.id}. Node id -> bundles "
f"{node_id_to_bundle}. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f"{driver_node_id} before starting an vLLM engine.")

for node_id, bundles in node_id_to_bundle.items():
if len(bundles) < parallel_config.tensor_parallel_size:
logger.warning(
"tensor_parallel_size=%d "
"is smaller than the reserved number of GPUs ({len(bundles)} "
"GPUs) in a node %s. Tensor parallel workers can be "
"spread out to 2 nodes which can degrade the performance. "
"To resolve this issue, make sure you have more than "
"%d GPUs available at each node.",
parallel_config.tensor_parallel_size, node_id,
parallel_config.tensor_parallel_size)


def _wait_until_pg_ready(current_placement_group: PlacementGroup):
"""Wait until a placement group is ready.

It prints the informative log messages if the placement group is
not created within time.

"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs = current_placement_group.bundle_specs

s = time.time()
ref = current_placement_group.ready()
wait_interval = 10
while time.time() - s < PG_WAIT_TIMEOUT:
ready, _ = ray.wait([ref], timeout=wait_interval)
if len(ready) > 0:
break

# Exponential backoff for warning print.
wait_interval *= 2
logger.info(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.",
int(time.time() - s), placement_group_specs)

try:
ray.get(current_placement_group.ready(), timeout=0)
except ray.exceptions.GetTimeoutError:
raise ValueError(
"Cannot provide a placement group of "
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
"`ray status` to make sure the cluster has enough resources."
) from None


def initialize_ray_cluster(
parallel_config: ParallelConfig,
ray_address: Optional[str] = None,
Expand Down Expand Up @@ -141,15 +230,32 @@ def initialize_ray_cluster(
f"The number of required {device_str}s exceeds the total "
f"number of available {device_str}s in the placement group.")
# Create a new placement group
placement_group_specs = ([{
device_str: 1
}] * parallel_config.world_size)
placement_group_specs: List[Dict[str, float]] = ([{
device_str: 1.0
} for _ in range(parallel_config.world_size)])

# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip = get_ip()
current_node_id = ray.get_runtime_context().get_node_id()
current_node_resource = available_resources_per_node()[current_node_id]
if current_node_resource.get(device_str) < 1:
raise ValueError(
f"Current node has no {device_str} available. "
f"{current_node_resource=}. vLLM engine cannot start without "
f"{device_str}. Make sure you have at least 1 {device_str} "
f"available in a node {current_node_id=} {current_ip=}.")
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs[0][f"node:{current_ip}"] = 0.001
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the key to make sure current node is included in the placement group?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that's right. I failed to find other clean way to support this unfortunately...


# By default, Ray packs resources as much as possible.
current_placement_group = ray.util.placement_group(
placement_group_specs)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
placement_group_specs, strategy="PACK")
_wait_until_pg_ready(current_placement_group)

assert current_placement_group is not None
_verify_bundles(current_placement_group, parallel_config)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group