Skip to content

Commit 8a1e046

Browse files
committed
[Core][AMD] Migrate fully transparent sleep mode to ROCm platform
Signed-off-by: Hollow Man <[email protected]>
1 parent 1769928 commit 8a1e046

File tree

8 files changed

+154
-11
lines changed

8 files changed

+154
-11
lines changed

CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,15 @@ set_gencode_flags_for_srcs(
208208
SRCS "${VLLM_CUMEM_EXT_SRC}"
209209
CUDA_ARCHS "${CUDA_ARCHS}")
210210

211-
if(VLLM_GPU_LANG STREQUAL "CUDA")
211+
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
212212
message(STATUS "Enabling cumem allocator extension.")
213-
# link against cuda driver library
214-
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
213+
if(VLLM_GPU_LANG STREQUAL "CUDA")
214+
# link against cuda driver library
215+
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
216+
else()
217+
# link against rocm driver library
218+
list(APPEND CUMEM_LIBS amdhip64)
219+
endif()
215220
define_gpu_extension_target(
216221
cumem_allocator
217222
DESTINATION vllm

csrc/cumem_allocator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
// need to be unsigned long long
44
#include <iostream>
55

6+
#include "cumem_allocator_compat.h"
7+
68
extern "C" {
79

810
#define PY_SSIZE_T_CLEAN
911
#include <Python.h>
1012

1113
#include <sys/types.h>
12-
#include <cuda_runtime_api.h>
13-
#include <cuda.h>
1414

1515
char error_msg[10240]; // 10KB buffer to store error messages
1616
CUresult no_error = CUresult(0);

csrc/cumem_allocator_compat.h

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#pragma once
2+
3+
#ifdef USE_ROCM
4+
////////////////////////////////////////
5+
// For compatibility with CUDA and ROCm
6+
////////////////////////////////////////
7+
#include <hip/hip_runtime_api.h>
8+
9+
extern "C" {
10+
#ifndef CUDA_SUCCESS
11+
#define CUDA_SUCCESS hipSuccess
12+
#endif // CUDA_SUCCESS
13+
14+
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
15+
typedef unsigned long long CUdevice;
16+
typedef hipDeviceptr_t CUdeviceptr;
17+
typedef hipError_t CUresult;
18+
typedef hipCtx_t CUcontext;
19+
typedef hipStream_t CUstream;
20+
typedef hipMemGenericAllocationHandle_t CUmemGenericAllocationHandle;
21+
typedef hipMemAllocationGranularity_flags CUmemAllocationGranularity_flags;
22+
typedef hipMemAllocationProp CUmemAllocationProp;
23+
typedef hipMemAccessDesc CUmemAccessDesc;
24+
25+
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
26+
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
27+
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
28+
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
29+
30+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
31+
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
32+
33+
// Error Handling
34+
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
35+
CUresult cuGetErrorString(CUresult hipError, const char** pStr) {
36+
*pStr = hipGetErrorString(hipError);
37+
return CUDA_SUCCESS;
38+
}
39+
40+
// Context Management
41+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
42+
CUresult cuCtxGetCurrent(CUcontext* ctx) {
43+
// This API is deprecated on the AMD platform, only for equivalent cuCtx
44+
// driver API on the NVIDIA platform.
45+
return hipCtxGetCurrent(ctx);
46+
}
47+
48+
CUresult cuCtxSetCurrent(CUcontext ctx) {
49+
// This API is deprecated on the AMD platform, only for equivalent cuCtx
50+
// driver API on the NVIDIA platform.
51+
return hipCtxSetCurrent(ctx);
52+
}
53+
54+
// Primary Context Management
55+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
56+
CUresult cuDevicePrimaryCtxRetain(CUcontext* ctx, CUdevice dev) {
57+
return hipDevicePrimaryCtxRetain(ctx, dev);
58+
}
59+
60+
// Virtual Memory Management
61+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
62+
CUresult cuMemAddressFree(CUdeviceptr ptr, size_t size) {
63+
return hipMemAddressFree(ptr, size);
64+
}
65+
66+
CUresult cuMemAddressReserve(CUdeviceptr* ptr, size_t size, size_t alignment,
67+
CUdeviceptr addr, unsigned long long flags) {
68+
return hipMemAddressReserve(ptr, size, alignment, addr, flags);
69+
}
70+
71+
CUresult cuMemCreate(CUmemGenericAllocationHandle* handle, size_t size,
72+
const CUmemAllocationProp* prop,
73+
unsigned long long flags) {
74+
return hipMemCreate(handle, size, prop, flags);
75+
}
76+
77+
CUresult cuMemGetAllocationGranularity(
78+
size_t* granularity, const CUmemAllocationProp* prop,
79+
CUmemAllocationGranularity_flags option) {
80+
return hipMemGetAllocationGranularity(granularity, prop, option);
81+
}
82+
83+
CUresult cuMemMap(CUdeviceptr dptr, size_t size, size_t offset,
84+
CUmemGenericAllocationHandle handle,
85+
unsigned long long flags) {
86+
return hipMemMap(dptr, size, offset, handle, flags);
87+
}
88+
89+
CUresult cuMemRelease(CUmemGenericAllocationHandle handle) {
90+
return hipMemRelease(handle);
91+
}
92+
93+
CUresult cuMemSetAccess(CUdeviceptr ptr, size_t size,
94+
const CUmemAccessDesc* desc, size_t count) {
95+
return hipMemSetAccess(ptr, size, desc, count);
96+
}
97+
98+
CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) {
99+
return hipMemUnmap(ptr, size);
100+
}
101+
} // extern "C"
102+
103+
#else
104+
////////////////////////////////////////
105+
// Import CUDA headers for NVIDIA GPUs
106+
////////////////////////////////////////
107+
#include <cuda_runtime_api.h>
108+
#include <cuda.h>
109+
#endif

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def _read_requirements(filename: str) -> list[str]:
627627

628628
if _is_cuda() or _is_hip():
629629
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
630+
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
630631

631632
if _is_hip():
632633
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,10 @@ def __init__(
321321

322322
from vllm.platforms import current_platform
323323

324-
if self.enable_sleep_mode and not current_platform.is_cuda():
325-
raise ValueError("Sleep mode is only supported on CUDA devices.")
324+
if self.enable_sleep_mode and not (current_platform.is_cuda()
325+
or current_platform.is_rocm()):
326+
raise ValueError(
327+
"Sleep mode is only supported on CUDA/ROCM devices.")
326328

327329
hf_config = get_config(self.hf_config_path or self.model,
328330
trust_remote_code, revision, code_revision,

vllm/device_allocator/cumem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def find_loaded_library(lib_name) -> Optional[str]:
5353
libcudart = CudaRTLibrary()
5454
cumem_available = True
5555
except ModuleNotFoundError:
56-
# rocm platform does not support cumem allocator
56+
# only cuda and rocm platforms support cumem allocator
5757
init_module = None
5858
python_create_and_map = None
5959
python_unmap_and_release = None

vllm/distributed/device_communicators/cuda_wrapper.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,20 @@ class CudaRTLibrary:
9595
]),
9696
]
9797

