diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa4582bbda0c..79e74766e855 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -33,10 +33,11 @@ steps: num_gpus: 2 # only support 1 or 2 for now. commands: - pytest -v -s test_pynccl.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py + # Use spawn to avoid CUDA re-init issues + - TEST_DIST_MODEL=facebook/opt-125m MULTIPROC_METHOD=spawn pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf MULTIPROC_METHOD=spawn pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m MULTIPROC_METHOD=spawn pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf MULTIPROC_METHOD=spawn pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/Dockerfile b/Dockerfile index d1d29177b0f4..eaaaa44f83d5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -113,7 +113,7 @@ RUN ldconfig /usr/local/cuda-12.1/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ - pip install dist/*.whl --verbose + pip install "$(echo dist/*.whl)[ray]" --verbose RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ --mount=type=cache,target=/root/.cache/pip \ diff --git a/requirements-cuda.txt b/requirements-cuda.txt index c6d2cd46aee5..e2115b586015 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -2,7 +2,6 @@ -r requirements-common.txt # Dependencies for NVIDIA GPUs -ray >= 2.9 pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 903845b64d98..0c944b138f6d 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -1,5 +1,5 @@ # Common dependencies -r requirements-common.txt -# Dependencies for AMD GPUs -ray == 2.9.3 +# No specific dependencies currently for AMD GPUs + diff --git a/setup.py b/setup.py index 19a9150ad2e6..0320e97fc7c4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import subprocess import sys from shutil import which -from typing import Dict, List +from typing import Dict, List, Optional import torch from packaging.version import Version, parse @@ -361,6 +361,20 @@ def _read_requirements(filename: str) -> List[str]: return requirements +def get_extra_requirements() -> Optional[Dict[str, List[str]]]: + extras = {"tensorizer": ["tensorizer==2.9.0a1"]} + if _is_cuda(): + extras["ray"] = ["ray>=2.9"] + elif _is_hip(): + extras["ray"] = ["ray==2.9.3"] + elif _is_neuron() or _is_cpu(): + pass + else: + raise ValueError( + "Unsupported platform, please use CUDA, ROCM or Neuron.") + return extras + + ext_modules = [] if _is_cuda(): @@ -405,9 +419,7 @@ def _read_requirements(filename: str) -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - extras_require={ - "tensorizer": ["tensorizer==2.9.0a1"], - }, + extras_require=get_extra_requirements(), cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 77aa90b12bf8..faed6f1e82e7 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -25,6 +25,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("worker_use_ray", [False, True]) def test_models( hf_runner, vllm_runner, @@ -32,17 +33,17 @@ def test_models( model: str, dtype: str, max_tokens: int, + worker_use_ray: bool, ) -> None: hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - ) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + worker_use_ray=worker_use_ray) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 737b1f316951..a57660794f00 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -27,6 +27,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("chunked_prefill_token_size", [16]) +@pytest.mark.parametrize("worker_use_ray", [False, True]) def test_models( hf_runner, vllm_runner, @@ -35,6 +36,7 @@ def test_models( dtype: str, max_tokens: int, chunked_prefill_token_size: int, + worker_use_ray: bool, ) -> None: # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) @@ -53,6 +55,7 @@ def test_models( max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + worker_use_ray=worker_use_ray, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/engine/test_local_workers.py b/tests/engine/test_local_workers.py new file mode 100644 index 000000000000..dc84f20fe130 --- /dev/null +++ b/tests/engine/test_local_workers.py @@ -0,0 +1,179 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, + WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[LocalWorkerVllm], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + LocalWorkerVllm(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + for worker in workers: + worker.start() + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/vllm/config.py b/vllm/config.py index dce2944b2ee8..a2e9dfa8ae75 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ import enum +import importlib.util import io import json import os @@ -422,7 +423,7 @@ def verify_with_parallel_config( @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -446,9 +447,9 @@ def create_config( tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -477,9 +478,9 @@ class ParallelConfig: Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to + worker_use_ray: Whether to use Ray for model workers. Will default to True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + greater than 1 and Ray is installed. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -495,7 +496,7 @@ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, @@ -512,8 +513,10 @@ def __init__( self.placement_group = placement_group self.world_size = pipeline_parallel_size * self.tensor_parallel_size - if self.world_size > 1: - self.worker_use_ray = True + if self.worker_use_ray is None: + ray_found = importlib.util.find_spec("ray") is not None + self.worker_use_ray = ray_found and self.world_size > 1 + self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 831a03be65f6..e685f3589ab2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,7 +27,7 @@ class EngineArgs: quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None - worker_use_ray: bool = False + worker_use_ray: Optional[bool] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -201,10 +201,12 @@ def add_cli_args( help='model context length. If unspecified, ' 'will be automatically derived from the model.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') + parser.add_argument( + '--worker-use-ray', + action=argparse.BooleanOptionalAction, + default=None, + help='use Ray for distributed serving, will default ' + 'to true when ray is installed and more than 1 GPU is used') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f61049513512..a53ce630fc24 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -338,9 +338,11 @@ def from_engine_args( initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync + elif engine_config.parallel_config.world_size > 1: + from vllm.executor.multiproc_gpu_executor import ( + MultiProcGPUExecutorAsync) + executor_class = MultiProcGPUExecutorAsync else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee..ad5394e97491 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -224,9 +224,11 @@ def from_engine_args( initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor + elif engine_config.parallel_config.world_size > 1: + from vllm.executor.multiproc_gpu_executor import ( + MultiProcGPUExecutor) + executor_class = MultiProcGPUExecutor else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor @@ -244,6 +246,12 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) diff --git a/vllm/engine/local_worker_utils.py b/vllm/engine/local_worker_utils.py new file mode 100644 index 000000000000..4a5d12b08704 --- /dev/null +++ b/vllm/engine/local_worker_utils.py @@ -0,0 +1,242 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from io import TextIOBase +from multiprocessing.connection import wait +from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp_method = os.getenv("MULTIPROC_METHOD", "fork") +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID = None + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + if self.result.exception is not None: + raise self.result.exception + return self.result.value + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for future in self.tasks.values(): + _set_future_result( + future, Result(exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['LocalWorkerVllm'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([p.sentinel for p in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + if worker.sentinel in dead_sentinels: + worker.join(1) + if worker.exitcode is not None and worker.exitcode != 0: + logger.error( + f"Worker {worker.name} pid {worker.pid} died, " + f"exit code: {worker.exitcode}") + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.join(2) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class LocalWorkerVllm(mp.Process): + """Local process wrapper for vllm.worker.Worker + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + super().__init__(daemon=True) + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.worker_factory = worker_factory + self.worker = None + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.kill() + + def run(self) -> None: + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + del self.tasks # Not used in forked process + self.worker = self.worker_factory() + del self.worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(self._task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(self.worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + f"Exception in worker {mp.current_process().name} " + f"while processing method {method}: {e}, {tb}") + exception = e + self.result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIOBase, worker_name: str, pid: int) -> None: + """Prepend output with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False + + file.start_new_line = True + file.write = write_with_prefix diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 04d4ed83976d..d78e73ed8ba1 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -73,9 +73,10 @@ def execute_model_compiled_dag_remote(self, ignored): return output except ImportError as e: - logger.warning(f"Failed to import Ray with {e!r}. " - "For distributed inference, please install Ray with " - "`pip install ray`.") + logger.warning( + f"Unable to import Ray with {e!r}. " + "For multi-node distributed inference, please install Ray with " + "`pip install ray`.") ray = None # type: ignore RayWorkerVllm = None # type: ignore diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbb6ec80f7b7..bbe543638879 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -94,6 +94,13 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + class ExecutorAsyncBase(ExecutorBase): diff --git a/vllm/executor/multi_gpu_executor.py b/vllm/executor/multi_gpu_executor.py new file mode 100644 index 000000000000..499667e47350 --- /dev/null +++ b/vllm/executor/multi_gpu_executor.py @@ -0,0 +1,140 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +class MultiGPUExecutor(ExecutorBase): + """Abstract superclass of multi-GPU executor implementations.""" + + def _init_driver_worker_and_model(self, rank: int, local_rank: int, + distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + + # Initialize the driver worker with the Worker class. + self.driver_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, + is_driver_worker=True, + ) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model(self, *args, **kwargs) -> SamplerOutput: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + +class MultiGPUExecutorAsync(MultiGPUExecutor, ExecutorAsyncBase): + + @abstractmethod + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + async def execute_model_async(self, *args, **kwargs) -> SamplerOutput: + all_outputs = await self._run_workers_async("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 000000000000..b00472ebe770 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,147 @@ +import asyncio +import os +from functools import partial +from typing import Any, Dict, Optional, Tuple + +from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler, + WorkerMonitor) +from vllm.executor.multi_gpu_executor import (MultiGPUExecutor, + MultiGPUExecutorAsync) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async, set_cuda_visible_devices) + +logger = init_logger(__name__) + + +def _create_worker(*args, **kwargs): + # Import within worker process to avoid CUDA init issues + from vllm.worker.worker import Worker + return Worker(*args, **kwargs) + + +class MultiProcGPUExecutor(MultiGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + set_cuda_visible_devices(range(world_size)) + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + LocalWorkerVllm( + result_handler, + partial( + _create_worker, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, + )) for rank in range(1, world_size) + ] + + for worker in self.workers: + worker.start() + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self._init_driver_worker_and_model(0, 0, distributed_init_method) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + +class MultiProcGPUExecutorAsync(MultiProcGPUExecutor, MultiGPUExecutorAsync): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + driver_executor = make_async(getattr(self.driver_worker, method)) + + # Run all the workers asynchronously. + coros = [driver_executor(*driver_args, **driver_kwargs)] + [ + worker.execute_method_async(method, *args, **kwargs) + for worker in self.workers + ] + + return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5db2f3f65253..2c949d778b91 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,12 +3,12 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.engine.ray_utils import RayWorkerVllm, ray -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.multi_gpu_executor import (MultiGPUExecutor, + MultiGPUExecutorAsync) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async, set_cuda_visible_devices) @@ -27,7 +27,7 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(MultiGPUExecutor): def _init_executor(self) -> None: assert (not self.speculative_config @@ -154,69 +154,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", tensorizer_config=self.tensorizer_config, )) - # Initialize the driver worker with the Worker class. driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=driver_local_rank, - rank=driver_rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, - is_driver_worker=True, - ) - - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._init_driver_worker_and_model(driver_rank, driver_local_rank, + distributed_init_method) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -237,23 +178,6 @@ def execute_model(self, output = all_outputs[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - def _run_workers( self, method: str, @@ -292,19 +216,17 @@ def _run_workers( method)(*driver_args, **driver_kwargs) # Get the results of the ray workers. - if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs @@ -346,7 +268,7 @@ def _check_if_any_actor_is_dead(self): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, MultiGPUExecutorAsync): async def _run_workers_async( self, @@ -374,23 +296,3 @@ async def _run_workers_async( all_outputs = await asyncio.gather(*coros) return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output diff --git a/vllm/utils.py b/vllm/utils.py index 4c0dc9ca729a..193c9f225283 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,6 +7,7 @@ import uuid import warnings from collections import OrderedDict, defaultdict +from collections.abc import Iterable from functools import lru_cache, partial from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, @@ -270,7 +271,7 @@ def get_open_port() -> int: return s.getsockname()[1] -def set_cuda_visible_devices(device_ids: List[int]) -> None: +def set_cuda_visible_devices(device_ids: Iterable[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) @@ -492,7 +493,7 @@ def maybe_expand_dim(tensor: torch.Tensor, def merge_dicts(dict1: Dict[Any, List[Any]], dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. - + When a key conflicts, the values in dict1 is prioritized. """ merged_dict = defaultdict(list)