From 3c05ea0847fa90abcc7af075bab97975f5938819 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 16 Feb 2024 15:55:52 -0800 Subject: [PATCH 1/6] Make ray optional for single-node deployment ray is a powerful platform for general purpose distributed computing but potentially overkill for the specific requirements of realtime synchronized inferencing between GPUs on a single node. We would prefer to have a "lightweight" option without the ray dependency for non-ray cluster environments. With the changes in this PR, ray will continue to be used for parallel workers if it's installed, otherwise vanilla python multiprocessing is used. Worker processes are shut down when the LLMEngine is garbage collected. Co-authored-by: Sahil Suneja --- Dockerfile | 3 +- requirements-rocm.txt | 1 - requirements.txt | 1 - setup.py | 8 ++ tests/engine/test_local_worker.py | 66 ++++++++++ vllm/config.py | 14 ++- vllm/engine/arg_utils.py | 12 +- vllm/engine/async_llm_engine.py | 12 +- vllm/engine/llm_engine.py | 119 ++++++++++++------ vllm/engine/local_worker_utils.py | 194 ++++++++++++++++++++++++++++++ vllm/engine/ray_utils.py | 9 +- vllm/logger.py | 13 ++ vllm/utils.py | 3 +- 13 files changed, 391 insertions(+), 64 deletions(-) create mode 100644 tests/engine/test_local_worker.py create mode 100644 vllm/engine/local_worker_utils.py diff --git a/Dockerfile b/Dockerfile index dd4867702d3d..3e9f8e1e08ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -70,7 +70,7 @@ ADD . /vllm-workspace/ COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/ # ignore build dependencies installation because we are using pre-complied extensions RUN rm pyproject.toml -RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose +RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install .[ray] --verbose #################### TEST IMAGE #################### @@ -80,7 +80,6 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal # In the future it would be nice to get a container with pytorch and cuda without duplicating cuda FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base -# libnccl required for ray RUN apt-get update -y \ && apt-get install -y python3-pip diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 53bd11de7c9d..7a77fd959f3f 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,7 +2,6 @@ ninja # For faster builds. typing-extensions>=4.8.0 starlette psutil -ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy tokenizers>=0.15.0 diff --git a/requirements.txt b/requirements.txt index 05ec2e804e13..1fc5504ed7fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ ninja # For faster builds. psutil -ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy torch == 2.1.2 diff --git a/setup.py b/setup.py index 745b5a9b2d02..5032c1a89fa8 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ import torch import torch.utils.cpp_extension as torch_cpp_ext from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME +from typing import Optional ROOT_DIR = os.path.dirname(__file__) @@ -434,6 +435,12 @@ def get_requirements() -> List[str]: return requirements +def get_ray_requirement() -> Optional[List[str]]: + if _is_neuron(): + return None + return ["ray >= 2.9"] + + package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } @@ -467,6 +474,7 @@ def get_requirements() -> List[str]: "examples", "tests")), python_requires=">=3.8", install_requires=get_requirements(), + extras_requires=get_ray_requirement(), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {}, package_data=package_data, diff --git a/tests/engine/test_local_worker.py b/tests/engine/test_local_worker.py new file mode 100644 index 000000000000..aae10b9032b1 --- /dev/null +++ b/tests/engine/test_local_worker.py @@ -0,0 +1,66 @@ +import pytest +import torch +import multiprocessing as mp +from vllm import LLM, SamplingParams + +TENSOR_PARALLEL_SIZE = 2 +MAX_GENERATION_TOKENS = 256 + + +def llm_generate(result_queue, prompt_token_ids, worker_use_ray=False): + try: + llm = LLM(model="facebook/opt-350m", + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + worker_use_ray=worker_use_ray) + + output = llm.generate( + prompt_token_ids=prompt_token_ids, + sampling_params=SamplingParams(max_tokens=MAX_GENERATION_TOKENS)) + except BaseException as e: + output = e + + result_queue.put(output) + + +def run_llm(prompt_token_ids, worker_use_ray=False): + result_queue = mp.Queue() + proc = mp.Process(target=llm_generate, + args=(result_queue, prompt_token_ids, worker_use_ray)) + proc.start() + result = result_queue.get() + proc.join() + if isinstance(result, BaseException): + raise result + return result + + +def get_prompts(): + # https://github.com/vllm-project/vllm/issues/367#issuecomment-1629872996 + batch_size = 32 + dim = 120 + max_token_id = 32000 + torch.manual_seed(42) + batch = torch.randint(max_token_id, (batch_size, dim)) + prompt_token_ids = [tokens.tolist() for tokens in batch] + return prompt_token_ids + + +@pytest.mark.skip("Requires multiple GPUs") +def test_local_worker(): + # Similar to tests/lora/test_llama.py + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + prompt_token_ids = get_prompts() + output1 = run_llm(prompt_token_ids, worker_use_ray=False) + output2 = run_llm(prompt_token_ids, worker_use_ray=True) + assert len(output1) == len(output2) + + completion_token_ids1 = [item.outputs[0].token_ids for item in output1] + completion_token_ids2 = [item.outputs[0].token_ids for item in output2] + assert completion_token_ids1 == completion_token_ids2 + + +if __name__ == "__main__": + test_local_worker() diff --git a/vllm/config.py b/vllm/config.py index ef9a920f29c2..68758b8ab061 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,3 +1,4 @@ +import importlib.util from typing import Optional, Union, ClassVar from dataclasses import dataclass import os @@ -376,9 +377,9 @@ class ParallelConfig: Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to + worker_use_ray: Whether to use Ray for model workers. Will default to True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + greater than 1 and Ray is installed. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -392,7 +393,7 @@ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, ray_workers_use_nsight: bool = False, @@ -412,9 +413,10 @@ def __init__( self.ray_workers_use_nsight = ray_workers_use_nsight self.world_size = pipeline_parallel_size * self.tensor_parallel_size - # Ray worker is not supported for Neuron backend. - if self.world_size > 1 and not is_neuron(): - self.worker_use_ray = True + if self.worker_use_ray is None: + ray_found = importlib.util.find_spec("ray") is not None + self.worker_use_ray = ray_found and self.world_size > 1 + self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c3dccdd5bb50..765baa90b953 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -20,7 +20,7 @@ class EngineArgs: kv_cache_dtype: str = 'auto' seed: int = 0 max_model_len: Optional[int] = None - worker_use_ray: bool = False + worker_use_ray: Optional[bool] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -149,10 +149,12 @@ def add_cli_args( help='model context length. If unspecified, ' 'will be automatically derived from the model.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') + parser.add_argument( + '--worker-use-ray', + action='store_true', + default=None, + help='use Ray for distributed serving, will default ' + 'to true when ray is installed and more than 1 GPU is used') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 65ab0c063417..0d131a78c0ec 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -287,9 +287,15 @@ async def _run_workers_async( coros.append(asyncio.get_event_loop().run_in_executor( None, partial(driver_executor, *driver_args, **driver_kwargs))) - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) + # Run the workers asynchronously. + if self.parallel_config.worker_use_ray: + for worker in self.workers: + coros.append( + worker.execute_method.remote(method, *args, **kwargs)) + else: + for worker in self.workers: + coros.append( + worker.execute_method_async(method, *args, **kwargs)) all_outputs = await asyncio.gather(*coros) return all_outputs diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1f518cbf39b2..57584def9e2f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -15,6 +15,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.engine.local_worker_utils import LocalWorkerVllm, WorkerMonitor, ResultHandler from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -119,6 +120,7 @@ def __init__( self.seq_counter = Counter() # Create the parallel GPU workers. + self.worker_monitor = None if self.parallel_config.worker_use_ray: # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") @@ -169,30 +171,79 @@ def _dispatch_worker(self): return Worker def _init_workers(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - Worker = self._dispatch_worker() + world_size = self.parallel_config.tensor_parallel_size - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + set_cuda_visible_devices(range(world_size)) + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") - self.workers: List[Worker] = [] distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + LocalWorkerVllm( + result_handler, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + ) for rank in range(1, world_size) + ] + + for worker in self.workers: + worker.start() + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self._init_driver_worker_and_model(0, 0, distributed_init_method) + + def __del__(self): + # Terminate local worker processes when engine is garbage collected + if self.worker_monitor is not None: + self.worker_monitor.close() + + def _init_driver_worker_and_model(self, rank: int, local_rank: int, + distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + Worker = self._dispatch_worker() + self.driver_worker = Worker( self.model_config, self.parallel_config, self.scheduler_config, self.device_config, - local_rank=0, - rank=0, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) - self._run_workers("init_model") - self._run_workers("load_model") + # don't use cupy for eager mode + self._run_workers("init_model", + cupy_port=get_open_port() + if not self.model_config.enforce_eager else None) + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( @@ -301,28 +352,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - model_config, - parallel_config, - scheduler_config, - device_config, - driver_local_rank, - driver_rank, - distributed_init_method, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=True, - ) - - # don't use cupy for eager mode - self._run_workers("init_model", - cupy_port=get_open_port() - if not model_config.enforce_eager else None) - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) + self._init_driver_worker_and_model(driver_rank, driver_local_rank, + distributed_init_method) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -1063,13 +1094,19 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - if use_ray_compiled_dag: + # Start the workers first. + if not self.parallel_config.worker_use_ray: + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + elif use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. - ray_worker_outputs = [ + worker_outputs = [ worker.execute_method.remote(method, *args, **kwargs) for worker in self.workers ] @@ -1079,15 +1116,17 @@ def _run_workers( if driver_kwargs is None: driver_kwargs = kwargs - # Start the driver worker after all the ray workers. + # Start the driver worker after all the other workers. driver_worker_output = getattr(self.driver_worker, method)(*driver_args, **driver_kwargs) - # Get the results of the ray workers. + # Get the results of the workers. if self.workers: - if use_ray_compiled_dag: + if not self.parallel_config.worker_use_ray: + worker_outputs = [output.get() for output in worker_outputs] + elif use_ray_compiled_dag: try: - ray_worker_outputs = [ + worker_outputs = [ pickle.loads(chan.begin_read()) for chan in output_channels ] @@ -1096,9 +1135,9 @@ def _run_workers( for chan in output_channels: chan.end_read() else: - ray_worker_outputs = ray.get(ray_worker_outputs) + worker_outputs = ray.get(worker_outputs) - return [driver_worker_output] + ray_worker_outputs + return [driver_worker_output] + worker_outputs def _compiled_ray_dag(self): import pkg_resources diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py new file mode 100644 index 000000000000..3b4711bbc033 --- /dev/null +++ b/vllm/engine/local_worker_utils.py @@ -0,0 +1,194 @@ +import asyncio +import os +import traceback +import threading +import multiprocessing as mp +import uuid +from multiprocessing.connection import wait +from dataclasses import dataclass +from typing import TypeVar, Generic, Optional, Union, List, Dict + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID = None + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + if self.result.exception is not None: + raise self.result.exception + return self.result.value + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for future in self.tasks.values(): + _set_future_result( + future, Result(exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['LocalWorkerVllm'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + dead_sentinels = wait([p.sentinel for p in self.workers]) + if self._close: + return + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + if worker.sentinel in dead_sentinels: + worker.join(1) + if worker.exitcode is not None and worker.exitcode != 0: + logger.error(f"Worker {worker.name} pid {worker.pid} died, " + f"exit code: {worker.exitcode}") + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class LocalWorkerVllm(mp.Process): + """Local process wrapper for vllm.worker.Worker + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, *args, **kwargs) -> None: + super().__init__(daemon=True) + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.worker_args = args + self.worker_kwargs = kwargs + self.worker = None + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.kill() + + def run(self) -> None: + del self.tasks + from vllm.worker.worker import Worker + self.worker = Worker(*self.worker_args, **self.worker_kwargs) + del self.worker_args + del self.worker_kwargs + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info( + f"Worker {mp.current_process().name} pid {os.getpid()} ready; " + "awaiting tasks") + for items in iter(self._task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(self.worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + f"Exception in worker {mp.current_process().name} " + f"while processing method {method}: {e}, {tb}") + exception = e + self.result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + + logger.info( + f"Worker {mp.current_process().name} pid {os.getpid()} exiting") diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index bbcbbdfea2f0..fdd5a593445e 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -59,9 +59,10 @@ def execute_model_compiled_dag_remote(self, ignored): return output except ImportError as e: - logger.warning(f"Failed to import Ray with {e!r}. " - "For distributed inference, please install Ray with " - "`pip install ray`.") + logger.warning( + f"Unable to import Ray with {e!r}. " + "For multi-node distributed inference, please install Ray with " + "`pip install ray`.") ray = None RayWorkerVllm = None @@ -101,8 +102,6 @@ def initialize_cluster( ray.init(address=ray_address, ignore_reinit_error=True) if not parallel_config.worker_use_ray: - assert parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") return None # Create placement group for worker processes diff --git a/vllm/logger.py b/vllm/logger.py index d25fcef9ba2e..2a5c592f87aa 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -4,6 +4,7 @@ import logging import sys import os +import multiprocessing as mp VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) @@ -11,6 +12,13 @@ _DATE_FORMAT = "%m-%d %H:%M:%S" +class ProcessLogger(logging.LoggerAdapter): + + def process(self, msg, kwargs): + msg = f"[{self.extra['process_name']} pid {self.extra['pid']}] {msg}" + return msg, kwargs + + class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" @@ -58,4 +66,9 @@ def init_logger(name: str): if VLLM_CONFIGURE_LOGGING: logger.addHandler(_default_handler) logger.propagate = False + if mp.parent_process() is not None: + logger = ProcessLogger(logger, { + 'process_name': mp.current_process().name, + 'pid': os.getpid() + }) return logger diff --git a/vllm/utils.py b/vllm/utils.py index 9cdf62337951..8c22cf54ddd9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -3,6 +3,7 @@ import socket import subprocess import uuid +from collections.abc import Iterable from platform import uname from typing import List, Tuple, Union from packaging.version import parse, Version @@ -199,7 +200,7 @@ def get_open_port() -> int: return s.getsockname()[1] -def set_cuda_visible_devices(device_ids: List[int]) -> None: +def set_cuda_visible_devices(device_ids: Iterable[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) From bd58c6806962b13a032ea5398352560a4177d333 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 19 Feb 2024 14:10:53 -0800 Subject: [PATCH 2/6] Fix setuptools extras_requires --- setup.py | 6 +++--- vllm/engine/local_worker_utils.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 5032c1a89fa8..756ecbc15dea 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import subprocess import warnings from pathlib import Path -from typing import List, Set +from typing import List, Set, Dict from packaging.version import parse, Version import setuptools @@ -435,10 +435,10 @@ def get_requirements() -> List[str]: return requirements -def get_ray_requirement() -> Optional[List[str]]: +def get_ray_requirement() -> Optional[Dict[str, List[str]]]: if _is_neuron(): return None - return ["ray >= 2.9"] + return {"ray": ["ray >= 2.9"]} package_data = { diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index 3b4711bbc033..6a061339eca2 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -68,7 +68,7 @@ def run(self): for result in iter(self.result_queue.get, _TERMINATE): future = self.tasks.pop(result.task_id) _set_future_result(future, result) - # Ensure that all waiters will receive an exception + # Ensure that all waiters will receive an exception for future in self.tasks.values(): _set_future_result( future, Result(exception=ChildProcessError("worker died"))) @@ -88,6 +88,7 @@ def __init__(self, workers: List['LocalWorkerVllm'], self._close = False def run(self) -> None: + # Blocks until any worker exits dead_sentinels = wait([p.sentinel for p in self.workers]) if self._close: return @@ -163,7 +164,7 @@ def kill_worker(self): self.kill() def run(self) -> None: - del self.tasks + del self.tasks # Not used in forked process from vllm.worker.worker import Worker self.worker = Worker(*self.worker_args, **self.worker_kwargs) del self.worker_args From dde90a3265a31ff63a88b3462ce084613e94f87c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 27 Feb 2024 15:04:16 -0800 Subject: [PATCH 3/6] Rename initialize_cluster to initialize_ray_cluster --- vllm/__init__.py | 4 ++-- vllm/engine/async_llm_engine.py | 6 +++--- vllm/engine/llm_engine.py | 10 +++++----- vllm/engine/ray_utils.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index f1e30f5eb6e6..5e40c3c20fcd 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -3,7 +3,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_cluster +from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams @@ -19,5 +19,5 @@ "EngineArgs", "AsyncLLMEngine", "AsyncEngineArgs", - "initialize_cluster", + "initialize_ray_cluster", ] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0d131a78c0ec..f20976182e2e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -9,7 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_cluster, ray +from vllm.engine.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -680,8 +680,8 @@ def from_engine_args(cls, engine_configs = engine_args.create_engine_configs() parallel_config = engine_configs[2] # Initialize the cluster. - placement_group = initialize_cluster(parallel_config, - engine_args.engine_use_ray) + placement_group = initialize_ray_cluster(parallel_config, + engine_args.engine_use_ray) # Create the async LLM engine. engine = cls(parallel_config.worker_use_ray, engine_args.engine_use_ray, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 57584def9e2f..983b4a10a9b3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -14,7 +14,7 @@ from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats -from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.engine.ray_utils import RayWorkerVllm, initialize_ray_cluster, ray from vllm.engine.local_worker_utils import LocalWorkerVllm, WorkerMonitor, ResultHandler from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -120,7 +120,6 @@ def __init__( self.seq_counter = Counter() # Create the parallel GPU workers. - self.worker_monitor = None if self.parallel_config.worker_use_ray: # Disable Ray usage stats collection. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") @@ -214,8 +213,9 @@ def _init_workers(self): def __del__(self): # Terminate local worker processes when engine is garbage collected - if self.worker_monitor is not None: - self.worker_monitor.close() + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() def _init_driver_worker_and_model(self, rank: int, local_rank: int, distributed_init_method: str): @@ -431,7 +431,7 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": engine_configs = engine_args.create_engine_configs() parallel_config = engine_configs[2] # Initialize the cluster. - placement_group = initialize_cluster(parallel_config) + placement_group = initialize_ray_cluster(parallel_config) # Create the LLM engine. engine = cls(*engine_configs, placement_group, diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index fdd5a593445e..94ad24b025c6 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -70,7 +70,7 @@ def execute_model_compiled_dag_remote(self, ignored): from ray.util.placement_group import PlacementGroup -def initialize_cluster( +def initialize_ray_cluster( parallel_config: ParallelConfig, engine_use_ray: bool = False, ray_address: Optional[str] = None, From 6000d64f9f82bfaab26ff9313b2fa5a5a37da953 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 4 Mar 2024 08:05:44 -0800 Subject: [PATCH 4/6] Fix CUDA initialization after rebase --- vllm/config.py | 8 ++++---- vllm/engine/arg_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 68758b8ab061..8e212d084446 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -500,12 +500,12 @@ class DeviceConfig: def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if torch.cuda.is_available(): - self.device_type = "cuda" - elif is_neuron(): + if is_neuron(): self.device_type = "neuron" else: - raise RuntimeError("No supported device detected.") + # We don't call torch.cuda.is_available() here to + # avoid initializing CUDA before workers are forked + self.device_type = "cuda" else: # Device type is assigned explicitly self.device_type = device diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 765baa90b953..0079b27a20c5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -151,7 +151,7 @@ def add_cli_args( # Parallel arguments parser.add_argument( '--worker-use-ray', - action='store_true', + action=argparse.BooleanOptionalAction, default=None, help='use Ray for distributed serving, will default ' 'to true when ray is installed and more than 1 GPU is used') From f69dd4d081d19f5649763e44e933d9d15624ef3c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 4 Mar 2024 15:54:10 -0800 Subject: [PATCH 5/6] Clean up worker logging --- vllm/engine/local_worker_utils.py | 49 +++++++++++++++++-------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index 6a061339eca2..5165af92b0be 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -1,5 +1,4 @@ import asyncio -import os import traceback import threading import multiprocessing as mp @@ -164,6 +163,10 @@ def kill_worker(self): self.kill() def run(self) -> None: + # Re-init logger in forked process, to include worker-specific prefix + global logger + logger = init_logger(__name__) + del self.tasks # Not used in forked process from vllm.worker.worker import Worker self.worker = Worker(*self.worker_args, **self.worker_kwargs) @@ -172,24 +175,26 @@ def run(self) -> None: # Accept tasks from the engine in task_queue # and return task output in result_queue - logger.info( - f"Worker {mp.current_process().name} pid {os.getpid()} ready; " - "awaiting tasks") - for items in iter(self._task_queue.get, _TERMINATE): - output = None - exception = None - task_id, method, args, kwargs = items - try: - executor = getattr(self.worker, method) - output = executor(*args, **kwargs) - except BaseException as e: - tb = traceback.format_exc() - logger.error( - f"Exception in worker {mp.current_process().name} " - f"while processing method {method}: {e}, {tb}") - exception = e - self.result_queue.put( - Result(task_id=task_id, value=output, exception=exception)) - - logger.info( - f"Worker {mp.current_process().name} pid {os.getpid()} exiting") + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(self._task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(self.worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + f"Exception in worker {mp.current_process().name} " + f"while processing method {method}: {e}, {tb}") + exception = e + self.result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") From 6ad3fa6184024fb76a1cfe22b4426f38ead67631 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 5 Mar 2024 14:47:07 -0800 Subject: [PATCH 6/6] Use dedicated multiprocessing context for workers --- vllm/engine/local_worker_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py index 5165af92b0be..b5db52ce9d75 100644 --- a/vllm/engine/local_worker_utils.py +++ b/vllm/engine/local_worker_utils.py @@ -1,7 +1,7 @@ import asyncio +import multiprocessing import traceback import threading -import multiprocessing as mp import uuid from multiprocessing.connection import wait from dataclasses import dataclass @@ -15,6 +15,10 @@ _TERMINATE = "TERMINATE" # sentinel +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp = multiprocessing.get_context("fork") + @dataclass class Result(Generic[T]):