Skip to content

Commit 62936e3

Browse files
maktukmakbigPYJ1151
authored andcommitted
CPU only build (vllm-project#9)
1 parent e20ae23 commit 62936e3

File tree

11 files changed

+185
-36
lines changed

11 files changed

+185
-36
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ sanitizer:
2828
py_install:
2929
VLLM_BUILD_CPU_OPS=1 MAX_JOBS=JOBS pip install --no-build-isolation -v -e .
3030

31+
py_install_cpu:
32+
VLLM_BUILD_CPU_ONLY=1 MAX_JOBS=JOBS pip install --no-build-isolation -v -e .
33+
3134
package:
3235
VLLM_BUILD_CPU_OPS=1 MAX_JOBS=JOBS python setup.py bdist_wheel
3336
echo "Wheel package is saved in ./dist/"

cpu.Dockerfile

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
FROM python:3.10 AS dev
2+
3+
RUN apt-get update -y \
4+
&& apt-get install -y python3-pip
5+
6+
WORKDIR /workspace
7+
8+
# install build and runtime dependencies
9+
COPY requirements-cpu.txt requirements-cpu.txt
10+
RUN --mount=type=cache,target=/root/.cache/pip \
11+
pip install -r requirements-cpu.txt
12+
13+
# install development dependencies
14+
COPY requirements-dev.txt requirements-dev.txt
15+
RUN --mount=type=cache,target=/root/.cache/pip \
16+
pip install -r requirements-dev.txt
17+
18+
# image to build pytorch extensions
19+
FROM dev AS build
20+
21+
# install build dependencies
22+
COPY requirements-build-cpu.txt requirements-build-cpu.txt
23+
RUN --mount=type=cache,target=/root/.cache/pip \
24+
pip install -r requirements-build-cpu.txt
25+
26+
# copy input files
27+
COPY csrc csrc
28+
COPY setup.py setup.py
29+
COPY requirements-cpu.txt requirements-cpu.txt
30+
COPY pyproject.toml pyproject.toml
31+
COPY vllm/__init__.py vllm/__init__.py
32+
33+
# max jobs used by Ninja to build extensions
34+
ENV MAX_JOBS=$max_jobs
35+
RUN python3 setup.py build_ext --inplace
36+
37+
# image to run unit testing suite
38+
FROM dev AS test
39+
40+
# copy pytorch extensions separately to avoid having to rebuild
41+
# when python code changes
42+
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
43+
COPY tests tests
44+
COPY vllm vllm
45+
46+
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
47+
48+
# use CUDA base as CUDA runtime dependencies are already installed via pip
49+
FROM python:3.10 AS dev
50+
51+
# libnccl required for ray
52+
RUN apt-get update -y \
53+
&& apt-get install -y python3-pip
54+
55+
WORKDIR /workspace
56+
COPY requirements-cpu.txt requirements-cpu.txt
57+
RUN --mount=type=cache,target=/root/.cache/pip \
58+
pip install -r requirements-cpu.txt
59+
60+
FROM vllm-base AS vllm
61+
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
62+
COPY vllm vllm
63+
64+
EXPOSE 8000
65+
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
66+
67+
# openai api server alternative
68+
FROM vllm-base AS vllm-openai
69+
# install additional dependencies for openai api server
70+
RUN --mount=type=cache,target=/root/.cache/pip \
71+
pip install accelerate fschat
72+
73+
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
74+
COPY vllm vllm
75+
76+
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
77+

csrc/dispatch_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
1515
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
1616

17+
#ifdef VLLM_BUILD_CPU_ONLY
18+
#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...)
19+
#else
1720
#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...) \
1821
case c10::DeviceType::CUDA: { \
1922
return BASENAME(__VA_ARGS__); \
2023
}
24+
#endif
2125

2226
#ifdef VLLM_BUILD_CPU_OPS
2327
#define VLLM_DISPATCH_TO_CPU_CASE(BASENAME, ...) \

csrc/pybind.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ void gptq_shuffle_dispatch(
8787
VLLM_DISPATCH_DEVICES(q_weight.device(), gptq_shuffle, q_weight, q_perm);
8888
}
8989

90+
#ifdef VLLM_BUILD_CPU_ONLY
91+
int get_device_attribute(
92+
int attribute,
93+
int device_id) { return 94387; }
94+
#endif
95+
9096
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9197
// vLLM custom ops
9298
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");

Dockerfile renamed to gpu.Dockerfile

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ RUN apt-get update -y \
66
WORKDIR /workspace
77

88
# install build and runtime dependencies
9-
COPY requirements.txt requirements.txt
9+
COPY requirements-gpu.txt requirements-gpu.txt
1010
RUN --mount=type=cache,target=/root/.cache/pip \
11-
pip install -r requirements.txt
11+
pip install -r requirements-gpu.txt
1212

1313
# install development dependencies
1414
COPY requirements-dev.txt requirements-dev.txt
@@ -19,14 +19,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
1919
FROM dev AS build
2020

2121
# install build dependencies
22-
COPY requirements-build.txt requirements-build.txt
22+
COPY requirements-build-gpu.txt requirements-build-gpu.txt
2323
RUN --mount=type=cache,target=/root/.cache/pip \
24-
pip install -r requirements-build.txt
24+
pip install -r requirements-build-gpu.txt
2525

2626
# copy input files
2727
COPY csrc csrc
2828
COPY setup.py setup.py
29-
COPY requirements.txt requirements.txt
29+
COPY requirements-gpu.txt requirements-gpu.txt
3030
COPY pyproject.toml pyproject.toml
3131
COPY vllm/__init__.py vllm/__init__.py
3232

@@ -60,9 +60,9 @@ RUN apt-get update -y \
6060
&& apt-get install -y python3-pip
6161

6262
WORKDIR /workspace
63-
COPY requirements.txt requirements.txt
63+
COPY requirements-gpu.txt requirements-gpu.txt
6464
RUN --mount=type=cache,target=/root/.cache/pip \
65-
pip install -r requirements.txt
65+
pip install -r requirements-gpu.txt
6666

6767
FROM vllm-base AS vllm
6868
COPY --from=build /workspace/vllm/*.so /workspace/vllm/

requirements-build-cpu.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Should be mirrored in pyproject.toml
2+
ninja
3+
packaging
4+
setuptools>=49.4.0
5+
torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.1.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl#sha256=88f1ee550c6291af8d0417871fb7af84b86527d18bc02ac4249f07dcd84dda56 #2.1.0+cpu
6+
wheel
File renamed without changes.

requirements-cpu.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
ninja # For faster builds.
2+
psutil
3+
ray >= 2.5.1
4+
pandas # Required for Ray data.
5+
pyarrow # Required for Ray data.
6+
sentencepiece # Required for LLaMA tokenizer.
7+
numpy
8+
einops # Required for phi-1_5
9+
torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.1.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl#sha256=88f1ee550c6291af8d0417871fb7af84b86527d18bc02ac4249f07dcd84dda56 #2.1.0+cpu
10+
transformers >= 4.34.0 # Required for Mistral.
11+
fastapi
12+
uvicorn[standard]
13+
pydantic == 1.10.13 # Required for OpenAI server.
14+
aioprometheus[starlette]
File renamed without changes.

setup.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from packaging.version import parse, Version
99
import setuptools
1010
import torch
11-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
11+
12+
BUILD_CPU_ONLY = os.getenv('VLLM_BUILD_CPU_ONLY', "0") == "1"
13+
14+
if not BUILD_CPU_ONLY:
15+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
16+
else:
17+
from torch.utils.cpp_extension import BuildExtension, CppExtension
1218

1319
ROOT_DIR = os.path.dirname(__file__)
1420

@@ -21,11 +27,11 @@
2127

2228

2329
def _is_hip() -> bool:
24-
return torch.version.hip is not None
30+
return torch.version.hip is not None and not BUILD_CPU_ONLY
2531

2632

2733
def _is_cuda() -> bool:
28-
return torch.version.cuda is not None
34+
return torch.version.cuda is not None and not BUILD_CPU_ONLY
2935

3036

3137
# Compiler flags.
@@ -86,7 +92,6 @@ def get_hipcc_rocm_version():
8692
print("Could not find HIP version in the output")
8793
return None
8894

89-
9095
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
9196
"""Get the CUDA version from nvcc.
9297
@@ -137,6 +142,19 @@ def get_torch_arch_list() -> Set[str]:
137142
stacklevel=2)
138143
return arch_list
139144

145+
if not BUILD_CPU_ONLY:
146+
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
147+
compute_capabilities = get_torch_arch_list()
148+
if not compute_capabilities:
149+
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
150+
# GPUs on the current machine.
151+
device_count = torch.cuda.device_count()
152+
for i in range(device_count):
153+
major, minor = torch.cuda.get_device_capability(i)
154+
if major < 7:
155+
raise RuntimeError(
156+
"GPUs with compute capability below 7.0 are not supported.")
157+
compute_capabilities.add(f"{major}.{minor}")
140158

141159
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
142160
compute_capabilities = get_torch_arch_list()
@@ -211,9 +229,11 @@ def get_torch_arch_list() -> Set[str]:
211229
f"amdgpu_arch_found: {amd_arch}")
212230

213231
# Setup CPU Operations
214-
BUILD_CPU_OPS = os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1"
232+
BUILD_CPU_OPS = (os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1" or BUILD_CPU_ONLY)
215233
CPU_OPS_SOURCES = []
216234
if BUILD_CPU_OPS:
235+
if BUILD_CPU_ONLY:
236+
CXX_FLAGS += ["-DVLLM_BUILD_CPU_ONLY"]
217237
CXX_FLAGS += [
218238
"-DVLLM_BUILD_CPU_OPS", "-fopenmp", "-mavx512f", "-mavx512bf16",
219239
"-mavx512vl"
@@ -228,29 +248,42 @@ def get_torch_arch_list() -> Set[str]:
228248

229249
ext_modules = []
230250

231-
vllm_extension_sources = [
232-
"csrc/cache_kernels.cu",
233-
"csrc/attention/attention_kernels.cu",
234-
"csrc/pos_encoding_kernels.cu",
235-
"csrc/activation_kernels.cu",
236-
"csrc/layernorm_kernels.cu",
237-
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
238-
"csrc/quantization/gptq/q_gemm.cu",
239-
"csrc/cuda_utils_kernels.cu",
240-
"csrc/pybind.cpp",
241-
] + CPU_OPS_SOURCES
251+
if not BUILD_CPU_ONLY:
252+
vllm_extension_sources = [
253+
"csrc/cache_kernels.cu",
254+
"csrc/attention/attention_kernels.cu",
255+
"csrc/pos_encoding_kernels.cu",
256+
"csrc/activation_kernels.cu",
257+
"csrc/layernorm_kernels.cu",
258+
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
259+
"csrc/quantization/gptq/q_gemm.cu",
260+
"csrc/cuda_utils_kernels.cu",
261+
"csrc/pybind.cpp",
262+
] + CPU_OPS_SOURCES
263+
264+
if _is_cuda():
265+
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
266+
267+
vllm_extension = CUDAExtension(
268+
name="vllm._C",
269+
sources=vllm_extension_sources,
270+
extra_compile_args={
271+
"cxx": CXX_FLAGS,
272+
"nvcc": NVCC_FLAGS,
273+
},
274+
)
275+
else:
276+
vllm_extension_sources = [
277+
"csrc/pybind.cpp",
278+
] + CPU_OPS_SOURCES
279+
vllm_extension = CppExtension(
280+
name="vllm._C",
281+
sources=vllm_extension_sources,
282+
extra_compile_args={
283+
"cxx": CXX_FLAGS,
284+
},
285+
)
242286

243-
if _is_cuda():
244-
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
245-
246-
vllm_extension = CUDAExtension(
247-
name="vllm._C",
248-
sources=vllm_extension_sources,
249-
extra_compile_args={
250-
"cxx": CXX_FLAGS,
251-
"nvcc": NVCC_FLAGS,
252-
},
253-
)
254287
ext_modules.append(vllm_extension)
255288

256289

@@ -280,7 +313,7 @@ def get_vllm_version() -> str:
280313
if hipcc_version != MAIN_CUDA_VERSION:
281314
rocm_version_str = hipcc_version.replace(".", "")[:3]
282315
version += f"+rocm{rocm_version_str}"
283-
else:
316+
elif _is_cuda():
284317
cuda_version = str(nvcc_cuda_version)
285318
if cuda_version != MAIN_CUDA_VERSION:
286319
cuda_version_str = cuda_version.replace(".", "")[:3]
@@ -303,9 +336,13 @@ def get_requirements() -> List[str]:
303336
if _is_hip():
304337
with open(get_path("requirements-rocm.txt")) as f:
305338
requirements = f.read().strip().split("\n")
339+
elif _is_cuda():
340+
with open(get_path("requirements-gpu.txt")) as f:
341+
requirements = f.read().strip().split("\n")
306342
else:
307-
with open(get_path("requirements.txt")) as f:
343+
with open(get_path("requirements-cpu.txt")) as f:
308344
requirements = f.read().strip().split("\n")
345+
309346
return requirements
310347

311348

0 commit comments

Comments
 (0)