Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 41 additions & 49 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple

from vllm.engine.ray_utils import RayWorkerWrapper, ray
Expand Down Expand Up @@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
VLLM_INSTANCE_ID = get_vllm_instance_id()

# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}])
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)

Expand All @@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs

init_worker_all_kwargs = []

# Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
init_worker_all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
Expand Down Expand Up @@ -265,40 +263,40 @@ def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any]] = None,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[List[Any]]] = None,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
Comment on lines 263 to 271
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, it's better to document in the doc string to note several usages:

  1. only args/kwargs , all workers share the same args
  2. only args/kwargs + driver_args/driver_kwargs , all workers except for the driver share the same args
  3. all_args/all_kwargs , specify each arg/kwargs for each worker

use_dummy_driver is orthogonal to the above, indicating whether to use dummy driver worker.

As this function goes more and more complicated, I'm thinking of simplifying the usecases. Maybe only 1 and 3 are allowed.

use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
all_args and all_kwargs are used to pass heterogeneous arguments,
i.e. different arguments for each worker.
"""Runs the given method on all workers. Can be used in the following
ways:

- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# for mypy type checking
assert driver_args is not None
assert driver_kwargs is not None
if all_args is None:
all_args = [driver_args] + [args] * len(self.workers)
if all_kwargs is None:
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)

# for mypy type checking
assert all_args is not None
assert all_kwargs is not None

if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")

if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)

if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
Expand All @@ -310,22 +308,17 @@ def _run_workers(
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_args[1:], all_kwargs[1:])
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *all_args[0], **all_kwargs[0])
method, *driver_args, **driver_kwargs)
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *all_args[0], **all_kwargs[0]))
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
Expand Down Expand Up @@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self):

class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)

async def _run_workers_async(
self,
method: str,
Expand All @@ -399,13 +396,8 @@ async def _run_workers_async(
if driver_kwargs is None:
driver_kwargs = kwargs

# Run the driver worker asynchronously.
def helper():
return self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)

driver_executor = make_async(helper)
coros.append(driver_executor())
coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))

# Run the ray workers asynchronously.
for worker in self.workers:
Expand Down
9 changes: 4 additions & 5 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def __init__(self,
self.worker_class_name = worker_class_name
self.worker = None

def update_environment_variables(self, envs: Dict[str, str]) -> None:
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
Expand Down Expand Up @@ -138,10 +139,8 @@ def init_worker(self, *args, **kwargs):

def execute_method(self, method, *args, **kwargs):
try:
if hasattr(self, method):
executor = getattr(self, method)
else:
executor = getattr(self.worker, method)
target = self if self.worker is None else self.worker
executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
Expand Down