diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index f779b0f8a511..d0b5e682bb6f 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -2,6 +2,7 @@ import os import pickle from collections import defaultdict +from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from vllm.engine.ray_utils import RayWorkerWrapper, ray @@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", VLLM_INSTANCE_ID = get_vllm_instance_id() # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [] - for (node_id, _) in worker_node_and_gpu_ids: - all_args_to_update_environment_variables.append([{ - "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])), - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, - "VLLM_TRACE_FUNCTION": - os.getenv("VLLM_TRACE_FUNCTION", "0"), - }]) + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + os.getenv("VLLM_TRACE_FUNCTION", "0"), + }, ) for (node_id, _) in worker_node_and_gpu_ids] self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) @@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs): # avoid writing `{"name": value}` manually return kwargs - init_worker_all_kwargs = [] - # Initialize the actual workers inside worker wrapper. - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): + init_worker_all_kwargs = [] + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): local_rank = node_workers[node_id].index(rank) init_worker_all_kwargs.append( collect_arg_helper_func( @@ -265,40 +263,40 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, - all_args: Optional[List[List[Any]]] = None, + all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: - """Runs the given method on all workers. - all_args and all_kwargs are used to pass heterogeneous arguments, - i.e. different arguments for each worker. + """Runs the given method on all workers. Can be used in the following + ways: + + - args/kwargs: All workers share the same args/kwargs + - args/kwargs and driver_args/driver_kwargs: Driver worker has + different args + - all_args/all_kwargs: args/kwargs for each worker are specified + individually """ - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # for mypy type checking - assert driver_args is not None - assert driver_kwargs is not None - if all_args is None: - all_args = [driver_args] + [args] * len(self.workers) - if all_kwargs is None: - all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers) - - # for mypy type checking - assert all_args is not None - assert all_kwargs is not None if max_concurrent_workers: raise NotImplementedError( "max_concurrent_workers is not supported yet.") + if driver_args is None: + driver_args = args if all_args is None else all_args[0] + if driver_kwargs is None: + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. @@ -310,22 +308,17 @@ def _run_workers( worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_args[1:], all_kwargs[1:]) + ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( - method, *all_args[0], **all_kwargs[0]) + method, *driver_args, **driver_kwargs) else: driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( - method, *all_args[0], **all_kwargs[0])) + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: if use_ray_compiled_dag: @@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self): class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_executor = make_async(self.driver_worker.execute_method) + async def _run_workers_async( self, method: str, @@ -399,13 +396,8 @@ async def _run_workers_async( if driver_kwargs is None: driver_kwargs = kwargs - # Run the driver worker asynchronously. - def helper(): - return self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - - driver_executor = make_async(helper) - coros.append(driver_executor()) + coros.append( + self.driver_executor(method, *driver_args, **driver_kwargs)) # Run the ray workers asynchronously. for worker in self.workers: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 783dff3a4340..bcd04e0f98db 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -108,7 +108,8 @@ def __init__(self, self.worker_class_name = worker_class_name self.worker = None - def update_environment_variables(self, envs: Dict[str, str]) -> None: + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: key = 'CUDA_VISIBLE_DEVICES' if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior @@ -138,10 +139,8 @@ def init_worker(self, *args, **kwargs): def execute_method(self, method, *args, **kwargs): try: - if hasattr(self, method): - executor = getattr(self, method) - else: - executor = getattr(self.worker, method) + target = self if self.worker is None else self.worker + executor = getattr(target, method) return executor(*args, **kwargs) except Exception as e: # if the driver worker also execute methods,