diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py new file mode 100644 index 000000000000..794b6306ff57 --- /dev/null +++ b/tests/engine/test_multiproc_workers.py @@ -0,0 +1,176 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.executor.multiproc_worker_utils import (LocalWorkerVllm, + ResultHandler, WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker.Worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[LocalWorkerVllm], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + LocalWorkerVllm(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.process.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4f8295d25cf4..dde79b3f57c5 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -11,7 +11,7 @@ from vllm.sequence import (Logprob, SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.utils import get_distributed_init_method from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker @@ -112,8 +112,7 @@ def create_worker(cls: type, ) engine_config = engine_args.create_engine_config() - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + distributed_init_method = get_distributed_init_method() worker = cls( model_config=engine_config.model_config, diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 1804cf78d800..aab993d7d94e 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -1,7 +1,7 @@ import torch from vllm.engine.arg_utils import EngineArgs -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.utils import get_distributed_init_method from vllm.worker.worker import Worker @@ -15,8 +15,7 @@ def test_swap() -> None: engine_config.cache_config.num_cpu_blocks = 1000 # Create the worker. - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + distributed_init_method = get_distributed_init_method() worker = Worker( model_config=engine_config.model_config, parallel_config=engine_config.parallel_config, diff --git a/vllm/__init__.py b/vllm/__init__.py index 5ca468022759..ca454efd44b2 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -3,8 +3,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.model_executor.models import ModelRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3a2f7db67935..4b007d71e9cf 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_ray_cluster, ray +from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 19e58fb1722c..3d5d9925daa8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -15,8 +15,8 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group -from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -28,7 +28,7 @@ get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import Counter, enable_trace_function_call_for_thread logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -133,6 +133,8 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats + enable_trace_function_call_for_thread() + if not self.model_config.skip_tokenizer_init: self.tokenizer: BaseTokenizerGroup self._init_tokenizer() @@ -287,6 +289,12 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 8d6a1fff91fd..54f984f4bc92 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -8,8 +8,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) +from vllm.utils import get_distributed_init_method, make_async logger = init_logger(__name__) @@ -33,8 +32,7 @@ def _init_worker(self): assert self.parallel_config.world_size == 1, ( "CPUExecutor only supports single CPU socket currently.") - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + distributed_init_method = get_distributed_init_method() self.driver_worker = CPUWorker( model_config=self.model_config, parallel_config=self.parallel_config, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 1839b5603ff3..1838c34be2fd 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -95,6 +95,13 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + class ExecutorAsyncBase(ExecutorBase): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index d413a7d27ff3..bcceac19dde2 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,11 +1,10 @@ -from typing import Dict, List, Set, Tuple +from typing import Any, Dict, List, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) +from vllm.utils import get_distributed_init_method, make_async logger = init_logger(__name__) @@ -23,30 +22,34 @@ def _init_executor(self) -> None: else: self._init_spec_worker() - def _init_non_spec_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( + def _get_worker_kwargs(self, + local_rank: int = 0, + rank: int = 0) -> Dict[str, Any]: + return dict( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, + local_rank=local_rank, + rank=rank, + distributed_init_method=get_distributed_init_method(), lora_config=self.lora_config, vision_language_config=self.vision_language_config, - is_driver_worker=True, + is_driver_worker=rank == 0, ) + + def _create_worker(self, local_rank: int = 0, rank: int = 0): + # Lazy import to avoid CUDA init issues + from vllm.worker.worker import Worker + return Worker(**self._get_worker_kwargs(local_rank, rank)) + + def _init_non_spec_worker(self): + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() self.driver_worker.init_device() self.driver_worker.load_model() @@ -57,41 +60,18 @@ def _init_spec_worker(self): from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.worker.worker import Worker - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - target_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) - draft_worker = MultiStepWorker( - model_config=self.speculative_config.draft_model_config, - parallel_config=self.speculative_config.draft_parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - # TODO allow draft-model specific load config. - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) + target_worker = self._create_worker() + + draft_worker_kwargs = self._get_worker_kwargs() + # Override draft-model specific worker args. + draft_worker_kwargs.update( + dict( + parallel_config=self.speculative_config.draft_parallel_config, + scheduler_config=self.scheduler_config, + # TODO allow draft-model specific load config. + )) + draft_worker = MultiStepWorker(**draft_worker_kwargs) spec_decode_worker = SpecDecodeWorker.from_workers( proposer_worker=draft_worker, scorer_worker=target_worker) diff --git a/vllm/executor/multi_gpu_executor.py b/vllm/executor/multi_gpu_executor.py new file mode 100644 index 000000000000..0a3110af1359 --- /dev/null +++ b/vllm/executor/multi_gpu_executor.py @@ -0,0 +1,120 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +class MultiGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def _init_device_and_model(self) -> None: + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model(self, *args, **kwargs) -> SamplerOutput: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + +class MultiGPUExecutorAsync(MultiGPUExecutor, ExecutorAsyncBase): + + @abstractmethod + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + async def execute_model_async(self, *args, **kwargs) -> SamplerOutput: + all_outputs = await self._run_workers_async("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 000000000000..dc0cb829c046 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,127 @@ +import asyncio +import os +from functools import partial +from typing import Any, Dict, Optional, Tuple + +from vllm.executor.multi_gpu_executor import (MultiGPUExecutor, + MultiGPUExecutorAsync) +from vllm.executor.multiproc_worker_utils import (LocalWorkerVllm, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.utils import get_vllm_instance_id, make_async + +logger = init_logger(__name__) + + +class MultiProcGPUExecutor(MultiGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = (",".join( + map(str, range(world_size)))) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + LocalWorkerVllm( + result_handler, + partial(self._create_worker, rank=rank, local_rank=rank)) + for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker() + self._init_device_and_model() + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + +class MultiProcGPUExecutorAsync(MultiProcGPUExecutor, MultiGPUExecutorAsync): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + driver_executor = make_async(getattr(self.driver_worker, method)) + + # Run all the workers asynchronously. + coros = [driver_executor(*driver_args, **driver_kwargs)] + [ + worker.execute_method_async(method, *args, **kwargs) + for worker in self.workers + ] + + return await asyncio.gather(*coros) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 000000000000..efff9d8e3ca9 --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,264 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +from vllm.logger import init_logger +from vllm.utils import enable_trace_function_call_for_thread + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp_method = os.getenv("MULTIPROC_METHOD", "fork") +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['LocalWorkerVllm'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(1) + if process.exitcode is not None and process.exitcode != 0: + logger.error( + f"Worker {process.name} pid {process.pid} died, " + f"exit code: {process.exitcode}") + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(2) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class LocalWorkerVllm(): + """Local process wrapper for vllm.worker.Worker + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = mp.Process( # type: ignore + target=_run_worker_process, + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + enable_trace_function_call_for_thread() + + worker = worker_factory() + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error(f"Exception in worker {process_name} " + f"while processing method {method}: {e}, {tb}") + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend output with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + assert hasattr(file, "start_new_line") + if file.start_new_line: + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False + + file.start_new_line = True # type: ignore + file.write = write_with_prefix # type: ignore diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index e69f104e7d5a..5515c3eaa71b 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,15 +3,14 @@ import pickle from collections import defaultdict from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from vllm.engine.ray_utils import RayWorkerWrapper, ray -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.multi_gpu_executor import (MultiGPUExecutor, + MultiGPUExecutorAsync) +from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) +from vllm.utils import get_ip, get_vllm_instance_id, make_async if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -27,7 +26,7 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(MultiGPUExecutor): def _init_executor(self) -> None: assert (not self.speculative_config @@ -74,7 +73,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. - self.driver_dummy_worker: RayWorkerWrapper = None + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] @@ -150,79 +149,16 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - def collect_arg_helper_func(**kwargs): - # avoid writing `{"name": value}` manually - return kwargs - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [] - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): - local_rank = node_workers[node_id].index(rank) - init_worker_all_kwargs.append( - collect_arg_helper_func( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=rank == 0, - )) + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._init_device_and_model() def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -244,23 +180,6 @@ def execute_model(self, output = all_outputs[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - def _run_workers( self, method: str, @@ -318,6 +237,7 @@ def _run_workers( driver_worker_output = self.driver_worker.execute_method( method, *driver_args, **driver_kwargs) else: + assert self.driver_dummy_worker is not None driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( method, *driver_args, **driver_kwargs)) @@ -353,8 +273,8 @@ def _compiled_ray_dag(self): # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote.bind(input_data) - for worker in self.workers + worker.execute_model_compiled_dag_remote.bind( # type: ignore + input_data) for worker in self.workers ]) return forward_dag.experimental_compile() @@ -376,7 +296,7 @@ def _check_if_any_actor_is_dead(self): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, MultiGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -407,23 +327,3 @@ async def _run_workers_async( all_outputs = await asyncio.gather(*coros) return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output diff --git a/vllm/engine/ray_utils.py b/vllm/executor/ray_utils.py similarity index 100% rename from vllm/engine/ray_utils.py rename to vllm/executor/ray_utils.py diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 69380d67f9b9..0195c40c27f6 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,7 +1,7 @@ from typing import Optional from vllm.config import TokenizerPoolConfig -from vllm.engine.ray_utils import ray +from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index f3cdc00564db..7c605416854b 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer from vllm.config import TokenizerPoolConfig -from vllm.engine.ray_utils import ray +from vllm.executor.ray_utils import ray from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) diff --git a/vllm/utils.py b/vllm/utils.py index 15c8818cc450..2193356a293c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,10 +1,13 @@ import asyncio +import datetime import enum import gc import glob import os import socket import subprocess +import tempfile +import threading import uuid import warnings from collections import defaultdict @@ -18,7 +21,7 @@ import torch from packaging.version import Version, parse -from vllm.logger import init_logger +from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") logger = init_logger(__name__) @@ -232,6 +235,7 @@ async def consumer(): return consumer() +@lru_cache(maxsize=None) def get_ip() -> str: host_ip = os.environ.get("HOST_IP") if host_ip: @@ -264,7 +268,10 @@ def get_ip() -> str: return "0.0.0.0" -def get_distributed_init_method(ip: str, port: int) -> str: +@lru_cache(maxsize=None) +def get_distributed_init_method() -> str: + ip = get_ip() + port = get_open_port() # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" @@ -607,3 +614,15 @@ def find_nccl_library(): raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info(f"Found nccl from library {so_file}") return so_file + + +def enable_trace_function_call_for_thread(): + if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b5dade0a770a..0a89e3a79769 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,15 +1,13 @@ -import datetime import importlib import os -import tempfile -import threading from abc import ABC, abstractmethod from typing import Dict, List, Set, Tuple -from vllm.logger import enable_trace_function_call, init_logger +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import get_vllm_instance_id, update_environment_variables +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) logger = init_logger(__name__) @@ -128,15 +126,7 @@ def init_worker(self, *args, **kwargs): function tracing if required. Arguments are passed to the worker class constructor. """ - if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): - tmp_dir = tempfile.gettempdir() - filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), - filename) - os.makedirs(os.path.dirname(log_path), exist_ok=True) - enable_trace_function_call(log_path) + enable_trace_function_call_for_thread() mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name)