Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 4115939

Browse files
jikunshangbigPYJ1151
authored andcommitted
Remove hardcoded device="cuda" to support more devices (vllm-project#2503)
Co-authored-by: Jiang Li <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
1 parent 3b1644e commit 4115939

32 files changed

+353
-301
lines changed

benchmarks/benchmark_latency.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
2525
dtype=args.dtype,
2626
enforce_eager=args.enforce_eager,
2727
kv_cache_dtype=args.kv_cache_dtype,
28+
device=args.device,
2829
)
2930

3031
sampling_params = SamplingParams(
@@ -135,5 +136,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
135136
default=None,
136137
help=('path to save the pytorch profiler output. Can be visualized '
137138
'with ui.perfetto.dev or Tensorboard.'))
139+
parser.add_argument(
140+
"--device",
141+
type=str,
142+
default="cuda",
143+
choices=["cuda"],
144+
help='device type for vLLM execution, supporting CUDA only currently.')
138145
args = parser.parse_args()
139146
main(args)

benchmarks/benchmark_throughput.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def run_vllm(
7272
max_model_len: Optional[int],
7373
enforce_eager: bool,
7474
kv_cache_dtype: str,
75+
device: str,
7576
) -> float:
7677
from vllm import LLM, SamplingParams
7778
llm = LLM(
@@ -85,6 +86,7 @@ def run_vllm(
8586
max_model_len=max_model_len,
8687
enforce_eager=enforce_eager,
8788
kv_cache_dtype=kv_cache_dtype,
89+
device=device,
8890
)
8991

9092
# Add the requests to the engine.
@@ -209,7 +211,7 @@ def main(args: argparse.Namespace):
209211
args.seed, args.n, args.use_beam_search,
210212
args.trust_remote_code, args.dtype,
211213
args.max_model_len, args.enforce_eager,
212-
args.kv_cache_dtype)
214+
args.kv_cache_dtype, args.device)
213215
elif args.backend == "hf":
214216
assert args.tensor_parallel_size == 1
215217
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -294,6 +296,12 @@ def main(args: argparse.Namespace):
294296
default="auto",
295297
help=
296298
'Data type for kv cache storage. If "auto", will use model data type.')
299+
parser.add_argument(
300+
"--device",
301+
type=str,
302+
default="cuda",
303+
choices=["cuda"],
304+
help='device type for vLLM execution, supporting CUDA only currently.')
297305
args = parser.parse_args()
298306
if args.tokenizer is None:
299307
args.tokenizer = args.model

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,32 @@ def main(
2525
dtype: torch.dtype,
2626
seed: int,
2727
do_profile: bool,
28+
device: str = "cuda",
2829
kv_cache_dtype: Optional[str] = None,
2930
) -> None:
3031
random.seed(seed)
3132
torch.random.manual_seed(seed)
32-
torch.cuda.manual_seed(seed)
33+
if torch.cuda.is_available():
34+
torch.cuda.manual_seed(seed)
3335

3436
scale = float(1.0 / (head_size**0.5))
3537
query = torch.empty(num_seqs,
3638
num_query_heads,
3739
head_size,
3840
dtype=dtype,
39-
device="cuda")
41+
device=device)
4042
query.uniform_(-scale, scale)
4143

4244
assert num_query_heads % num_kv_heads == 0
4345
alibi_slopes = None
4446
if use_alibi:
4547
alibi_slopes = torch.randn(num_query_heads,
4648
dtype=torch.float,
47-
device="cuda")
49+
device=device)
4850

4951
context_lens = [context_len for _ in range(num_seqs)]
5052
max_context_len = max(context_lens)
51-
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
53+
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
5254

5355
# Create the block tables.
5456
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@@ -59,12 +61,17 @@ def main(
5961
for _ in range(max_num_blocks_per_seq)
6062
]
6163
block_tables.append(block_table)
62-
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
64+
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
6365

6466
# Create the KV cache.
65-
key_caches, value_caches = create_kv_caches_with_random(
66-
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
67-
dtype)
67+
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
68+
block_size,
69+
1,
70+
num_kv_heads,
71+
head_size,
72+
kv_cache_dtype,
73+
dtype,
74+
device=device)
6875
key_cache, value_cache = key_caches[0], value_caches[0]
6976

7077
# Prepare for the paged attention kernel.
@@ -84,7 +91,7 @@ def main(
8491
)
8592
max_logits = torch.empty_like(exp_sums)
8693

87-
def run_benchmark(num_iters: int, profile: bool = False) -> float:
94+
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
8895
torch.cuda.synchronize()
8996
if profile:
9097
torch.cuda.cudart().cudaProfilerStart()
@@ -135,6 +142,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
135142

136143
# Warmup.
137144
print("Warming up...")
145+
run_benchmark = run_cuda_benchmark
138146
run_benchmark(num_iters=3, profile=False)
139147

