Skip to content

Commit c2a8acf

Browse files
heheda12345comaniac
authored andcommitted
[V1] Move more control of kv cache initialization from model_executor to EngineCore (vllm-project#11960)
Signed-off-by: Chen Zhang <[email protected]> Co-authored-by: Cody Yu <[email protected]>
1 parent 67bee34 commit c2a8acf

File tree

12 files changed

+515
-104
lines changed

12 files changed

+515
-104
lines changed

tests/v1/test_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import List
2+
3+
import torch
4+
5+
from vllm.v1.utils import bind_kv_cache
6+
7+
8+
def test_bind_kv_cache():
9+
from vllm.attention import Attention
10+
11+
ctx = {
12+
'layers.0.self_attn': Attention(32, 128, 0.1),
13+
'layers.1.self_attn': Attention(32, 128, 0.1),
14+
'layers.2.self_attn': Attention(32, 128, 0.1),
15+
'layers.3.self_attn': Attention(32, 128, 0.1),
16+
}
17+
kv_cache = {
18+
'layers.0.self_attn': torch.zeros((1, )),
19+
'layers.1.self_attn': torch.zeros((1, )),
20+
'layers.2.self_attn': torch.zeros((1, )),
21+
'layers.3.self_attn': torch.zeros((1, )),
22+
}
23+
runner_kv_caches: List[torch.Tensor] = []
24+
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
25+
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
26+
'layers.0.self_attn']
27+
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
28+
'layers.1.self_attn']
29+
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
30+
'layers.2.self_attn']
31+
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
32+
'layers.3.self_attn']
33+
34+
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
35+
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
36+
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
37+
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
38+
39+
40+
def test_bind_kv_cache_non_attention():
41+
from vllm.attention import Attention
42+
43+
# example from Jamba PP=2
44+
ctx = {
45+
'model.layers.20.attn': Attention(32, 128, 0.1),
46+
'model.layers.28.attn': Attention(32, 128, 0.1),
47+
}
48+
kv_cache = {
49+
'model.layers.20.attn': torch.zeros((1, )),
50+
'model.layers.28.attn': torch.zeros((1, )),
51+
}
52+
53+
runner_kv_caches: List[torch.Tensor] = []
54+
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
55+
56+
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
57+
'model.layers.20.attn']
58+
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
59+
'model.layers.28.attn']
60+
61+
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
62+
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']

vllm/attention/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def __init__(
101101
self.num_heads = num_heads
102102
self.head_size = head_size
103103
self.num_kv_heads = num_kv_heads
104+
self.sliding_window = sliding_window
104105
self.backend = backend_name_to_enum(attn_backend.get_name())
106+
self.dtype = dtype
105107

106108
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
107109
# torch.compile works by registering the attention as one giant

vllm/v1/core/kv_cache_utils.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from dataclasses import dataclass
44
from typing import Any, List, NamedTuple, Optional, Tuple
55

6+
from vllm.config import VllmConfig
67
from vllm.logger import init_logger
8+
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
9+
KVCacheTensor)
710
from vllm.v1.request import Request
811

