diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py new file mode 100644 index 000000000000..610ad9732fb9 --- /dev/null +++ b/tests/engine/test_multiproc_workers.py @@ -0,0 +1,176 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker.Worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.process.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(5) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 000000000000..0c04796bc38e --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,264 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp_method = os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory() + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + "Exception in worker %s while processing method %s: %s, %s", + process_name, method, e, tb) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign]