Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions vllm/distributed/device_communicators/cuda_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# This file is a pure Python wrapper for the cudart library.
# It avoids the need to compile a separate shared library, and is
# convenient for use when we just need to call a few functions.

import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa

from vllm.logger import init_logger

logger = init_logger(__name__)

# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html

cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int


class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]


@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]


class CudaRTLibrary:
exported_functions = [
# ​cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# ​cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),

# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),

# ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# ​cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),

# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
]

# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}

# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}

def __init__(self, so_file: Optional[str] = None):
if so_file is None:
assert torch.version.cuda is not None
major_version = torch.version.cuda.split(".")[0]
so_file = f"libcudart.so.{major_version}"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]

if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]

def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")

def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")

def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))

def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())

def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())

def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr

def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))

def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))

def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))

def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
return handle

def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
return devPtr
166 changes: 93 additions & 73 deletions vllm/distributed/device_communicators/custom_all_reduce_utils.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,88 @@
import ctypes
import json
import os
import sys
import tempfile
import time
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional
from typing import Dict, List, Optional

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless

logger = init_logger(__name__)


@contextmanager
def mute_output():
with open(os.devnull, "w") as f:
sys.stderr = f
sys.stdout = f
yield


def producer(i: int,
init_method: str,
def producer(batch_is: List[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
with mute_output():
dist.init_process_group(
backend="gloo",
init_method=init_method,
world_size=2,
rank=0,
)
# produce a tensor in GPU i
data = torch.zeros((128, ), device=f"cuda:{i}")
# get the information to reconstruct the shared tensor
func, args = torch.multiprocessing.reductions.reduce_tensor(data)
args = list(args)
dist.broadcast_object_list([(func, args)], src=0)
dist.barrier()
torch.cuda.synchronize()
assert torch.all(data == 1).item()


def consumer(j: int,
init_method: str,

lib = CudaRTLibrary()
for i in batch_is:
lib.cudaSetDevice(i)
pointer = lib.cudaMalloc(1024)
lib.cudaMemset(pointer, 1, 1024)
lib.cudaDeviceSynchronize()
handle = lib.cudaIpcGetMemHandle(pointer)
producer_queue.put(handle)
open_success = consumer_queue.get()
if open_success:
producer_queue.put(0)
consumer_queue.get()
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()


def consumer(batch_js: List[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
with mute_output():
dist.init_process_group(
backend="gloo",
init_method=init_method,
world_size=2,
rank=1,
)
torch.cuda.set_device(j)
recv = [None]
dist.broadcast_object_list(recv, src=0)
func: Callable
args: List
func, args = recv[0] # type: ignore
# `args[6]` is the device id
# by default pytorch will use `i` from the producer
# here we need to set it to `j` to test P2P access
args[6] = j
data = func(*args)
data += 1
dist.barrier()
torch.cuda.synchronize()
assert torch.all(data == 1).item()


def can_actually_p2p(i, j):

lib = CudaRTLibrary()
for j in batch_js:
lib.cudaSetDevice(j)
handle = producer_queue.get()
open_success = False
try:
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
open_success = True
except RuntimeError:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue.put(open_success)
if open_success:
lib.cudaMemset(pointer, 2, 1024)
producer_queue.get()
consumer_queue.put(0)
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()


def can_actually_p2p(
batch_is: List[int],
batch_js: List[int],
):
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes
Expand All @@ -101,30 +106,39 @@ def can_actually_p2p(i, j):
tensor in process j will be reflected in the tensor in process i, because
they are the same memory segment.
It is important to note that process j accesses the tensor in GPU j, not
GPU i. That's why we need p2p access. # noqa
GPU i. That's why we need p2p access.

The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch). # noqa
"""
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs

# make sure the temp file is not the same across different calls
temp_path = tempfile.mktemp() + str(time.time())
# create an empty file
with open(temp_path, "w"):
pass
init_method = f"file://{temp_path}"

# make sure the processes are spawned
smp = mp.get_context("spawn")
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
pi = smp.Process(target=producer,
args=(i, init_method, cuda_visible_devices))
args=(batch_is, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
pj = smp.Process(target=consumer,
args=(j, init_method, cuda_visible_devices))
args=(batch_js, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
pi.start()
pj.start()
pi.join()
pj.join()
return pi.exitcode == 0 and pj.exitcode == 0
result = []
for i, j in zip(batch_is, batch_js):
a = result_queue.get()
b = result_queue.get()
assert a == b
result.append(a)
return result


# why do we need this cache?
Expand Down Expand Up @@ -169,9 +183,15 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
cache = {}
batch_is = []
batch_js = []
for _i in range(num_dev):
for _j in range(num_dev):
cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j)
batch_is.append(_i)
batch_js.append(_j)
result = can_actually_p2p(batch_is, batch_js)
for _i, _j, r in zip(batch_is, batch_js, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
Expand Down