912
logger = init_logger(__name__)
@@ -305,3 +308,124 @@ def hash_request_tokens(block_size: int,
305308
ret.append(block_hash)
306309
parent_block_hash_value = block_hash.hash_value
307310
return ret
311+
312+
313+
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
314+
kv_cache_spec: KVCacheSpec,
315+
available_memory: int):
316+
"""
317+
Checks whether `available_memory` is enough for the KV cache to hold at
318+
least one request with the model's max_model_len.
319+
320+
Args:
321+
vllm_config: The global VllmConfig
322+
kv_cache_spec: The kv cache spec of the model
323+
available_memory: Memory available for KV cache in bytes.
324+
325+
Raises:
326+
ValueError: If there is not enough memory available for the KV cache.
327+
"""
328+
329+
if available_memory <= 0:
330+
raise ValueError("No available memory for the cache blocks. "
331+
"Try increasing `gpu_memory_utilization` when "
332+
"initializing the engine.")
333+
334+
max_model_len = vllm_config.model_config.max_model_len
335+
needed_memory = 0
336+
for layer_spec in kv_cache_spec.values():
337+
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
338+
339+
if needed_memory > available_memory:
340+
raise ValueError(
341+
f"To serve at least one request with the models's max seq len "
342+
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
343+
f"cache is needed, which is larger than the available KV cache "
344+
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
345+
f"increasing `gpu_memory_utilization` or decreasing "
346+
f"`max_model_len` when initializing the engine.")
347+
348+
349+
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
350+
"""
351+
Whether all layers in the given KVCacheSpec have the same type of KV cache.
352+
353+
Args:
354+
kv_cache_spec: The KVCacheSpec of the model
355+
356+
Returns:
357+
True if all layers have the same type, False otherwise.
358+
"""
359+
360+
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
361+
return len(layer_keys) == 1
362+
363+
364+
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
365+
kv_cache_spec: KVCacheSpec,
366+
available_memory: int) -> KVCacheConfig:
367+
"""
368+
Generates the KV cache configuration for a model with one type of KV cache.
369+
Divide the available memory equally among all layers.
370+
371+
Args:
372+
vllm_config: The global VllmConfig
373+
kv_cache_spec: The kv cache spec of the model
374+
available_memory: Memory available for KV cache in bytes.
375+
376+
Returns:
377+
The generated KVCacheConfig
378+
"""
379+
380+
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
381+
assert len(page_sizes) == 1
382+
page_size = page_sizes.pop()
383+
384+
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
385+
num_blocks = max(num_blocks, 0)
386+
387+
if vllm_config.cache_config.num_gpu_blocks_override is not None:
388+
num_gpu_blocks_override = \
389+
vllm_config.cache_config.num_gpu_blocks_override
390+
logger.info(
391+
"Overriding num_gpu_blocks=%d with "
392+
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
393+
num_blocks = num_gpu_blocks_override
394+
395+
logger.info("# GPU blocks: %d", num_blocks)
396+
397+
per_layer_size = page_size * num_blocks
398+
399+
kv_cache_config = KVCacheConfig(
400+
num_blocks=num_blocks,
401+
tensors={
402+
layer_name: KVCacheTensor(size=per_layer_size)
403+
for layer_name in kv_cache_spec
404+
},
405+
groups=[[layer_name for layer_name in kv_cache_spec]],
406+
kv_cache_spec=kv_cache_spec)
407+
return kv_cache_config
408+
409+
410+
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
411+
available_memory: int) -> KVCacheConfig:
412+
"""
413+
Generates the KV cache configuration for a model
414+
TODO: support hybrid models with more than one type of KV cache.
415+
416+
Args:
417+
vllm_config: The global VllmConfig
418+
kv_cache_spec: The kv cache spec of the model
419+
available_memory: Memory available for KV cache in bytes.
420+
421+
Returns:
422+
The generated KVCacheConfig
423+
"""
424+
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
425+
if is_kv_cache_type_uniform(kv_cache_spec):
426+
# KV cache of all layers are the same, which is true for most models.
427+
# Allocate the same amount of memory for each layer.
428+
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
429+
available_memory)
430+
else:
431+
raise NotImplementedError

vllm/v1/engine/core.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
import zmq.asyncio
1212
from msgspec import msgpack
1313

14-
from vllm.config import CacheConfig, VllmConfig
14+
from vllm.config import VllmConfig
1515
from vllm.logger import init_logger
1616
from vllm.transformers_utils.config import (
1717
maybe_register_config_serialize_by_value)
1818
from vllm.utils import get_exception_traceback, zmq_socket_ctx
19+
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
1920
from vllm.v1.core.scheduler import Scheduler
2021
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
2122
EngineCoreRequest, EngineCoreRequestType,
@@ -49,7 +50,7 @@ def __init__(
4950

5051
# Setup KV Caches and update CacheConfig after profiling.
5152
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
52-
vllm_config.cache_config)
53+
vllm_config)
5354
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
5455
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
5556

@@ -65,21 +66,25 @@ def __init__(
6566
vllm_config.model_config)
6667

6768
def _initialize_kv_caches(self,
68-
cache_config: CacheConfig) -> Tuple[int, int]:
69+
vllm_config: VllmConfig) -> Tuple[int, int]:
6970
start = time.time()
70-
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
71-
)
7271

