Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class KVCacheBlocks:
`blocks[i][j]` refers to the i-th kv_cache_group
and the j-th block of tokens.We don't use block of
tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
"""

Expand Down Expand Up @@ -184,9 +184,17 @@ def get_computed_blocks(self,

if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
if request.num_preemptions > 0:
# Previously preempted request
self.prefix_cache_stats.preempted_requests += 1
self.prefix_cache_stats.preempted_queries += request.num_tokens
self.prefix_cache_stats.preempted_hits += (
num_new_computed_tokens)
else:
# New request
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens

return KVCacheBlocks(computed_blocks), num_new_computed_tokens

Expand All @@ -209,10 +217,10 @@ def allocate_slots(
already been computed locally (i.e. new_computed_blocks).
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
Expand Down Expand Up @@ -365,7 +373,7 @@ def get_num_common_prefix_blocks(
requests in the current step.

Returns:
list[int]: The number of common prefix blocks for each kv cache
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert request.status == RequestStatus.RUNNING
Expand Down
66 changes: 34 additions & 32 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,46 +251,48 @@ def schedule(self) -> SchedulerOutput:
req_index += 1
continue

# Schedule newly needed KV blocks for the request.
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
else:
preempted_req = self.running.pop()

self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)

self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:

if new_blocks is not None:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:

# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
else:
preempted_req = self.running.pop()

self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this line is the only actual change in this file, all other changes are code style change.

if self.log_stats:
preempted_req.record_event(EngineCoreEventType.PREEMPTED,
scheduled_timestamp)

self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break

if new_blocks is None:
# Cannot schedule this request.
break
assert new_blocks is not None

# Schedule the request.
scheduled_running_reqs.append(request)
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@ class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of requests in this update.
# The number of new requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of tokens that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0
# The number of previously preempted requests in this update.
preempted_requests: int = 0
# The `queries` number for preempted requests.
preempted_queries: int = 0
# The `hits` number for preempted requests.
preempted_hits: int = 0


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(
# indicates that the output is corrupted
self.num_nans_in_logits = 0

# The number of requests being preempted by the scheduler
self.num_preemptions = 0

self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
Expand Down