From 937753b40832eaa722fa263a179854652a20ccdb Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 16 Feb 2024 15:55:52 -0800 Subject: [PATCH 1/8] Make ray optional for single-node deployment ray is a powerful platform for general purpose distributed computing but potentially overkill for the specific requirements of realtime synchronized inferencing between GPUs on a single node. We would prefer to have a "lightweight" option without the ray dependency for non-ray cluster environments. This also helps with production security compliance. With the changes in this PR, ray will continue to be used for parallel workers if it's installed, otherwise vanilla python multiprocessing is used. It can also be overridden with --no-worker-use-ray. Worker processes are shut down when the LLMEngine is garbage collected. Co-authored-by: Sahil Suneja --- Dockerfile | 2 +- requirements-cuda.txt | 1 - requirements-rocm.txt | 4 +- setup.py | 20 ++- tests/engine/test_local_worker.py | 68 ++++++++ vllm/config.py | 19 ++- vllm/engine/arg_utils.py | 12 +- vllm/engine/async_llm_engine.py | 6 +- vllm/engine/llm_engine.py | 11 +- vllm/engine/local_worker_utils.py | 204 ++++++++++++++++++++++++ vllm/engine/ray_utils.py | 7 +- vllm/executor/executor_base.py | 7 + vllm/executor/multi_gpu_executor.py | 140 ++++++++++++++++ vllm/executor/multiproc_gpu_executor.py | 138 ++++++++++++++++ vllm/executor/ray_gpu_executor.py | 134 +++------------- vllm/logger.py | 13 ++ vllm/utils.py | 5 +- 17 files changed, 645 insertions(+), 146 deletions(-) create mode 100644 tests/engine/test_local_worker.py create mode 100644 vllm/engine/local_worker_utils.py create mode 100644 vllm/executor/multi_gpu_executor.py create mode 100644 vllm/executor/multiproc_gpu_executor.py diff --git a/Dockerfile b/Dockerfile index d1d29177b0f4..eaaaa44f83d5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -113,7 +113,7 @@ RUN ldconfig /usr/local/cuda-12.1/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ - pip install dist/*.whl --verbose + pip install "$(echo dist/*.whl)[ray]" --verbose RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ --mount=type=cache,target=/root/.cache/pip \ diff --git a/requirements-cuda.txt b/requirements-cuda.txt index c6d2cd46aee5..e2115b586015 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -2,7 +2,6 @@ -r requirements-common.txt # Dependencies for NVIDIA GPUs -ray >= 2.9 pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 903845b64d98..0c944b138f6d 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -1,5 +1,5 @@ # Common dependencies -r requirements-common.txt -# Dependencies for AMD GPUs -ray == 2.9.3 +# No specific dependencies currently for AMD GPUs + diff --git a/setup.py b/setup.py index 19a9150ad2e6..0320e97fc7c4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import subprocess import sys from shutil import which -from typing import Dict, List +from typing import Dict, List, Optional import torch from packaging.version import Version, parse @@ -361,6 +361,20 @@ def _read_requirements(filename: str) -> List[str]: return requirements +def get_extra_requirements() -> Optional[Dict[str, List[str]]]: + extras = {"tensorizer": ["tensorizer==2.9.0a1"]} + if _is_cuda(): + extras["ray"] = ["ray>=2.9"] + elif _is_hip(): + extras["ray"] = ["ray==2.9.3"] + elif _is_neuron() or _is_cpu(): + pass + else: + raise ValueError( + "Unsupported platform, please use CUDA, ROCM or Neuron.") + return extras + + ext_modules = [] if _is_cuda(): @@ -405,9 +419,7 @@ def _read_requirements(filename: str) -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - extras_require={ - "tensorizer": ["tensorizer==2.9.0a1"], - }, + extras_require=get_extra_requirements(), cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/engine/test_local_worker.py b/tests/engine/test_local_worker.py new file mode 100644 index 000000000000..ace0c245a9e8 --- /dev/null +++ b/tests/engine/test_local_worker.py @@ -0,0 +1,68 @@ +import multiprocessing as mp + +import pytest +import torch + +from vllm import LLM, SamplingParams + +TENSOR_PARALLEL_SIZE = 2 +MAX_GENERATION_TOKENS = 256 + + +def llm_generate(result_queue, prompt_token_ids, worker_use_ray=False): + try: + llm = LLM(model="facebook/opt-350m", + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + worker_use_ray=worker_use_ray) + + output = llm.generate( + prompt_token_ids=prompt_token_ids, + sampling_params=SamplingParams(max_tokens=MAX_GENERATION_TOKENS)) + except BaseException as e: + output = e + + result_queue.put(output) + + +def run_llm(prompt_token_ids, worker_use_ray=False): + result_queue = mp.Queue() + proc = mp.Process(target=llm_generate, + args=(result_queue, prompt_token_ids, worker_use_ray)) + proc.start() + result = result_queue.get() + proc.join() + if isinstance(result, BaseException): + raise result + return result + + +def get_prompts(): + # https://github.com/vllm-project/vllm/issues/367#issuecomment-1629872996 + batch_size = 32 + dim = 120 + max_token_id = 32000 + torch.manual_seed(42) + batch = torch.randint(max_token_id, (batch_size, dim)) + prompt_token_ids = [tokens.tolist() for tokens in batch] + return prompt_token_ids + + +@pytest.mark.skip("Requires multiple GPUs") +def test_local_worker(): + # Similar to tests/lora/test_llama.py + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + prompt_token_ids = get_prompts() + output1 = run_llm(prompt_token_ids, worker_use_ray=False) + output2 = run_llm(prompt_token_ids, worker_use_ray=True) + assert len(output1) == len(output2) + + completion_token_ids1 = [item.outputs[0].token_ids for item in output1] + completion_token_ids2 = [item.outputs[0].token_ids for item in output2] + assert completion_token_ids1 == completion_token_ids2 + + +if __name__ == "__main__": + test_local_worker() diff --git a/vllm/config.py b/vllm/config.py index dce2944b2ee8..a2e9dfa8ae75 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ import enum +import importlib.util import io import json import os @@ -422,7 +423,7 @@ def verify_with_parallel_config( @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -446,9 +447,9 @@ def create_config( tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -477,9 +478,9 @@ class ParallelConfig: Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to + worker_use_ray: Whether to use Ray for model workers. Will default to True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + greater than 1 and Ray is installed. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -495,7 +496,7 @@ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, @@ -512,8 +513,10 @@ def __init__( self.placement_group = placement_group self.world_size = pipeline_parallel_size * self.tensor_parallel_size - if self.world_size > 1: - self.worker_use_ray = True + if self.worker_use_ray is None: + ray_found = importlib.util.find_spec("ray") is not None + self.worker_use_ray = ray_found and self.world_size > 1 + self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 831a03be65f6..e685f3589ab2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,7 +27,7 @@ class EngineArgs: quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None - worker_use_ray: bool = False + worker_use_ray: Optional[bool] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -201,10 +201,12 @@ def add_cli_args( help='model context length. If unspecified, ' 'will be automatically derived from the model.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') + parser.add_argument( + '--worker-use-ray', + action=argparse.BooleanOptionalAction, + default=None, + help='use Ray for distributed serving, will default ' + 'to true when ray is installed and more than 1 GPU is used') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f61049513512..a53ce630fc24 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -338,9 +338,11 @@ def from_engine_args( initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync + elif engine_config.parallel_config.world_size > 1: + from vllm.executor.multiproc_gpu_executor import ( + MultiProcGPUExecutorAsync) + executor_class = MultiProcGPUExecutorAsync else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee..7e5d3094d812 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -224,9 +224,11 @@ def from_engine_args( initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor + elif engine_config.parallel_config.world_size > 1: + from vllm.executor.multiproc_gpu_executor import ( + MultiProcGPUExecutor) + executor_class = MultiProcGPUExecutor else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor @@ -244,6 +246,11 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def __del__(self): + # Shutdown model executor when engine is garbage collected + if self.model_executor is not None: + self.model_executor.shutdown() + def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py new file mode 100644 index 000000000000..7f12f3e0aec6 --- /dev/null +++ b/vllm/engine/local_worker_utils.py @@ -0,0 +1,204 @@ +import asyncio +import multiprocessing +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing.connection import wait +from typing import Dict, Generic, List, Optional, TypeVar, Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp = multiprocessing.get_context("fork") + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID = None + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + if self.result.exception is not None: + raise self.result.exception + return self.result.value + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for future in self.tasks.values(): + _set_future_result( + future, Result(exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['LocalWorkerVllm'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([p.sentinel for p in self.workers]) + if self._close: + return + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + if worker.sentinel in dead_sentinels: + worker.join(1) + if worker.exitcode is not None and worker.exitcode != 0: + logger.error(f"Worker {worker.name} pid {worker.pid} died, " + f"exit code: {worker.exitcode}") + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class LocalWorkerVllm(mp.Process): + """Local process wrapper for vllm.worker.Worker + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, *args, **kwargs) -> None: + super().__init__(daemon=True) + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.worker_args = args + self.worker_kwargs = kwargs + self.worker = None + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.kill() + + def run(self) -> None: + # Re-init logger in forked process, to include worker-specific prefix + global logger + logger = init_logger(__name__) + + del self.tasks # Not used in forked process + from vllm.worker.worker import Worker + self.worker = Worker(*self.worker_args, **self.worker_kwargs) + del self.worker_args + del self.worker_kwargs + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(self._task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(self.worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + f"Exception in worker {mp.current_process().name} " + f"while processing method {method}: {e}, {tb}") + exception = e + self.result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 04d4ed83976d..d78e73ed8ba1 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -73,9 +73,10 @@ def execute_model_compiled_dag_remote(self, ignored): return output except ImportError as e: - logger.warning(f"Failed to import Ray with {e!r}. " - "For distributed inference, please install Ray with " - "`pip install ray`.") + logger.warning( + f"Unable to import Ray with {e!r}. " + "For multi-node distributed inference, please install Ray with " + "`pip install ray`.") ray = None # type: ignore RayWorkerVllm = None # type: ignore diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbb6ec80f7b7..bbe543638879 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -94,6 +94,13 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + class ExecutorAsyncBase(ExecutorBase): diff --git a/vllm/executor/multi_gpu_executor.py b/vllm/executor/multi_gpu_executor.py new file mode 100644 index 000000000000..499667e47350 --- /dev/null +++ b/vllm/executor/multi_gpu_executor.py @@ -0,0 +1,140 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +class MultiGPUExecutor(ExecutorBase): + """Abstract superclass of multi-GPU executor implementations.""" + + def _init_driver_worker_and_model(self, rank: int, local_rank: int, + distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + + # Initialize the driver worker with the Worker class. + self.driver_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, + is_driver_worker=True, + ) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model(self, *args, **kwargs) -> SamplerOutput: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + +class MultiGPUExecutorAsync(MultiGPUExecutor, ExecutorAsyncBase): + + @abstractmethod + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + async def execute_model_async(self, *args, **kwargs) -> SamplerOutput: + all_outputs = await self._run_workers_async("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 000000000000..456d02721d17 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,138 @@ +import asyncio +import os +from typing import Any, Dict, Optional, Tuple + +from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, + WorkerMonitor) +from vllm.executor.multi_gpu_executor import (MultiGPUExecutor, + MultiGPUExecutorAsync) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async, set_cuda_visible_devices) + +logger = init_logger(__name__) + + +class MultiProcGPUExecutor(MultiGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + set_cuda_visible_devices(range(world_size)) + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + LocalWorkerVllm( + result_handler, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, + ) for rank in range(1, world_size) + ] + + for worker in self.workers: + worker.start() + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self._init_driver_worker_and_model(0, 0, distributed_init_method) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + +class MultiProcGPUExecutorAsync(MultiProcGPUExecutor, MultiGPUExecutorAsync): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + driver_executor = make_async(getattr(self.driver_worker, method)) + + # Run all the workers asynchronously. + coros = [driver_executor(*driver_args, **driver_kwargs)] + [ + worker.execute_method_async(method, *args, **kwargs) + for worker in self.workers + ] + + return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5db2f3f65253..2c949d778b91 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,12 +3,12 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.engine.ray_utils import RayWorkerVllm, ray -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.multi_gpu_executor import (MultiGPUExecutor, + MultiGPUExecutorAsync) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async, set_cuda_visible_devices) @@ -27,7 +27,7 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(MultiGPUExecutor): def _init_executor(self) -> None: assert (not self.speculative_config @@ -154,69 +154,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", tensorizer_config=self.tensorizer_config, )) - # Initialize the driver worker with the Worker class. driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=driver_local_rank, - rank=driver_rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, - is_driver_worker=True, - ) - - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._init_driver_worker_and_model(driver_rank, driver_local_rank, + distributed_init_method) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -237,23 +178,6 @@ def execute_model(self, output = all_outputs[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - def _run_workers( self, method: str, @@ -292,19 +216,17 @@ def _run_workers( method)(*driver_args, **driver_kwargs) # Get the results of the ray workers. - if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs @@ -346,7 +268,7 @@ def _check_if_any_actor_is_dead(self): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, MultiGPUExecutorAsync): async def _run_workers_async( self, @@ -374,23 +296,3 @@ async def _run_workers_async( all_outputs = await asyncio.gather(*coros) return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output diff --git a/vllm/logger.py b/vllm/logger.py index af9575085ef3..8e497fa7e938 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -2,6 +2,7 @@ # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" import logging +import multiprocessing as mp import os import sys from typing import Optional @@ -12,6 +13,13 @@ _DATE_FORMAT = "%m-%d %H:%M:%S" +class ProcessLogger(logging.LoggerAdapter): + + def process(self, msg, kwargs): + msg = f"[{self.extra['process_name']} pid {self.extra['pid']}] {msg}" + return msg, kwargs + + class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" @@ -64,4 +72,9 @@ def init_logger(name: str): " Please open an issue on Github.") logger.addHandler(_default_handler) logger.propagate = False + if mp.parent_process() is not None: + logger = ProcessLogger(logger, { + 'process_name': mp.current_process().name, + 'pid': os.getpid() + }) return logger diff --git a/vllm/utils.py b/vllm/utils.py index 4c0dc9ca729a..193c9f225283 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,6 +7,7 @@ import uuid import warnings from collections import OrderedDict, defaultdict +from collections.abc import Iterable from functools import lru_cache, partial from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, @@ -270,7 +271,7 @@ def get_open_port() -> int: return s.getsockname()[1] -def set_cuda_visible_devices(device_ids: List[int]) -> None: +def set_cuda_visible_devices(device_ids: Iterable[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) @@ -492,7 +493,7 @@ def maybe_expand_dim(tensor: torch.Tensor, def merge_dicts(dict1: Dict[Any, List[Any]], dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. - + When a key conflicts, the values in dict1 is prioritized. """ merged_dict = defaultdict(list) From c0bad3d9f628233f6fe8ed71235909883b61e6a0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 22 Mar 2024 16:39:49 -0700 Subject: [PATCH 2/8] Use getattr to access model_executor in engine destructor --- vllm/engine/llm_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7e5d3094d812..67a3ee003953 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -248,8 +248,10 @@ def __reduce__(self): def __del__(self): # Shutdown model executor when engine is garbage collected - if self.model_executor is not None: - self.model_executor.shutdown() + # Use getattr since __init__ can fail before the field is set + model_executor = getattr(self, "model_executor", None) + if model_executor is not None: + model_executor.shutdown() def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) From d0a8709a58320b7e36ea6cca4010b63f02c76e1e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 28 Mar 2024 21:50:02 -0700 Subject: [PATCH 3/8] Address a couple of review comments --- vllm/engine/llm_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 67a3ee003953..ad5394e97491 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -249,8 +249,7 @@ def __reduce__(self): def __del__(self): # Shutdown model executor when engine is garbage collected # Use getattr since __init__ can fail before the field is set - model_executor = getattr(self, "model_executor", None) - if model_executor is not None: + if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() def get_tokenizer(self) -> "PreTrainedTokenizer": From 5e214a38875118b2e099c65050a150f29d2d0c82 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 29 Mar 2024 10:20:35 -0700 Subject: [PATCH 4/8] Extend existing distributed correctness test Instead of adding equivalent new test --- .buildkite/test-pipeline.yaml | 9 +-- .../test_basic_distributed_correctness.py | 11 +-- tests/engine/test_local_worker.py | 68 ------------------- vllm/engine/local_worker_utils.py | 4 +- 4 files changed, 14 insertions(+), 78 deletions(-) delete mode 100644 tests/engine/test_local_worker.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa4582bbda0c..79e74766e855 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -33,10 +33,11 @@ steps: num_gpus: 2 # only support 1 or 2 for now. commands: - pytest -v -s test_pynccl.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py + # Use spawn to avoid CUDA re-init issues + - TEST_DIST_MODEL=facebook/opt-125m MULTIPROC_METHOD=spawn pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf MULTIPROC_METHOD=spawn pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m MULTIPROC_METHOD=spawn pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf MULTIPROC_METHOD=spawn pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 77aa90b12bf8..faed6f1e82e7 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -25,6 +25,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("worker_use_ray", [False, True]) def test_models( hf_runner, vllm_runner, @@ -32,17 +33,17 @@ def test_models( model: str, dtype: str, max_tokens: int, + worker_use_ray: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - ) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + worker_use_ray=worker_use_ray) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/engine/test_local_worker.py b/tests/engine/test_local_worker.py deleted file mode 100644 index ace0c245a9e8..000000000000 --- a/tests/engine/test_local_worker.py +++ /dev/null @@ -1,68 +0,0 @@ -import multiprocessing as mp - -import pytest -import torch - -from vllm import LLM, SamplingParams - -TENSOR_PARALLEL_SIZE = 2 -MAX_GENERATION_TOKENS = 256 - - -def llm_generate(result_queue, prompt_token_ids, worker_use_ray=False): - try: - llm = LLM(model="facebook/opt-350m", - tensor_parallel_size=TENSOR_PARALLEL_SIZE, - worker_use_ray=worker_use_ray) - - output = llm.generate( - prompt_token_ids=prompt_token_ids, - sampling_params=SamplingParams(max_tokens=MAX_GENERATION_TOKENS)) - except BaseException as e: - output = e - - result_queue.put(output) - - -def run_llm(prompt_token_ids, worker_use_ray=False): - result_queue = mp.Queue() - proc = mp.Process(target=llm_generate, - args=(result_queue, prompt_token_ids, worker_use_ray)) - proc.start() - result = result_queue.get() - proc.join() - if isinstance(result, BaseException): - raise result - return result - - -def get_prompts(): - # https://github.com/vllm-project/vllm/issues/367#issuecomment-1629872996 - batch_size = 32 - dim = 120 - max_token_id = 32000 - torch.manual_seed(42) - batch = torch.randint(max_token_id, (batch_size, dim)) - prompt_token_ids = [tokens.tolist() for tokens in batch] - return prompt_token_ids - - -@pytest.mark.skip("Requires multiple GPUs") -def test_local_worker(): - # Similar to tests/lora/test_llama.py - # Cannot use as it will initialize torch.cuda too early... - # if torch.cuda.device_count() < 2: - # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") - - prompt_token_ids = get_prompts() - output1 = run_llm(prompt_token_ids, worker_use_ray=False) - output2 = run_llm(prompt_token_ids, worker_use_ray=True) - assert len(output1) == len(output2) - - completion_token_ids1 = [item.outputs[0].token_ids for item in output1] - completion_token_ids2 = [item.outputs[0].token_ids for item in output2] - assert completion_token_ids1 == completion_token_ids2 - - -if __name__ == "__main__": - test_local_worker() diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index 7f12f3e0aec6..db4b045ac114 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -1,5 +1,6 @@ import asyncio import multiprocessing +import os import threading import traceback import uuid @@ -17,7 +18,8 @@ # Use dedicated multiprocess context for workers. # Both spawn and fork work -mp = multiprocessing.get_context("fork") +mp_method = os.getenv("MULTIPROC_METHOD", "fork") +mp = multiprocessing.get_context(mp_method) @dataclass From 0fb0743afbc9c8dedb0753af2e8a57858e460e04 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 29 Mar 2024 13:34:43 -0700 Subject: [PATCH 5/8] Use factory for worker initialization Useful for unit tests --- vllm/engine/local_worker_utils.py | 14 +++++------ vllm/executor/multiproc_gpu_executor.py | 33 ++++++++++++++++--------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index db4b045ac114..f9809b2168d6 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -6,7 +6,7 @@ import uuid from dataclasses import dataclass from multiprocessing.connection import wait -from typing import Dict, Generic, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union from vllm.logger import init_logger @@ -128,13 +128,13 @@ class LocalWorkerVllm(mp.Process): """Local process wrapper for vllm.worker.Worker for handling single-node multi-GPU tensor parallel.""" - def __init__(self, result_handler: ResultHandler, *args, **kwargs) -> None: + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: super().__init__(daemon=True) self._task_queue = mp.Queue() self.result_queue = result_handler.result_queue self.tasks = result_handler.tasks - self.worker_args = args - self.worker_kwargs = kwargs + self.worker_factory = worker_factory self.worker = None def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], @@ -174,10 +174,8 @@ def run(self) -> None: logger = init_logger(__name__) del self.tasks # Not used in forked process - from vllm.worker.worker import Worker - self.worker = Worker(*self.worker_args, **self.worker_kwargs) - del self.worker_args - del self.worker_kwargs + self.worker = self.worker_factory() + del self.worker_factory # Accept tasks from the engine in task_queue # and return task output in result_queue diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 456d02721d17..b00472ebe770 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,5 +1,6 @@ import asyncio import os +from functools import partial from typing import Any, Dict, Optional, Tuple from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, @@ -13,6 +14,12 @@ logger = init_logger(__name__) +def _create_worker(*args, **kwargs): + # Import within worker process to avoid CUDA init issues + from vllm.worker.worker import Worker + return Worker(*args, **kwargs) + + class MultiProcGPUExecutor(MultiGPUExecutor): """Python multiprocessing-based multi-GPU executor""" @@ -42,18 +49,20 @@ def _init_executor(self) -> None: self.workers = [ LocalWorkerVllm( result_handler, - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, - ) for rank in range(1, world_size) + partial( + _create_worker, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, + )) for rank in range(1, world_size) ] for worker in self.workers: From e04800109559a5dcb828426970bf7bb7716d1dd8 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 29 Mar 2024 15:17:12 -0700 Subject: [PATCH 6/8] Test local worker mechanics in isolation --- tests/engine/test_local_workers.py | 179 +++++++++++++++++++++++++++++ vllm/engine/local_worker_utils.py | 33 +++--- 2 files changed, 197 insertions(+), 15 deletions(-) create mode 100644 tests/engine/test_local_workers.py diff --git a/tests/engine/test_local_workers.py b/tests/engine/test_local_workers.py new file mode 100644 index 000000000000..dc84f20fe130 --- /dev/null +++ b/tests/engine/test_local_workers.py @@ -0,0 +1,179 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, + WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[LocalWorkerVllm], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + LocalWorkerVllm(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + for worker in workers: + worker.start() + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index f9809b2168d6..b83ea5e69204 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -95,23 +95,26 @@ def __init__(self, workers: List['LocalWorkerVllm'], def run(self) -> None: # Blocks until any worker exits dead_sentinels = wait([p.sentinel for p in self.workers]) - if self._close: - return - self._close = True + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + if worker.sentinel in dead_sentinels: + worker.join(1) + if worker.exitcode is not None and worker.exitcode != 0: + logger.error( + f"Worker {worker.name} pid {worker.pid} died, " + f"exit code: {worker.exitcode}") + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() - # Kill / cleanup all workers for worker in self.workers: - if worker.sentinel in dead_sentinels: - worker.join(1) - if worker.exitcode is not None and worker.exitcode != 0: - logger.error(f"Worker {worker.name} pid {worker.pid} died, " - f"exit code: {worker.exitcode}") - # Cleanup any remaining workers - logger.info("Killing local vLLM worker processes") - for worker in self.workers: - worker.kill_worker() - # Must be done after worker task queues are all closed - self.result_handler.close() + worker.join(2) def close(self): if self._close: From 1938c35e80ed106583fe9b919908e8270535d01e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 29 Mar 2024 22:18:46 -0700 Subject: [PATCH 7/8] Add pid prefix to process stdout/stderr instead of logger --- vllm/engine/local_worker_utils.py | 41 ++++++++++++++++++++++++++++--- vllm/logger.py | 13 ---------- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index b83ea5e69204..4a5d12b08704 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -1,10 +1,12 @@ import asyncio import multiprocessing import os +import sys import threading import traceback import uuid from dataclasses import dataclass +from io import TextIOBase from multiprocessing.connection import wait from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union @@ -16,6 +18,10 @@ _TERMINATE = "TERMINATE" # sentinel +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + # Use dedicated multiprocess context for workers. # Both spawn and fork work mp_method = os.getenv("MULTIPROC_METHOD", "fork") @@ -172,9 +178,11 @@ def kill_worker(self): self.kill() def run(self) -> None: - # Re-init logger in forked process, to include worker-specific prefix - global logger - logger = init_logger(__name__) + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) del self.tasks # Not used in forked process self.worker = self.worker_factory() @@ -205,3 +213,30 @@ def run(self) -> None: logger.exception("Worker failed") logger.info("Worker exiting") + + +def _add_prefix(file: TextIOBase, worker_name: str, pid: int) -> None: + """Prepend output with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False + + file.start_new_line = True + file.write = write_with_prefix diff --git a/vllm/logger.py b/vllm/logger.py index 8e497fa7e938..af9575085ef3 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -2,7 +2,6 @@ # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" import logging -import multiprocessing as mp import os import sys from typing import Optional @@ -13,13 +12,6 @@ _DATE_FORMAT = "%m-%d %H:%M:%S" -class ProcessLogger(logging.LoggerAdapter): - - def process(self, msg, kwargs): - msg = f"[{self.extra['process_name']} pid {self.extra['pid']}] {msg}" - return msg, kwargs - - class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" @@ -72,9 +64,4 @@ def init_logger(name: str): " Please open an issue on Github.") logger.addHandler(_default_handler) logger.propagate = False - if mp.parent_process() is not None: - logger = ProcessLogger(logger, { - 'process_name': mp.current_process().name, - 'pid': os.getpid() - }) return logger From 56a1ad439350adcd7d7d3e6c18c95b60d23181ca Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 11 Apr 2024 12:33:02 +0100 Subject: [PATCH 8/8] Update new chunked prefill distributed test to include non-Ray --- tests/distributed/test_chunked_prefill_distributed.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 737b1f316951..a57660794f00 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -27,6 +27,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("chunked_prefill_token_size", [16]) +@pytest.mark.parametrize("worker_use_ray", [False, True]) def test_models( hf_runner, vllm_runner, @@ -35,6 +36,7 @@ def test_models( dtype: str, max_tokens: int, chunked_prefill_token_size: int, + worker_use_ray: bool, ) -> None: # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) @@ -53,6 +55,7 @@ def test_models( max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + worker_use_ray=worker_use_ray, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model