Skip to content

Commit d674e59

Browse files
committed
update
1 parent df845b2 commit d674e59

16 files changed

Lines changed: 200 additions & 16 deletions

tests/core/block/test_prefix_caching_block.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,32 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int):
682682

683683
assert new_block[0].block_id == last_block_id
684684

685+
# Test case for cache mertics
686+
@staticmethod
687+
def test_metric():
688+
block_size = 16
689+
allocator = PrefixCachingBlockAllocator(num_blocks=4,
690+
block_size=block_size)
691+
# Test when no query (0/0)
692+
assert allocator.get_prefix_cache_hit_rate() == 0.0
693+
694+
token_ids = list(range(block_size))
695+
allocator.allocate_immutable_block(prev_block=None,
696+
token_ids=token_ids)
697+
# Test 0/1 hit rate
698+
assert allocator.get_prefix_cache_hit_rate() == 0.0
699+
700+
allocator.allocate_immutable_block(prev_block=None,
701+
token_ids=token_ids)
702+
# Test 1/2 hit rate
703+
assert allocator.get_prefix_cache_hit_rate() == 0.5
704+
705+
# Test more than one block
706+
for _ in range(2, 1005):
707+
allocator.allocate_immutable_block(prev_block=None,
708+
token_ids=token_ids)
709+
assert allocator.get_prefix_cache_hit_rate() > 0.99
710+
685711
@staticmethod
686712
def create_immutable_chain(
687713
block_size: int,

tests/prefix_caching/test_prefix_caching.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def test_block_allocator(
3434
assert (first_block == second_block)
3535
assert (second_block.ref_count == 2)
3636

37+
# Check metric: 1 hit of 2 queries
38+
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
39+
3740
# Free the first_block and confirm that the ref_count is correctly
3841
# decremented on the second block
3942
block_allocator.free(first_block)
@@ -48,6 +51,10 @@ def test_block_allocator(
4851
assert (first_block == second_block)
4952
assert (first_block.block_hash == block_hash)
5053

54+
# Allocate one more time to get 3/4 hit rate for easy checking
55+
block_allocator.allocate(block_hash, 0)
56+
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
57+
5158

5259
@pytest.mark.parametrize("num_blocks", [16])
5360
def test_eviction(num_blocks: int, ):

vllm/core/block/common.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import deque
2+
from dataclasses import dataclass
23
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
34

45
from vllm.core.block.interfaces import Block, BlockAllocator
@@ -282,6 +283,58 @@ def ids(self) -> List[int]:
282283
return self._block_ids
283284

284285

286+
@dataclass
287+
class CacheMetricData:
288+
"""A utility dataclass to maintain cache metric.
289+
To avoid overflow, we maintain the hit rate in block granularity, so that
290+
we can maintain a single hit rate for n_completed_block x block_size,
291+
and calculate the real time hit rate by the following:
292+
BS = The number of queries per block.
293+
nB = The number of completed blocks.
294+
HR = hit rate of (nB x BS) queries.
295+
Q = current number of queries (< BS).
296+
H = current number of hits (< BS).
297+
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
298+
"""
299+
num_completed_blocks: int = 0
300+
completed_block_cache_hit_rate: float = 0.0
301+
num_incompleted_block_queries: int = 0
302+
num_incompleted_block_hit: int = 0
303+
block_size: int = 1000
304+
305+
def query(self, hit: bool):
306+
self.num_incompleted_block_queries += 1
307+
self.num_incompleted_block_hit += 1 if hit else 0
308+
309+
# When a block is completed, update the cache hit rate
310+
# and reset the incomplete numbers.
311+
if self.num_incompleted_block_queries == self.block_size:
312+
hit_rate = (self.num_incompleted_block_hit /
313+
self.num_incompleted_block_queries)
314+
self.completed_block_cache_hit_rate = (
315+
self.completed_block_cache_hit_rate * self.num_completed_blocks
316+
+ hit_rate) / (self.num_completed_blocks + 1)
317+
self.num_incompleted_block_queries = 0
318+
self.num_incompleted_block_hit = 0
319+
self.num_completed_blocks += 1
320+
321+
def get_hit_rate(self):
322+
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
323+
total_blocks = self.num_completed_blocks + incomplete_ratio
324+
if total_blocks == 0:
325+
return 0.0
326+
327+
completed_block_hit, incompleted_block_hit = 0.0, 0.0
328+
if self.num_completed_blocks > 0:
329+
completed_block_hit = (self.completed_block_cache_hit_rate *
330+
self.num_completed_blocks)
331+
if self.num_incompleted_block_queries > 0:
332+
incompleted_hit_rate = (self.num_incompleted_block_hit /
333+
self.num_incompleted_block_queries)
334+
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
335+
return (completed_block_hit + incompleted_block_hit) / total_blocks
336+
337+
285338
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
286339
"""Retrieves all the blocks in a sequence starting from the last block.
287340

vllm/core/block/cpu_gpu_block_allocator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ def get_common_computed_block_ids(
323323
def all_block_ids(self) -> FrozenSet[int]:
324324
return frozenset(self._block_ids_to_allocator.keys())
325325

326+
def get_prefix_cache_hit_rate(self, device: Device) -> float:
327+
"""Prefix cache hit rate. -1 means not supported or disabled."""
328+
assert device in self._allocators
329+
return self._allocators[device].get_prefix_cache_hit_rate()
330+
326331
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
327332
"""Returns and clears the mapping of source to destination block IDs.
328333
Will be called after every swapping operations for now, and after every

vllm/core/block/interfaces.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ def get_num_blocks_touched(self,
186186
num_lookahead_slots: int = 0) -> int:
187187
pass
188188

189+
@abstractmethod
190+
def get_prefix_cache_hit_rate(self) -> float:
191+
"""Prefix cache hit rate. -1 means not supported or disabled."""
192+
pass
193+
189194
class NoFreeBlocksError(ValueError):
190195
pass
191196

@@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block:
278283
There is at most one null block per allocator.
279284
"""
280285
pass
286+
287+
@abstractmethod
288+
def get_prefix_cache_hit_rate(self, device: Device) -> float:
289+
"""Prefix cache hit rate. -1 means not supported or disabled."""
290+
pass

vllm/core/block/naive_block.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None:
341341

342342
block.block_id = block_id # Assign block_id
343343

344+
def get_prefix_cache_hit_rate(self) -> float:
345+
return -1
346+
344347

345348
class NaiveBlock(Block):
346349
"""An implementation of the Block class that does not support prefix

vllm/core/block/prefix_caching_block.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Token blocks."""
2-
32
from os.path import commonprefix
43
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
54

6-
from vllm.core.block.common import (CopyOnWriteTracker,
5+
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
76
get_all_blocks_recursively)
87
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
98
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
@@ -107,6 +106,8 @@ def __init__(
107106
self._cow_tracker = CopyOnWriteTracker(
108107
refcounter=self._refcounter.as_readonly())
109108

109+
self.metric_data = CacheMetricData()
110+
110111
# Implements Block.Factory.
111112
def _create_block(
112113
self,
@@ -155,9 +156,11 @@ def allocate_immutable_block(self,
155156

156157
cached_block_id = self._cached_blocks.get(block.content_hash, None)
157158
if cached_block_id is not None:
159+
self.metric_data.query(hit=True)
158160
block.block_id = cached_block_id
159161
self._incr_refcount_cached_block(block)
160162
return block
163+
self.metric_data.query(hit=False)
161164
self._block_pool.free_block(block)
162165

163166
# No cached block => Allocate a new block
@@ -404,6 +407,9 @@ def get_physical_block_id(self, absolute_id: int) -> int:
404407
def all_block_ids(self) -> FrozenSet[int]:
405408
return self._hashless_allocator.all_block_ids
406409

410+
def get_prefix_cache_hit_rate(self) -> float:
411+
return self.metric_data.get_hit_rate()
412+
407413
def is_block_cached(self, block: Block) -> bool:
408414
assert block.content_hash is not None
409415
if block.content_hash in self._cached_blocks:

vllm/core/block_manager_v1.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Set, Tuple
99

1010
from vllm.block import BlockTable, PhysicalTokenBlock
11+
from vllm.core.block.common import CacheMetricData
1112
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
1213
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
1314
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
@@ -60,6 +61,11 @@ def contains_block(self, block_hash: int) -> bool:
6061
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
6162
pass
6263

64+
@abstractmethod
65+
def get_prefix_cache_hit_rate(self) -> float:
66+
"""Prefix cache hit rate. -1 means not supported or disabled."""
67+
pass
68+
6369

6470
class CachedBlockAllocator(BlockAllocatorBase):
6571
"""Manages free physical token blocks for a device.
@@ -85,6 +91,8 @@ def __init__(self,
8591

8692
self.default_hash_ctr = count()
8793

94+
self.cache_metric_data = CacheMetricData()
95+
8896
def allocate_block(self, block_hash: int,
8997
num_hashed_tokens: int) -> PhysicalTokenBlock:
9098
if self.current_num_blocks == self.num_blocks:
@@ -105,15 +113,17 @@ def allocate(self,
105113
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
106114
if block_hash is None:
107115
block_hash = next(self.default_hash_ctr)
116+
108117
if block_hash in self.evictor:
109118
assert block_hash not in self.cached_blocks
110119
block = self.evictor.remove(block_hash)
111120
assert block.ref_count == 0
112121
self.cached_blocks[block_hash] = block
113-
block.ref_count += 1
114-
assert block.block_hash == block_hash
115-
return block
116-
if block_hash not in self.cached_blocks:
122+
123+
if block_hash in self.cached_blocks:
124+
self.cache_metric_data.query(hit=True)
125+
else:
126+
self.cache_metric_data.query(hit=False)
117127
self.cached_blocks[block_hash] = self.allocate_block(
118128
block_hash, num_hashed_tokens)
119129
block = self.cached_blocks[block_hash]
@@ -150,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
150160
del self.cached_blocks[old_hash]
151161
self.cached_blocks[block_hash] = block
152162

163+
def get_prefix_cache_hit_rate(self) -> float:
164+
return self.cache_metric_data.get_hit_rate()
165+
153166

154167
class UncachedBlockAllocator(BlockAllocatorBase):
155168
"""Manages free physical token blocks for a device.
@@ -209,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
209222
raise NotImplementedError(
210223
"Invalid codepath for uncached block allocator.")
211224

225+
def get_prefix_cache_hit_rate(self) -> float:
226+
return -1
227+
212228

213229
class BlockSpaceManagerV1(BlockSpaceManager):
214230
"""Manages the mapping between logical and physical token blocks."""
@@ -705,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
705721
if self.enable_caching:
706722
for seq in seq_group.get_seqs():
707723
self.compute_full_blocks_in_seq(seq)
724+
725+
def get_prefix_cache_hit_rate(self, device: Device) -> float:
726+
if device == Device.GPU:
727+
return self.gpu_allocator.get_prefix_cache_hit_rate()
728+
if device == Device.CPU:
729+
return self.cpu_allocator.get_prefix_cache_hit_rate()
730+
raise ValueError(f"Invalid device: {device}")

vllm/core/block_manager_v2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ def get_num_free_gpu_blocks(self) -> int:
441441
def get_num_free_cpu_blocks(self) -> int:
442442
return self.block_allocator.get_num_free_blocks(Device.CPU)
443443

444+
def get_prefix_cache_hit_rate(self, device: Device) -> float:
445+
return self.block_allocator.get_prefix_cache_hit_rate(device)
446+
444447
def _can_swap(self,
445448
seq_group: SequenceGroup,
446449
device: Device,

vllm/core/embedding_model_block_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
44
from vllm.sequence import Sequence, SequenceGroup
5+
from vllm.utils import Device
56

67

78
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
@@ -81,3 +82,6 @@ def get_common_computed_block_ids(self,
8182

8283
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
8384
pass
85+
86+
def get_prefix_cache_hit_rate(self, device: Device) -> float:
87+
return -1

0 commit comments

Comments
 (0)