Skip to content

Commit f1434d4

Browse files
committed
fix ut
1 parent c546e89 commit f1434d4

File tree

6 files changed

+16
-7
lines changed

6 files changed

+16
-7
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ def main(args: argparse.Namespace):
211211
args.seed, args.n, args.use_beam_search,
212212
args.trust_remote_code, args.dtype,
213213
args.max_model_len, args.enforce_eager,
214-
args.kv_cache_dtype,
215-
args.device)
214+
args.kv_cache_dtype, args.device)
216215
elif args.backend == "hf":
217216
assert args.tensor_parallel_size == 1
218217
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,14 @@ def main(
6464
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
6565

6666
# Create the KV cache.
67-
key_caches, value_caches = create_kv_caches_with_random(
68-
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
69-
dtype, device=device)
67+
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
68+
block_size,
69+
1,
70+
num_kv_heads,
71+
head_size,
72+
kv_cache_dtype,
73+
dtype,
74+
device=device)
7075
key_cache, value_cache = key_caches[0], value_caches[0]
7176

7277
# Prepare for the paged attention kernel.

tests/kernels/test_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NUM_MAPPINGS = [256] # Arbitrary values for testing
1616
SEEDS = [0]
1717
CUDA_DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
18+
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
1819

1920

2021
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)

tests/lora/test_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from vllm.lora.models import LoRAMapping
77
from vllm.lora.request import LoRARequest
8-
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
8+
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig, DeviceConfig
99
from vllm.worker.worker import Worker
1010

1111

@@ -25,6 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
2525
),
2626
parallel_config=ParallelConfig(1, 1, False),
2727
scheduler_config=SchedulerConfig(32, 32, 32, 256),
28+
device_config=DeviceConfig("cuda"),
2829
local_rank=0,
2930
rank=0,
3031
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,

vllm/engine/llm_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
f"quantization={model_config.quantization}, "
8989
f"enforce_eager={model_config.enforce_eager}, "
9090
f"kv_cache_dtype={cache_config.cache_dtype}, "
91+
f"device_config={device_config.device}, "
9192
f"seed={model_config.seed})")
9293
# TODO(woosuk): Print more configs in debug mode.
9394

vllm/worker/model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,9 @@ def _prepare_decode(
314314
max_len=1,
315315
pad=_PAD_SLOT_ID,
316316
dtype=torch.long)
317-
context_lens = torch.tensor(context_lens, dtype=torch.int)
317+
context_lens = torch.tensor(context_lens,
318+
dtype=torch.int,
319+
device=self.device_config.device)
318320

319321
if use_captured_graph:
320322
# The shape of graph_block_tables is

0 commit comments

Comments
 (0)