Skip to content
Merged
20 changes: 2 additions & 18 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _is_full_nvlink(rank, world_size):


def _can_p2p(rank: int, world_size: int) -> bool:
from vllm.distributed.utils import gpu_p2p_access_check
num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
Expand All @@ -155,28 +156,11 @@ def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
continue
if not torch.cuda.can_device_access_peer(rank, i):
return False
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
if not _can_actually_p2p(rank, i):
if not gpu_p2p_access_check(rank, i):
return False
return True


# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c)


class CustomAllreduce:

# max_size: max supported allreduce size
Expand Down
66 changes: 66 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import os
from typing import Sequence

import torch
import torch.distributed as dist

from vllm.logger import init_logger

from .parallel_state import get_cpu_world_group

logger = init_logger(__name__)


def ensure_divisibility(numerator, denominator):
Expand Down Expand Up @@ -46,3 +55,60 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list)

return tensor_list


# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()


_gpu_p2p_access_cache = None


def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""

# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]

is_distributed = dist.is_initialized()

num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.expanduser(
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or dist.get_rank()== 0) \
and (not os.path.exists(path)):
# only the master process can enter this block to calculate the cache
logger.info(f"generating GPU P2P access cache for in {path}")
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info(f"reading GPU P2P access cache from {path}")
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]