140148
# Benchmark.
@@ -175,6 +183,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
175183
default="auto",
176184
help=
177185
'Data type for kv cache storage. If "auto", will use model data type.')
186+
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
178187
args = parser.parse_args()
179188
print(args)
180189

tests/kernels/test_activation.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,29 @@
77
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
88
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
99
SEEDS = [0]
10-
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
10+
CUDA_DEVICES = [
11+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
12+
]
1113

1214

1315
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
1416
@pytest.mark.parametrize("d", D)
1517
@pytest.mark.parametrize("dtype", DTYPES)
1618
@pytest.mark.parametrize("seed", SEEDS)
17-
@pytest.mark.parametrize("device", DEVICES)
19+
@pytest.mark.parametrize("device", CUDA_DEVICES)
1820
@torch.inference_mode()
1921
def test_silu_and_mul(
2022
num_tokens: int,
2123
d: int,
2224
dtype: torch.dtype,
2325
seed: int,
24-
device: int,
26+
device: str,
2527
) -> None:
2628
torch.random.manual_seed(seed)
27-
torch.cuda.manual_seed(seed)
28-
gpu_id = f"cuda:{device}"
29-
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
29+
if torch.cuda.is_available():
30+
torch.cuda.manual_seed(seed)
31+
torch.set_default_device(device)
32+
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
3033
layer = SiluAndMul()
3134
out = layer(x)
3235
ref_out = layer._forward(x)
@@ -37,19 +40,20 @@ def test_silu_and_mul(
3740
@pytest.mark.parametrize("d", D)
3841
@pytest.mark.parametrize("dtype", DTYPES)
3942
@pytest.mark.parametrize("seed", SEEDS)
40-
@pytest.mark.parametrize("device", DEVICES)
43+
@pytest.mark.parametrize("device", CUDA_DEVICES)
4144
@torch.inference_mode()
4245
def test_gelu_new(
4346
num_tokens: int,
4447
d: int,
4548
dtype: torch.dtype,
4649
seed: int,
47-
device: int,
50+
device: str,
4851
) -> None:
4952
torch.random.manual_seed(seed)
50-
torch.cuda.manual_seed(seed)
51-
gpu_id = f"cuda:{device}"
52-
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
53+
if torch.cuda.is_available():
54+
torch.cuda.manual_seed(seed)
55+
torch.set_default_device(device)
56+
x = torch.randn(num_tokens, d, dtype=dtype)
5357
layer = NewGELU()
5458
out = layer(x)
5559
ref_out = layer._forward(x)
@@ -60,18 +64,19 @@ def test_gelu_new(
6064
@pytest.mark.parametrize("d", D)
6165
@pytest.mark.parametrize("dtype", DTYPES)
6266
@pytest.mark.parametrize("seed", SEEDS)
63-
@pytest.mark.parametrize("device", DEVICES)
67+
@pytest.mark.parametrize("device", CUDA_DEVICES)
6468
def test_gelu_fast(
6569
num_tokens: int,
6670
d: int,
6771
dtype: torch.dtype,
6872
seed: int,
69-
device: int,
73+
device: str,
7074
) -> None:
7175
torch.random.manual_seed(seed)
72-
torch.cuda.manual_seed(seed)
73-
gpu_id = f"cuda:{device}"
74-
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
76+
if torch.cuda.is_available():
77+
torch.cuda.manual_seed(seed)
78+
torch.set_default_device(device)
79+
x = torch.randn(num_tokens, d, dtype=dtype)
7580
layer = FastGELU()
7681
out = layer(x)
7782
ref_out = layer._forward(x)

tests/kernels/test_attention.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
USE_ALIBI = [False, True]
2828
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
2929
SEEDS = [0]
30-
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
30+
CUDA_DEVICES = [
31+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
32+
]
3133

3234

3335
def ref_masked_attention(
@@ -91,7 +93,7 @@ def ref_single_query_cached_kv_attention(
9193
alibi_bias = None
9294
if alibi_slopes is not None:
9395
# Create the ALiBi bias used in the paged attention kernel.
94-
position_ids = torch.arange(context_len, device=query.device).int()
96+
position_ids = torch.arange(context_len).int()
9597
alibi_bias = (position_ids - context_len + 1).float()
9698
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
9799
1, 1, -1)
@@ -110,7 +112,7 @@ def ref_single_query_cached_kv_attention(
110112
@pytest.mark.parametrize("dtype", DTYPES)
111113
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
112114
@pytest.mark.parametrize("seed", SEEDS)
113-
@pytest.mark.parametrize("device", DEVICES)
115+
@pytest.mark.parametrize("device", CUDA_DEVICES)
114116
def test_paged_attention(
115117
kv_cache_factory,
116118
version: str,
@@ -122,33 +124,28 @@ def test_paged_attention(
122124
dtype: torch.dtype,
123125
kv_cache_dtype: str,
124126
seed: int,
125-
device: int,
127+
device: str,
126128
) -> None:
127129
random.seed(seed)
128130
torch.random.manual_seed(seed)
129-
torch.cuda.manual_seed(seed)
130-
gpu_id = f"cuda:{device}"
131+
if torch.cuda.is_available():
132+
torch.cuda.manual_seed(seed)
133+
torch.set_default_device(device)
131134
scale = float(1.0 / (head_size**0.5))
132135
num_query_heads, num_kv_heads = num_heads
133-
query = torch.empty(num_seqs,
134-
num_query_heads,
135-
head_size,
136-
dtype=dtype,
137-
device=gpu_id)
136+
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
138137
query.uniform_(-scale, scale)
139138

140139
assert num_query_heads % num_kv_heads == 0
141140
num_queries_per_kv = num_query_heads // num_kv_heads
142141
alibi_slopes = None
143142
if use_alibi:
144-
alibi_slopes = torch.randn(num_query_heads,
145-
dtype=torch.float,
146-
device=gpu_id)
143+
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
147144

148145
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
149146
context_lens[-1] = MAX_SEQ_LEN
150147
max_context_len = max(context_lens)
151-
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
148+
context_lens = torch.tensor(context_lens, dtype=torch.int)
152149

153150
# Create the block tables.
154151
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@@ -159,13 +156,13 @@ def test_paged_attention(
159156
for _ in range(max_num_blocks_per_seq)
160157
]
161158
block_tables.append(block_table)
162-
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
159+
block_tables = torch.tensor(block_tables, dtype=torch.int)
163160

164161
# Create the KV caches.
165162
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
166163
num_kv_heads, head_size,
167164
kv_cache_dtype, dtype, seed,
168-
gpu_id)
165+
device)
169166
key_cache, value_cache = key_caches[0], value_caches[0]
170167

171168
# Call the paged attention kernel.
@@ -193,12 +190,10 @@ def test_paged_attention(
193190
tmp_output = torch.empty(
194191
size=(num_seqs, num_heads, num_partitions, head_size),
195192
dtype=output.dtype,
196-
device=output.device,
197193
)
198194
exp_sums = torch.empty(
199195
size=(num_seqs, num_heads, num_partitions),
200196
dtype=torch.float32,
201-
device=output.device,
202197
)
203198
max_logits = torch.empty_like(exp_sums)
204199
ops.paged_attention_v2(
@@ -229,14 +224,14 @@ def test_paged_attention(
229224
block_size, x)
230225
dequantized_key_cache = torch.empty(size=key_cache_shape,
231226
dtype=dtype,
232-
device=gpu_id)
227+
device=device)
233228
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
234229
key_cache = dequantized_key_cache
235230

236231
value_cache_shape = value_cache.shape
237232
dequantized_value_cache = torch.empty(size=value_cache_shape,
238233
dtype=dtype,
239-
device=gpu_id)
234+
device=device)
240235
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
241236
value_cache = dequantized_value_cache
242237

@@ -283,7 +278,7 @@ def ref_multi_query_kv_attention(
283278
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
284279
diagonal=1)
285280
attn_mask = attn_mask * torch.finfo(dtype).min
286-
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
281+
attn_mask = attn_mask.to(dtype=dtype)
287282

288283
ref_output = ref_masked_attention(
289284
query[start_idx:end_idx],
@@ -303,20 +298,21 @@ def ref_multi_query_kv_attention(
303298
@pytest.mark.parametrize("head_size", HEAD_SIZES)
304299
@pytest.mark.parametrize("dtype", DTYPES)
305300
@pytest.mark.parametrize("seed", SEEDS)
306-
@pytest.mark.parametrize("device", DEVICES)
301+
@pytest.mark.parametrize("device", CUDA_DEVICES)
307302
@torch.inference_mode()
308303
def test_multi_query_kv_attention(
309304
num_seqs: int,
310305
num_heads: Tuple[int, int],
311306
head_size: int,
312307
dtype: torch.dtype,
313308
seed: int,
314-
device: int,
309+
device: str,
315310
) -> None:
316311
random.seed(seed)
317312
torch.random.manual_seed(seed)
318-
torch.cuda.manual_seed(seed)
319-
gpu_id = f"cuda:{device}"
313+
if torch.cuda.is_available():
314+
torch.cuda.manual_seed(seed)
315+
torch.set_default_device(device)
320316
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
321317
# As the xformers library is already tested with its own tests, we can use
322318
# a smaller MAX_SEQ_LEN here.
@@ -329,8 +325,7 @@ def test_multi_query_kv_attention(
329325
qkv = torch.empty(num_tokens,
330326
num_query_heads + 2 * num_kv_heads,
331327
head_size,
332-
dtype=dtype,
333-
device=gpu_id)
328+
dtype=dtype)
334329
qkv.uniform_(-scale, scale)
335330
query, key, value = qkv.split(
336331
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)

0 commit comments

Comments
 (0)