73-
if cache_config.num_gpu_blocks_override is not None:
74-
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
75-
logger.info(
76-
"Overriding num_gpu_blocks=%d with "
77-
"num_gpu_blocks_override=%d", num_gpu_blocks,
78-
num_gpu_blocks_override)
79-
num_gpu_blocks = num_gpu_blocks_override
72+
# Get all kv cache needed by the model
73+
kv_cache_spec = self.model_executor.get_kv_cache_spec()
74+
75+
# Profiles the peak memory usage of the model to determine how much
76+
# memory can be allocated for kv cache.
77+
availble_gpu_memory = self.model_executor.determine_available_memory()
8078

79+
# Get the kv cache tensor size
80+
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
81+
availble_gpu_memory)
82+
num_gpu_blocks = kv_cache_config.num_blocks
8183
num_cpu_blocks = 0
82-
self.model_executor.initialize(num_gpu_blocks)
84+
85+
# Initialize kv cache and warmup the execution
86+
self.model_executor.initialize(kv_cache_config)
87+
8388
elapsed = time.time() - start
8489
logger.info(("init engine (profile, create kv cache, "
8590
"warmup model) took %.2f seconds"), elapsed)

vllm/v1/executor/abstract.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from abc import ABC, abstractmethod
2-
from typing import Tuple, Type
2+
from typing import Type
33

44
from vllm.config import VllmConfig
5+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
56
from vllm.v1.outputs import ModelRunnerOutput
67

78

@@ -31,11 +32,15 @@ def __init__(self, vllm_config: VllmConfig) -> None:
3132
raise NotImplementedError
3233

3334
@abstractmethod
34-
def initialize(self, num_gpu_blocks: int) -> None:
35+
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
3536
raise NotImplementedError
3637

3738
@abstractmethod
38-
def determine_num_available_blocks(self) -> Tuple[int, int]:
39+
def determine_available_memory(self) -> int: # in bytes
40+
raise NotImplementedError
41+
42+
@abstractmethod
43+
def get_kv_cache_spec(self) -> KVCacheSpec:
3944
raise NotImplementedError
4045

4146
@abstractmethod

vllm/v1/executor/multiproc_executor.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.utils import (get_distributed_init_method, get_mp_context,
2424
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
2525
from vllm.v1.executor.abstract import Executor
26+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2627
from vllm.v1.outputs import ModelRunnerOutput
2728
from vllm.worker.worker_base import WorkerWrapperBase
2829

@@ -90,29 +91,33 @@ def sigusr1_handler(signum, frame):
9091
for w in self.workers:
9192
w.worker_response_mq.wait_until_ready()
9293

93-
def initialize(self, num_gpu_blocks: int) -> None:
94+
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
9495
"""
9596
Initialize the KV caches and begin the model execution loop of the
9697
underlying workers.
9798
"""
98-
logger.info("# GPU blocks: %d", num_gpu_blocks)
99-
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
99+
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
100100
self.collective_rpc("compile_or_warm_up_model")
101101

102-
def determine_num_available_blocks(self) -> Tuple[int, int]:
102+
def determine_available_memory(self) -> int:
103103
"""
104-
Determine the number of available KV blocks by invoking the
104+
Determine the available memory (in bytes) for KV cache by invoking the
105105
underlying worker.
106106
"""
107-
num_blocks = self.collective_rpc("determine_num_available_blocks")
107+
memory_sizes = self.collective_rpc("determine_available_memory")
108108

109109
# Since we use a shared centralized controller, we take the minimum
110-
# number of blocks across all workers to make sure all the memory
110+
# memory size across all workers to make sure all the memory
111111
# operators can be applied to all workers.
112-
num_gpu_blocks = min(b[0] for b in num_blocks)
113-
num_cpu_blocks = min(b[1] for b in num_blocks)
112+
return min(memory_sizes)
114113

115-
return num_gpu_blocks, num_cpu_blocks
114+
def get_kv_cache_spec(self) -> KVCacheSpec:
115+
"""
116+
Get all kv cache needed by the model by invoking the underlying worker.
117+
"""
118+
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
119+
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
120+
return kv_cache_specs[0]
116121

117122
def collective_rpc(self,
118123
method: str,

0 commit comments

Comments
 (0)