Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,32 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int):

assert new_block[0].block_id == last_block_id

# Test case for cache mertics
@staticmethod
def test_metric():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: test overflow case

Copy link
Collaborator Author

@comaniac comaniac Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I improved the way of handling overflow so there won't be overflow anymore. Specifically, we group the hit rate of n*1000 queries, where n is an integer. Additionally, we maintain hit_count and query_count for less than 1000 queries. Thus, we could combine them to get the real hit rate:

incomplete_ratio = query_count / 1000
(grouped_hit_rate * n + (hit_count / query_count) * incomplete_ratio) / (n + incomplete_ratio)

Also improved the test to cover this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. btw i don't think we need this since python int won't overflow

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I'm just afraid that if we host an endpoint for months, the counter will grow to a huge number which might hurt performance

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel there will be many other performance issues in such a case in vLLM. But I don't mind this code being here, as long as it's well tested.

block_size = 16
allocator = PrefixCachingBlockAllocator(num_blocks=4,
block_size=block_size)
# Test when no query (0/0)
assert allocator.get_prefix_cache_hit_rate() == 0.0

token_ids = list(range(block_size))
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 0/1 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.0

allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 1/2 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.5

# Test more than one block
for _ in range(2, 1005):
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99

@staticmethod
def create_immutable_chain(
block_size: int,
Expand Down
7 changes: 7 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def test_block_allocator(
assert (first_block == second_block)
assert (second_block.ref_count == 2)

# Check metric: 1 hit of 2 queries
assert block_allocator.get_prefix_cache_hit_rate() == 0.5

# Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block)
Expand All @@ -48,6 +51,10 @@ def test_block_allocator(
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)

# Allocate one more time to get 3/4 hit rate for easy checking
block_allocator.allocate(block_hash, 0)
assert block_allocator.get_prefix_cache_hit_rate() == 0.75


@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
Expand Down
53 changes: 53 additions & 0 deletions vllm/core/block/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple

from vllm.core.block.interfaces import Block, BlockAllocator
Expand Down Expand Up @@ -282,6 +283,58 @@ def ids(self) -> List[int]:
return self._block_ids


@dataclass
class CacheMetricData:
"""A utility dataclass to maintain cache metric.
To avoid overflow, we maintain the hit rate in block granularity, so that
we can maintain a single hit rate for n_completed_block x block_size,
and calculate the real time hit rate by the following:
BS = The number of queries per block.
nB = The number of completed blocks.
HR = hit rate of (nB x BS) queries.
Q = current number of queries (< BS).
H = current number of hits (< BS).
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
"""
num_completed_blocks: int = 0
completed_block_cache_hit_rate: float = 0.0
num_incompleted_block_queries: int = 0
num_incompleted_block_hit: int = 0
block_size: int = 1000

def query(self, hit: bool):
self.num_incompleted_block_queries += 1
self.num_incompleted_block_hit += 1 if hit else 0

# When a block is completed, update the cache hit rate
# and reset the incomplete numbers.
if self.num_incompleted_block_queries == self.block_size:
hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
self.completed_block_cache_hit_rate = (
self.completed_block_cache_hit_rate * self.num_completed_blocks
+ hit_rate) / (self.num_completed_blocks + 1)
self.num_incompleted_block_queries = 0
self.num_incompleted_block_hit = 0
self.num_completed_blocks += 1

def get_hit_rate(self):
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
total_blocks = self.num_completed_blocks + incomplete_ratio
if total_blocks == 0:
return 0.0

completed_block_hit, incompleted_block_hit = 0.0, 0.0
if self.num_completed_blocks > 0:
completed_block_hit = (self.completed_block_cache_hit_rate *
self.num_completed_blocks)
if self.num_incompleted_block_queries > 0:
incompleted_hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
return (completed_block_hit + incompleted_block_hit) / total_blocks


def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.

Expand Down
5 changes: 5 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def get_common_computed_block_ids(
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())

def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()

def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
Expand Down
10 changes: 10 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def get_num_blocks_touched(self,
num_lookahead_slots: int = 0) -> int:
pass

@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

class NoFreeBlocksError(ValueError):
pass

Expand Down Expand Up @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block:
There is at most one null block per allocator.
"""
pass

@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
3 changes: 3 additions & 0 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None:

block.block_id = block_id # Assign block_id

def get_prefix_cache_hit_rate(self) -> float:
return -1


class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
Expand Down
10 changes: 8 additions & 2 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Token blocks."""

from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple

from vllm.core.block.common import (CopyOnWriteTracker,
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
Expand Down Expand Up @@ -107,6 +106,8 @@ def __init__(
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())

self.metric_data = CacheMetricData()

# Implements Block.Factory.
def _create_block(
self,
Expand Down Expand Up @@ -155,9 +156,11 @@ def allocate_immutable_block(self,

cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id
self._incr_refcount_cached_block(block)
return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block)

# No cached block => Allocate a new block
Expand Down Expand Up @@ -404,6 +407,9 @@ def get_physical_block_id(self, absolute_id: int) -> int:
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids

def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()

def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:
Expand Down
31 changes: 27 additions & 4 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Set, Tuple

from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.common import CacheMetricData
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Expand Down Expand Up @@ -60,6 +61,11 @@ def contains_block(self, block_hash: int) -> bool:
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass

@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass


class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
Expand All @@ -85,6 +91,8 @@ def __init__(self,

self.default_hash_ctr = count()

self.cache_metric_data = CacheMetricData()

def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
Expand All @@ -105,15 +113,17 @@ def allocate(self,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)

if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:

if block_hash in self.cached_blocks:
self.cache_metric_data.query(hit=True)
else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
Expand Down Expand Up @@ -150,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block

def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()


class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
Expand Down Expand Up @@ -209,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")

def get_prefix_cache_hit_rate(self) -> float:
return -1


class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
Expand Down Expand Up @@ -705,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")
3 changes: 3 additions & 0 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def get_num_free_gpu_blocks(self) -> int:
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)

def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
Expand Down
4 changes: 4 additions & 0 deletions vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device


class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
Expand Down Expand Up @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self,

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1
15 changes: 8 additions & 7 deletions vllm/core/evictor_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,21 @@ def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")

evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
evicted_block, evicted_block_id = None, None
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block is None:
evicted_block, evicted_block_id = block, _id
continue
if evicted_block.last_accessed < block.last_accessed:
break
if (evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block, evicted_block_id = block, _id

assert evicted_block is not None
assert evicted_block_id is not None
self.free_table.pop(evicted_block_id)

return evicted_block_id, evicted_block.content_hash
Expand All @@ -110,7 +112,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,

def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
self.free_table.move_to_end(block_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this line?
the free_table will be unordered if update op happens.


def remove(self, block_id: int):
if block_id not in self.free_table:
Expand Down
6 changes: 6 additions & 0 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple

from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device


class AllocStatus(enum.Enum):
Expand Down Expand Up @@ -116,3 +117,8 @@ def get_common_computed_block_ids(
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
5 changes: 4 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache
from vllm.utils import Device, PyObjectCache

logger = init_logger(__name__)

Expand Down Expand Up @@ -447,6 +447,9 @@ def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0

def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

Expand Down
Loading