Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
pynvml == 11.5.0
triton >= 2.1.0
cupy-cuda12x == 13.0.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
is_driver_worker=True,
)

self._run_workers("init_model")
self._run_workers("init_model", cupy_port=get_open_port())
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union

from torch.distributed import ProcessGroup

import torch
from torch.distributed import ProcessGroup

from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
is_cupy_nccl_enabled_for_all_reduce,
)
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce

Expand All @@ -31,8 +32,12 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
if is_cupy_nccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_


Expand Down
122 changes: 122 additions & 0 deletions vllm/model_executor/parallel_utils/cupy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""CuPy utilities for all-reduce.

We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.

TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib
from unittest.mock import patch

import torch
from torch.distributed import ReduceOp

try:
import cupy
from cupy.cuda import nccl
from cupyx.distributed._nccl_comm import NCCLBackend, _get_nccl_dtype_and_count
except ImportError as e:
cupy = e
nccl = None

class NCCLBackend:
...


_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}

_NCCL_BACKEND = None
_WORLD_SIZE = 0

_get_nccl_dtype_and_count_orginal = _get_nccl_dtype_and_count


def _get_nccl_dtype_and_count_bf16(*args, **kwargs):
"""Patch/hack to force bf16 dtype in cupy NCCL.

cupy doesn't support bf16 by default, but the underlying NCCL
kernels do. We can just force the dtype to be bf16 and it will
work fine."""
dtype, count = _get_nccl_dtype_and_count_orginal(*args, **kwargs)
# Hardcoded to always return bf16 dtype
dtype = nccl.NCCL_BFLOAT16
return dtype, count


def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None


@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream) -> None:
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield


def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.

# TODO: handle NCCL timeouts.
"""
assert not is_initialized()

if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy==13.0.0.") from cupy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"NCCLBackend is not available. Please install cupy==13.0.0.") from cupy
"NCCLBackend is not available. Please install cupy-cuda12x==13.0.0.") from cupy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. Actually, there are two issues:

  1. I changed the PR to use cupy 12.3 instead of 13.0 because cupy 13.0 does not support python 3.8 (I wasn't able to find the wheel in pypi).
  2. Users need to install different versions of cupy depending on their env. For example, CUDA 11.8 users should install cupy-cuda11x. ROCm users should install cupy-rocm.


# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")

cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackend(world_size, rank, host, port)
_WORLD_SIZE = world_size


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
maybe_patch_cupy_bf16 = patch(
"cupyx.distributed._nccl_comm._get_nccl_dtype_and_count",
_get_nccl_dtype_and_count_bf16)
else:
maybe_patch_cupy_bf16 = contextlib.nullcontext()
cupy_input = cupy.asarray(input_)
with maybe_patch_cupy_bf16:
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])


def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0


def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE
37 changes: 37 additions & 0 deletions vllm/model_executor/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib

import torch

from vllm.model_executor.parallel_utils import cupy_utils

# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
Expand Down Expand Up @@ -206,3 +209,37 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None

# Destroy the cupy states if any.
cupy_utils.destroy_process_group()


# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False


@contextlib.contextmanager
def with_cupy_nccl_for_all_reduce():
"""use CuPy nccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
yield
else:
global _ENABLE_CUPY_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True

stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream):
yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old


def is_cupy_nccl_enabled_for_all_reduce():
"""check if CuPy nccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE
7 changes: 5 additions & 2 deletions vllm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def init_test_distributed_environment(
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(parallel_config, rank,
distributed_init_method)
init_distributed_environment(
parallel_config,
rank,
cupy_port=None,
distributed_init_method=distributed_init_method)


def multi_process_tensor_parallel(
Expand Down
36 changes: 24 additions & 12 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
from vllm.model_executor.parallel_utils import custom_all_reduce
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
Expand Down Expand Up @@ -674,6 +676,12 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:

# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
with custom_all_reduce.capture():
for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata.
Expand Down Expand Up @@ -734,25 +742,29 @@ def capture(
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()

# Capture the graph.
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
with with_cupy_nccl_for_all_reduce():
self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()

# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with with_cupy_nccl_for_all_reduce():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()

# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,
Expand Down
26 changes: 24 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
self.cache_events = None
self.gpu_cache = None

def init_model(self) -> None:
def init_model(self, cupy_port: Optional[int] = None) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
Expand All @@ -88,7 +89,7 @@ def init_model(self) -> None:
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
cupy_port, self.distributed_init_method)
if not self.parallel_config.disable_custom_all_reduce:
init_custom_ar()
# Initialize the model.
Expand Down Expand Up @@ -233,6 +234,7 @@ def list_loras(self) -> Set[int]:
def init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
Expand All @@ -255,8 +257,28 @@ def init_distributed_environment(
init_method=distributed_init_method,
)

if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1 and cupy_port is not None:
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
)

# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if cupy_utils.is_initialized():
cupy_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

Expand Down