diff --git a/requirements.txt b/requirements.txt index 5684b2c29634..b030ae616919 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server. aioprometheus[starlette] pynvml == 11.5.0 triton >= 2.1.0 +cupy-cuda12x == 12.3.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 03a2b1157652..86f092520930 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -283,7 +283,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", is_driver_worker=True, ) - self._run_workers("init_model") + self._run_workers("init_model", cupy_port=get_open_port()) self._run_workers( "load_model", max_concurrent_workers=self.parallel_config. diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 65671994f330..cf805df892fd 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,14 +1,15 @@ from collections import namedtuple from typing import Any, Dict, List, Optional, Union -from torch.distributed import ProcessGroup - import torch +from torch.distributed import ProcessGroup +from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, + is_cupy_nccl_enabled_for_all_reduce, ) from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce @@ -31,8 +32,12 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - torch.distributed.all_reduce(input_, - group=get_tensor_model_parallel_group()) + if is_cupy_nccl_enabled_for_all_reduce(): + # TODO: support multiple parallel groups. + cupy_utils.all_reduce(input_) + else: + torch.distributed.all_reduce(input_, + group=get_tensor_model_parallel_group()) return input_ diff --git a/vllm/model_executor/parallel_utils/cupy_utils.py b/vllm/model_executor/parallel_utils/cupy_utils.py new file mode 100644 index 000000000000..f8cffc01e3c3 --- /dev/null +++ b/vllm/model_executor/parallel_utils/cupy_utils.py @@ -0,0 +1,130 @@ +"""CuPy utilities for all-reduce. + +We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing +CUDA graphs, because torch.distributed.all_reduce causes errors when capturing +CUDA graphs. + +NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8. +TODO: Remove this file when torch.distributed.all_reduce is fixed. +""" +import contextlib + +import torch +from torch.distributed import ReduceOp + +try: + import cupy + from cupy.cuda import nccl + from cupyx.distributed import NCCLBackend +except ImportError as e: + cupy = e + nccl = None + + class NCCLBackend: + ... + + +_OP_MAPPING = { + ReduceOp.SUM: "sum", + ReduceOp.PRODUCT: "prod", + ReduceOp.MIN: "min", + ReduceOp.MAX: "max", +} + + +class NCCLBackendWithBFloat16(NCCLBackend): + # This is enough to add bfloat16 support for most operations, + # but broadcast will fail (will require changes in compiled + # cupy code). + def _get_nccl_dtype_and_count(self, array, count=None): + nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count) + torch_dtype = getattr(array, "_torch_dtype", None) + if torch_dtype is torch.bfloat16: + nccl_dtype = nccl.NCCL_BFLOAT16 + return nccl_dtype, count + + def barrier(self) -> None: + raise RuntimeError( + "Currently, CuPy NCCL barrier is not supported since the TCP " + "store is immediately stopped after the initialization.") + + +_NCCL_BACKEND = None +_WORLD_SIZE = 0 + + +def is_initialized() -> bool: + """Returns whether the NCCL backend is initialized.""" + return _NCCL_BACKEND is not None + + +@contextlib.contextmanager +def set_cupy_stream(stream: torch.cuda.Stream): + """Set the cuda stream for communication""" + cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream, + stream.device_index) + with cupy_stream: + yield + + +def init_process_group(world_size: int, rank: int, host: str, + port: int) -> None: + """Initializes the CuPy NCCL backend. + + # TODO: handle NCCL timeouts. + """ + assert not is_initialized() + + if isinstance(cupy, Exception): + raise ImportError( + "NCCLBackend is not available. Please install cupy.") from cupy + + # TODO(woosuk): Create TP and PP process groups for CuPy. + global _NCCL_BACKEND + global _WORLD_SIZE + assert world_size > 0, f"{world_size=} should be a positive integer" + assert 0 <= rank < world_size, ( + f"{rank=} should be a integer between [0, {world_size})") + + cupy.cuda.runtime.setDevice(torch.cuda.current_device()) + _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) + _WORLD_SIZE = world_size + + # Stop the TCP store to prevent the deadlock issues at termination time. + # FIXME(woosuk): This is hacky. Find a more robust solution. + if rank == 0 and hasattr(_NCCL_BACKEND, "_store"): + _NCCL_BACKEND._store.stop() + + +def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: + """All-reduces the input tensor across the process group.""" + assert input_.is_cuda, f"{input_} should be a cuda tensor" + # Hack to support bfloat16 + torch_dtype = input_.dtype + if torch_dtype is torch.bfloat16: + # We need to view as float16, otherwise + # cupy will fail. This will not change + # the underlying data. + input_ = input_.view(torch.float16) + cupy_input = cupy.asarray(input_) + cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access + _NCCL_BACKEND.all_reduce(in_array=cupy_input, + out_array=cupy_input, + op=_OP_MAPPING[op]) + + +def destroy_process_group() -> None: + """Destroys the NCCL backend.""" + global _NCCL_BACKEND + global _WORLD_SIZE + _NCCL_BACKEND = None + _WORLD_SIZE = 0 + + +def get_world_size() -> int: + """Returns the world size.""" + return _WORLD_SIZE + + +def get_nccl_backend(): + return _NCCL_BACKEND diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 59cc19653857..aeb07f64c37d 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -3,9 +3,12 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" +import contextlib import torch +from vllm.model_executor.parallel_utils import cupy_utils + # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -206,3 +209,37 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None + + # Destroy the cupy states if any. + cupy_utils.destroy_process_group() + + +# Whether to use cupy for nccl all reduce. +# We use cupy for all reduce when using CUDA graph, because torch.distributed +# is not well supported by CUDA graph. +_ENABLE_CUPY_FOR_ALL_REDUCE = False + + +@contextlib.contextmanager +def with_cupy_nccl_for_all_reduce(): + """use CuPy nccl instead of torch.distributed for all reduce""" + tp_size = get_tensor_model_parallel_world_size() + if tp_size == 1: + # No-op. + # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. + yield + else: + global _ENABLE_CUPY_FOR_ALL_REDUCE + old = _ENABLE_CUPY_FOR_ALL_REDUCE + _ENABLE_CUPY_FOR_ALL_REDUCE = True + + stream = torch.cuda.current_stream() + with cupy_utils.set_cupy_stream(stream): + yield + _ENABLE_CUPY_FOR_ALL_REDUCE = old + + +def is_cupy_nccl_enabled_for_all_reduce(): + """check if CuPy nccl is enabled for all reduce""" + global _ENABLE_CUPY_FOR_ALL_REDUCE + return _ENABLE_CUPY_FOR_ALL_REDUCE diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 4f74c05038e7..75bf6ce373d9 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -15,8 +15,11 @@ def init_test_distributed_environment( tensor_parallel_size, worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" - init_distributed_environment(parallel_config, rank, - distributed_init_method) + init_distributed_environment( + parallel_config, + rank, + cupy_port=None, + distributed_init_method=distributed_init_method) def multi_process_tensor_parallel( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fce0009e3097..62f7530868ad 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -5,11 +5,15 @@ import torch import torch.nn as nn -from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig +from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, + SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend +from vllm.model_executor.parallel_utils.parallel_state import ( + with_cupy_nccl_for_all_reduce) from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata @@ -644,6 +648,10 @@ def list_loras(self) -> Set[int]: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: + # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never + # deleted before the CUDA graphs. + self.cupy_nccl_backend = get_nccl_backend() + assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " @@ -674,6 +682,12 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. + # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce + # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use + # either custom all-reduce kernel or CuPy NCCL. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or CuPy NCCL if it is disabled or not supported. with custom_all_reduce.capture(): for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. @@ -713,6 +727,14 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: # This usually takes < 10 seconds. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") + def __del__(self) -> None: + # Delete the CUDA graphs before deleting the CuPy NCCL communicator. + # NOTE(woosuk): This is necessary because otherwise deadlocks can + # happen. + # FIXME(woosuk): This is a bit hacky. Find a more robust solution. + self.graph_runners.clear() + self.cupy_nccl_backend = None + class CUDAGraphRunner: @@ -734,18 +756,8 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - self.model( - input_ids, - positions, - kv_caches, - input_metadata, - ) - torch.cuda.synchronize() - - # Capture the graph. - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph, pool=memory_pool): - hidden_states = self.model( + with with_cupy_nccl_for_all_reduce(): + self.model( input_ids, positions, kv_caches, @@ -753,6 +765,20 @@ def capture( ) torch.cuda.synchronize() + # Capture the graph. + # NOTE(woosuk): Python 3.8 does not support multi-line with statements. + # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 + with with_cupy_nccl_for_all_reduce(): + hidden_states = self.model( + input_ids, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c97e82a55a1e..b616040367c8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar @@ -67,7 +68,7 @@ def __init__( self.cache_events = None self.gpu_cache = None - def init_model(self) -> None: + def init_model(self, cupy_port: Optional[int] = None) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow @@ -88,7 +89,7 @@ def init_model(self) -> None: f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + cupy_port, self.distributed_init_method) if not self.parallel_config.disable_custom_all_reduce: init_custom_ar() # Initialize the model. @@ -233,6 +234,7 @@ def list_loras(self) -> Set[int]: def init_distributed_environment( parallel_config: ParallelConfig, rank: int, + cupy_port: Optional[int], distributed_init_method: Optional[str] = None, ) -> None: """Initialize the distributed environment.""" @@ -255,8 +257,28 @@ def init_distributed_environment( init_method=distributed_init_method, ) + if cupy_utils.is_initialized(): + cupy_world_size = cupy_utils.get_world_size() + if cupy_world_size != parallel_config.world_size: + raise RuntimeError( + "cupy.distributed is already initialized but the cupy world " + "size does not match parallel_config.world_size " + f"({cupy_world_size} vs. {parallel_config.world_size}).") + elif parallel_config.world_size > 1 and cupy_port is not None: + # NOTE(woosuk): We don't initialize CuPy process group when world size + # is 1. + # TODO(woosuk): Support multi-node connection. + cupy_utils.init_process_group( + world_size=parallel_config.world_size, + rank=rank, + host="localhost", + port=cupy_port, + ) + # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) + if cupy_utils.is_initialized(): + cupy_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)