98+
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
99+
cuda_to_hip_mapping = {
100+
"cudaSetDevice": "hipSetDevice",
101+
"cudaDeviceSynchronize": "hipDeviceSynchronize",
102+
"cudaDeviceReset": "hipDeviceReset",
103+
"cudaGetErrorString": "hipGetErrorString",
104+
"cudaMalloc": "hipMalloc",
105+
"cudaFree": "hipFree",
106+
"cudaMemset": "hipMemset",
107+
"cudaMemcpy": "hipMemcpy",
108+
"cudaIpcGetMemHandle": "hipIpcGetMemHandle",
109+
"cudaIpcOpenMemHandle": "hipIpcOpenMemHandle",
110+
}
111+
98112
# class attribute to store the mapping from the path to the library
99113
# to avoid loading the same library multiple times
100114
path_to_library_cache: Dict[str, Any] = {}
@@ -103,11 +117,21 @@ class CudaRTLibrary:
103117
# to the corresponding dictionary
104118
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
105119

120+
# check if the current process is using ROCm
121+
is_rocm = False
122+
106123
def __init__(self, so_file: Optional[str] = None):
107124
if so_file is None:
108125
so_file = find_loaded_library("libcudart")
109126
if so_file is None:
110-
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
127+
# libcudart is not loaded in the current process, try hip
128+
so_file = find_loaded_library("libamdhip64")
129+
# should be safe to assume now that we are using ROCm
130+
# as the following assertion should error out if the
131+
# libhiprtc library is also not loaded
132+
self.is_rocm = True
133+
if so_file is None:
134+
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
111135
assert so_file is not None, \
112136
(
113137
"libcudart is not loaded in the current process, "
@@ -121,7 +145,9 @@ def __init__(self, so_file: Optional[str] = None):
121145
if so_file not in CudaRTLibrary.path_to_dict_mapping:
122146
_funcs = {}
123147
for func in CudaRTLibrary.exported_functions:
124-
f = getattr(self.lib, func.name)
148+
f = getattr(
149+
self.lib, CudaRTLibrary.cuda_to_hip_mapping[func.name]
150+
if self.is_rocm else func.name)
125151
f.restype = func.restype
126152
f.argtypes = func.argtypes
127153
_funcs[func.name] = f

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10421042
action="store_true",
10431043
default=False,
10441044
help="Enable sleep mode for the engine. "
1045-
"(only cuda platform is supported)")
1045+
"(only cuda and hip platforms are supported)")
10461046

10471047
parser.add_argument(
10481048
'--calculate-kv-scales',

0 commit comments

Comments
 (0)