22import os
33import pickle
44from collections import defaultdict
5+ from itertools import islice , repeat
56from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple
67
78from vllm .engine .ray_utils import RayWorkerWrapper , ray
@@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
136137 VLLM_INSTANCE_ID = get_vllm_instance_id ()
137138
138139 # Set environment variables for the driver and workers.
139- all_args_to_update_environment_variables = []
140- for (node_id , _ ) in worker_node_and_gpu_ids :
141- all_args_to_update_environment_variables .append ([{
142- "CUDA_VISIBLE_DEVICES" :
143- "," .join (map (str , node_gpus [node_id ])),
144- "VLLM_INSTANCE_ID" :
145- VLLM_INSTANCE_ID ,
146- "VLLM_TRACE_FUNCTION" :
147- os .getenv ("VLLM_TRACE_FUNCTION" , "0" ),
148- }])
140+ all_args_to_update_environment_variables = [({
141+ "CUDA_VISIBLE_DEVICES" :
142+ "," .join (map (str , node_gpus [node_id ])),
143+ "VLLM_INSTANCE_ID" :
144+ VLLM_INSTANCE_ID ,
145+ "VLLM_TRACE_FUNCTION" :
146+ os .getenv ("VLLM_TRACE_FUNCTION" , "0" ),
147+ }, ) for (node_id , _ ) in worker_node_and_gpu_ids ]
149148 self ._run_workers ("update_environment_variables" ,
150149 all_args = all_args_to_update_environment_variables )
151150
@@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs):
156155 # avoid writing `{"name": value}` manually
157156 return kwargs
158157
159- init_worker_all_kwargs = []
160-
161158 # Initialize the actual workers inside worker wrapper.
162- for rank , (node_id , _ ) in enumerate (worker_node_and_gpu_ids , ):
159+ init_worker_all_kwargs = []
160+ for rank , (node_id , _ ) in enumerate (worker_node_and_gpu_ids ):
163161 local_rank = node_workers [node_id ].index (rank )
164162 init_worker_all_kwargs .append (
165163 collect_arg_helper_func (
@@ -265,40 +263,40 @@ def _run_workers(
265263 self ,
266264 method : str ,
267265 * args ,
268- driver_args : Optional [Tuple [Any ]] = None ,
266+ driver_args : Optional [Tuple [Any , ... ]] = None ,
269267 driver_kwargs : Optional [Dict [str , Any ]] = None ,
270- all_args : Optional [List [List [Any ]]] = None ,
268+ all_args : Optional [List [Tuple [Any , ... ]]] = None ,
271269 all_kwargs : Optional [List [Dict [str , Any ]]] = None ,
272270 use_dummy_driver : bool = False ,
273271 max_concurrent_workers : Optional [int ] = None ,
274272 use_ray_compiled_dag : bool = False ,
275273 ** kwargs ,
276274 ) -> Any :
277- """Runs the given method on all workers.
278- all_args and all_kwargs are used to pass heterogeneous arguments,
279- i.e. different arguments for each worker.
275+ """Runs the given method on all workers. Can be used in the following
276+ ways:
277+
278+ - args/kwargs: All workers share the same args/kwargs
279+ - args/kwargs and driver_args/driver_kwargs: Driver worker has
280+ different args
281+ - all_args/all_kwargs: args/kwargs for each worker are specified
282+ individually
280283 """
281- if driver_args is None :
282- driver_args = args
283- if driver_kwargs is None :
284- driver_kwargs = kwargs
285-
286- # for mypy type checking
287- assert driver_args is not None
288- assert driver_kwargs is not None
289- if all_args is None :
290- all_args = [driver_args ] + [args ] * len (self .workers )
291- if all_kwargs is None :
292- all_kwargs = [driver_kwargs ] + [kwargs ] * len (self .workers )
293-
294- # for mypy type checking
295- assert all_args is not None
296- assert all_kwargs is not None
297284
298285 if max_concurrent_workers :
299286 raise NotImplementedError (
300287 "max_concurrent_workers is not supported yet." )
301288
289+ if driver_args is None :
290+ driver_args = args if all_args is None else all_args [0 ]
291+ if driver_kwargs is None :
292+ driver_kwargs = kwargs if all_kwargs is None else all_kwargs [0 ]
293+
294+ count = len (self .workers )
295+ all_worker_args = repeat (args , count ) if all_args is None \
296+ else islice (all_args , 1 , None )
297+ all_worker_kwargs = repeat (kwargs , count ) if all_kwargs is None \
298+ else islice (all_kwargs , 1 , None )
299+
302300 if use_ray_compiled_dag :
303301 # Right now, compiled DAG can only accept a single
304302 # input. TODO(sang): Fix it.
@@ -310,22 +308,17 @@ def _run_workers(
310308 worker .execute_method .remote (method , * worker_args ,
311309 ** worker_kwargs )
312310 for (worker , worker_args , worker_kwargs
313- ) in zip (self .workers , all_args [ 1 :], all_kwargs [ 1 :] )
311+ ) in zip (self .workers , all_worker_args , all_worker_kwargs )
314312 ]
315313
316- if driver_args is None :
317- driver_args = args
318- if driver_kwargs is None :
319- driver_kwargs = kwargs
320-
321314 # Start the driver worker after all the ray workers.
322315 if not use_dummy_driver :
323316 driver_worker_output = self .driver_worker .execute_method (
324- method , * all_args [ 0 ] , ** all_kwargs [ 0 ] )
317+ method , * driver_args , ** driver_kwargs )
325318 else :
326319 driver_worker_output = ray .get (
327320 self .driver_dummy_worker .execute_method .remote (
328- method , * all_args [ 0 ] , ** all_kwargs [ 0 ] ))
321+ method , * driver_args , ** driver_kwargs ))
329322 # Get the results of the ray workers.
330323 if self .workers :
331324 if use_ray_compiled_dag :
@@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self):
383376
384377class RayGPUExecutorAsync (RayGPUExecutor , ExecutorAsyncBase ):
385378
379+ def __init__ (self , * args , ** kwargs ):
380+ super ().__init__ (* args , ** kwargs )
381+ self .driver_executor = make_async (self .driver_worker .execute_method )
382+
386383 async def _run_workers_async (
387384 self ,
388385 method : str ,
@@ -399,13 +396,8 @@ async def _run_workers_async(
399396 if driver_kwargs is None :
400397 driver_kwargs = kwargs
401398
402- # Run the driver worker asynchronously.
403- def helper ():
404- return self .driver_worker .execute_method (method , * driver_args ,
405- ** driver_kwargs )
406-
407- driver_executor = make_async (helper )
408- coros .append (driver_executor ())
399+ coros .append (
400+ self .driver_executor (method , * driver_args , ** driver_kwargs ))
409401
410402 # Run the ray workers asynchronously.
411403 for worker in self .workers :
0 commit comments