diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 5ac62f02b99c..2b59fe0fa103 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -6,6 +6,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -23,30 +24,47 @@ def _init_executor(self) -> None: else: self._init_spec_worker() - def _init_non_spec_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, load_config=self.load_config, - local_rank=0, - rank=0, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - is_driver_worker=True, + is_driver_worker=rank == 0, + ) + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + wrapper = WorkerWrapperBase( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + def _init_non_spec_worker(self): + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() self.driver_worker.init_device() self.driver_worker.load_model() @@ -57,41 +75,18 @@ def _init_spec_worker(self): from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.worker.worker import Worker - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - target_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) + target_worker = self._create_worker() - draft_worker = MultiStepWorker( + draft_worker_kwargs = self._get_worker_kwargs() + # Override draft-model specific worker args. + draft_worker_kwargs.update( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, # TODO allow draft-model specific load config. - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, + #load_config=self.load_config, ) + draft_worker = MultiStepWorker(**draft_worker_kwargs) spec_decode_worker = SpecDecodeWorker.from_workers( proposer_worker=draft_worker, scorer_worker=target_worker) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b6bcda4e6b18..496c41697a9c 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -153,29 +153,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - def collect_arg_helper_func(**kwargs): - # avoid writing `{"name": value}` manually - return kwargs - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [] - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): - local_rank = node_workers[node_id].index(rank) - init_worker_all_kwargs.append( - collect_arg_helper_func( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=rank == 0, - )) + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") @@ -200,8 +185,7 @@ def execute_model(self, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output + return all_outputs[0] def _run_workers( self,