Skip to content

Commit d721168

Browse files
authored
Improve setup script & Add a guard for bfloat16 kernels (#130)
1 parent 4a151dd commit d721168

File tree

4 files changed

+90
-16
lines changed

4 files changed

+90
-16
lines changed

csrc/attention/attention_dtypes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,4 @@
33
#include "attention_generic.cuh"
44
#include "dtype_float16.cuh"
55
#include "dtype_float32.cuh"
6-
7-
#ifdef ENABLE_BF16
86
#include "dtype_bfloat16.cuh"
9-
#endif // ENABLE_BF16

csrc/attention/attention_kernels.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
458458
// TODO(woosuk): Support FP32.
459459
if (query.dtype() == at::ScalarType::Half) {
460460
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
461-
#ifdef ENABLE_BF16
462461
} else if (query.dtype() == at::ScalarType::BFloat16) {
463462
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
464-
#endif
465463
} else {
466464
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
467465
}

csrc/attention/dtype_bfloat16.cuh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,36 @@ struct FloatVec<bf16_8_t> {
7878

7979
// Utility functions for type conversions.
8080
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
81+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
82+
assert(false);
83+
#else
8184
return __bfloat1622float2(val);
85+
#endif
8286
}
8387

8488
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
89+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
90+
assert(false);
91+
#else
8592
return __bfloat162bfloat162(val);
93+
#endif
8694
}
8795

8896
// Vector addition.
8997
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
98+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
99+
assert(false);
100+
#else
90101
return a + b;
102+
#endif
91103
}
92104

93105
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
106+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
107+
assert(false);
108+
#else
94109
return __hadd2(a, b);
110+
#endif
95111
}
96112

97113
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
@@ -134,12 +150,20 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
134150
// Vector multiplication.
135151
template<>
136152
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
153+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
154+
assert(false);
155+
#else
137156
return __hmul(a, b);
157+
#endif
138158
}
139159

140160
template<>
141161
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
162+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
163+
assert(false);
164+
#else
142165
return __hmul2(a, b);
166+
#endif
143167
}
144168

145169
template<>
@@ -244,11 +268,19 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
244268

245269
// Vector fused multiply-add.
246270
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
271+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
272+
assert(false);
273+
#else
247274
return __hfma2(a, b, c);
275+
#endif
248276
}
249277

250278
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
279+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
280+
assert(false);
281+
#else
251282
return __hfma2(bf162bf162(a), b, c);
283+
#endif
252284
}
253285

254286
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
@@ -361,19 +393,31 @@ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
361393
}
362394

363395
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
396+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
397+
assert(false);
398+
#else
364399
dst = __float22bfloat162_rn(src);
400+
#endif
365401
}
366402

367403
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
404+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
405+
assert(false);
406+
#else
368407
dst.x = __float22bfloat162_rn(src.x);
369408
dst.y = __float22bfloat162_rn(src.y);
409+
#endif
370410
}
371411

372412
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
413+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
414+
assert(false);
415+
#else
373416
dst.x = __float22bfloat162_rn(src.x);
374417
dst.y = __float22bfloat162_rn(src.y);
375418
dst.z = __float22bfloat162_rn(src.z);
376419
dst.w = __float22bfloat162_rn(src.w);
420+
#endif
377421
}
378422

379423
} // namespace cacheflow

setup.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,63 @@
1-
from typing import List
1+
import subprocess
2+
from typing import List, Set
23

4+
from packaging.version import parse, Version
35
import setuptools
46
import torch
57
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
68
from torch.utils.cpp_extension import CUDA_HOME
79

8-
9-
# Build custom operators.
10-
CXX_FLAGS = ["-g"]
10+
# Compiler flags.
11+
CXX_FLAGS = ["-g", "-O2"]
1112
# TODO(woosuk): Should we use -O3?
1213
NVCC_FLAGS = ["-O2"]
1314

15+
1416
if not torch.cuda.is_available():
1517
raise RuntimeError(
1618
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. "
1719
"CUDA must be available in order to build the package.")
1820

19-
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
20-
# different compute capabilities.
21-
compute_capability = torch.cuda.get_device_capability()
22-
major, minor = compute_capability
23-
# Enable bfloat16 support if the compute capability is >= 8.0.
24-
if major >= 8:
25-
NVCC_FLAGS.append("-DENABLE_BF16")
21+
22+
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
23+
"""Get the CUDA version from nvcc.
24+
25+
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
26+
"""
27+
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
28+
universal_newlines=True)
29+
output = nvcc_output.split()
30+
release_idx = output.index("release") + 1
31+
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
32+
return nvcc_cuda_version
33+
34+
35+
# Collect the compute capabilities of all available GPUs.
36+
device_count = torch.cuda.device_count()
37+
compute_capabilities: Set[int] = set()
38+
for i in range(device_count):
39+
major, minor = torch.cuda.get_device_capability(i)
40+
if major < 7:
41+
raise RuntimeError(
42+
"GPUs with compute capability less than 7.0 are not supported.")
43+
compute_capabilities.add(major * 10 + minor)
44+
# If no GPU is available, add all supported compute capabilities.
45+
if not compute_capabilities:
46+
compute_capabilities = {70, 75, 80, 86, 90}
47+
# Add target compute capabilities to NVCC flags.
48+
for capability in compute_capabilities:
49+
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
50+
51+
# Validate the NVCC CUDA version.
52+
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
53+
if nvcc_cuda_version < Version("11.0"):
54+
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
55+
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
56+
raise RuntimeError(
57+
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
58+
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
59+
raise RuntimeError(
60+
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
2661

2762
ext_modules = []
2863

0 commit comments

Comments
 (0)