Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/e2e/test_chunked_prefill_tkv_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def test_decode_padding_to_same_block(model: ModelInfo, max_model_len: int,
computed_tokens_dict = computed_tokens()

# Save the number of free blocks to compare once we allocate a new one
initial_free_blocks = len(runner.block_pool)
initial_free_blocks = runner.block_pool.get_num_free_blocks()

# Both requests are still in the first block
# Scheduler schedules 1 token each
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_decode_padding_to_same_block(model: ModelInfo, max_model_len: int,
# 🌶️🌶️🌶️ short prompt gets padded, it's now the longest sequence
assert output.tkv == short_prompt_len + steps + 64
# We should have allocated only one new block for the long prompt entering
assert len(runner.block_pool) == initial_free_blocks - 1
assert runner.block_pool.get_num_free_blocks() == initial_free_blocks - 1
computed_tokens_dict = computed_tokens()

# The shorter request is now at the second block boundary (tkv = 128), so we
Expand All @@ -371,4 +371,4 @@ def test_decode_padding_to_same_block(model: ModelInfo, max_model_len: int,
# 🌶️🌶️🌶️ short prompt padding removed again, tkv is back to long + steps
assert output.tkv == long_prompt_len + steps
# One more real block was allocated for the short request
assert len(runner.block_pool) == initial_free_blocks - 2
assert runner.block_pool.get_num_free_blocks() == initial_free_blocks - 2
21 changes: 14 additions & 7 deletions tests/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,18 @@ def check_scheduler_inference_steps(
n_blocks = (engine_core.model_executor.driver_worker.worker.
model_runner.n_blocks)
n_reserved_blocks = n_blocks - scheduler.n_free_blocks
req_ids2blocks = (engine_core.model_executor.driver_worker.worker.
model_runner.req_ids2blocks)
req_ids2reserved_blocks = (

kv_cache_manager = (engine_core.model_executor.driver_worker.
worker.model_runner.kv_cache_manager)

req_ids2blocks = {
req_id: [block.block_id for block in blocks]
for req_id, blocks in kv_cache_manager.req_to_blocks.items()
if blocks
}
req_ids2num_reserved_blocks = (
engine_core.model_executor.driver_worker.worker.model_runner.
req_ids2reserved_blocks)
req_ids2num_reserved_blocks)
n_used_blocks = sum(
[len(blocks) for blocks in req_ids2blocks.values()])

Expand All @@ -242,16 +249,16 @@ def check_scheduler_inference_steps(
), f"Step {step}, n_used_blocks: {n_used_blocks}"

assert DISABLE_ASSERTS or len(req_ids2blocks) == len(
req_ids2reserved_blocks)
req_ids2num_reserved_blocks)
for req_id in req_ids2blocks:
# current number of used blocks should be less than reserved
assert (DISABLE_ASSERTS or len(req_ids2blocks[req_id])
<= req_ids2reserved_blocks[req_id])
<= req_ids2num_reserved_blocks[req_id])
# update requested/reserved blocks to check in last step
# Note: overwrite and not max
# because of reduce_left_padding()
requested_blocks[req_id] = len(req_ids2blocks[req_id])
reserved_blocks[req_id] = req_ids2reserved_blocks[req_id]
reserved_blocks[req_id] = req_ids2num_reserved_blocks[req_id]

# last step: check that sequences used all their reserved blocks
# Note: no early stopping, all sequences produce max_num_tokens
Expand Down
133 changes: 83 additions & 50 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from logging import DEBUG
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast

import torch
Expand All @@ -16,7 +16,10 @@
from vllm.model_executor.layers.pooler import ClassifierPooler, Pooler
from vllm.sampling_params import SamplingType
from vllm.utils import is_pin_memory_available
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.core.sched.output import CachedRequestData
from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.pool.metadata import PoolingMetadata
Expand Down Expand Up @@ -812,10 +815,8 @@ def __init__(

self.block_size = SpyrePlatform.get_block_size()

# TODO: move to a KV cache manager
self.req_ids2blocks: dict[str, deque[int]] = {}
# max number of blocks needed (reserved) per request id
self.req_ids2reserved_blocks: dict[str, int] = {}
self.req_ids2num_reserved_blocks: dict[str, int] = {}

self.tkv: int = 0

Expand Down Expand Up @@ -863,7 +864,29 @@ def complete_warmup(self) -> None:
def _set_blocks(self, num_blocks: int) -> None:
# set number of available blocks and populate block_pool
self.n_blocks = num_blocks - 1
self.block_pool = deque([i for i in range(1, self.n_blocks + 1)])
self.block_pool = BlockPool(num_gpu_blocks=self.n_blocks + 1,
enable_caching=False,
enable_kv_cache_events=False)
attn_spec = FullAttentionSpec(
block_size=self.block_size,
# dummy values
num_kv_heads=1,
head_size=1,
dtype=torch.float16)
self.kv_cache_manager = FullAttentionManager(
kv_cache_spec=attn_spec,
block_pool=self.block_pool,
# Currently don't support models with more than one
# attention type, e.g. full and sliding window, so
# there is only one group.
kv_cache_group_id=0,
# We don't support DCP
# https://docs.vllm.ai/en/latest/serving/context_parallel_deployment/#decode-context-parallel
dcp_world_size=1,
)

def _get_blocks(self, request_id: str) -> list[KVCacheBlock]:
return self.kv_cache_manager.req_to_blocks[request_id]

def get_total_spyre_blocks(self) -> int:
"""Returns the total number of KV cache blocks available for spyre.
Expand Down Expand Up @@ -912,12 +935,13 @@ def update_states(self, scheduler_output):
# TODO: move to kv cache manager
# Continuous batching: free blocks
for req_id in scheduler_output.finished_req_ids:
if blocks_to_free := self.req_ids2blocks.pop(req_id, None):
if logger.isEnabledFor(DEBUG) and (blocks_to_free :=
self._get_blocks(req_id)):
logger.debug("Freeing request id: %s", req_id)
self.req_ids2reserved_blocks.pop(req_id)
for block_id in blocks_to_free:
logger.debug("Freeing block with id: %s", block_id)
self.block_pool.append(block_id)
for block in blocks_to_free:
logger.debug("Freeing block with id: %s", block.block_id)
self.req_ids2num_reserved_blocks.pop(req_id, None)
self.kv_cache_manager.free(req_id)

def _prepare_prompt(
self,
Expand All @@ -930,7 +954,7 @@ def _prepare_prompt(
req_id = request.req_id
prompt_token_ids = request.prompt_token_ids
sampling_params = request.sampling_params
is_new_batch = len(self.req_ids2blocks) == 0
is_new_batch = len(self.req_ids2num_reserved_blocks) == 0
prompt_len = len(prompt_token_ids)

# make sure that the current tkv of the decode batch is greater or
Expand Down Expand Up @@ -989,19 +1013,20 @@ def _prepare_prompt(
(self.tkv - len(prompt_token_ids)) / self.block_size)
n_reserved_blocks = math.ceil(
n / self.block_size) - n_fully_padded_blocks
self.req_ids2reserved_blocks[req_id] = n_reserved_blocks
self.req_ids2num_reserved_blocks[req_id] = n_reserved_blocks

# filling block table and slot mapping
blocks = []
slots = []
for pos_i in range(right_padding_tkv):
if pos_i % self.block_size == 0:
block_number = self.block_pool.popleft()
blocks.append(block_number)
block_offset = pos_i % self.block_size
slot = block_number * self.block_size + block_offset
slots.append(slot)
self.req_ids2blocks[req_id] = deque(blocks)

blocks = self.kv_cache_manager.allocate_new_blocks(
req_id, right_padding_tkv)

block_offsets = [block.block_id * self.block_size for block in blocks]
slot_mapping = torch.arange(self.block_size,
dtype=torch.int64).repeat(len(blocks))
slot_mapping += torch.tensor(block_offsets,
dtype=torch.int64).repeat_interleave(
self.block_size)
slot_mapping.unsqueeze_(0)

# Add new request to the cached states.
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
Expand Down Expand Up @@ -1029,7 +1054,6 @@ def _prepare_prompt(
self.prefill_batch.refresh_metadata()

self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu')
slot_mapping = torch.tensor([slots], dtype=torch.int64)
prompt_token_ids_tensor = torch.tensor(prompt_token_ids,
dtype=torch.long,
device=torch.device("cpu"))
Expand All @@ -1049,7 +1073,7 @@ def _prepare_prompt(
current_tkv_mask = None
# left padding info is stored in CachedRequestState of self.requests
left_padded_prompt_mask = None
# block table is stored in self.req_ids2blocks (only passed for decode)
# block table is only passed for decode
block_table = None

model_inputs = SamplingForwardInputs(
Expand Down Expand Up @@ -1094,24 +1118,35 @@ def _prepare_decode(
for req_id in req_ids:
# adding new blocks if needed
if self.tkv % self.block_size == 0:
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
n_blocks = max(n_blocks, len(self.req_ids2blocks[req_id]))
req_state: SamplingRequestState = self.requests[req_id]
# we want to allocate the block for 1 next token.
# So we need to compute the number of tokens that the
# kv_cache_manager knows about: the number of computed
# tokens so far plus any intra-block padding.
total_tokens = req_state.left_padding % self.block_size \
+ req_state.num_computed_tokens + 1
Comment on lines +1126 to +1127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment to explain this formula (especially the req_state.left_padding % self.block_size bit)?

blocks = self.kv_cache_manager.allocate_new_blocks(
req_id, total_tokens)
assert len(blocks) == 1
n_blocks = max(n_blocks, len(self._get_blocks(req_id)))

for req_id in req_ids:
# TODO: Will this always just be one token ID if there's no spec
# or jump decoding?

req_state: SamplingRequestState = self.requests[req_id]
req_state = self.requests[req_id]

# filling block table with padding blocks to make it rectangular
# Note: the padding block id 0 here is chosen arbitrarily, it can
# be any allocated block id on the Sypre card (has to be in range
# [0, self.n_blocks - 1]). Further, it also be a block id that holds
# actual KV cache for another (or the same) sequence.
blocks = self.req_ids2blocks[req_id].copy()
for i in range(n_blocks - len(self.req_ids2blocks[req_id])):
blocks.appendleft(0)
block_table.append(blocks)
blocks = self._get_blocks(req_id)
left_padding_blocks = n_blocks - len(blocks)

req_block_ids = [0] * left_padding_blocks + \
[block.block_id for block in blocks]
block_table.append(req_block_ids)

# slot_mapping for all blocks of sequence
start_slot = block_table[-1][-1] * self.block_size
Expand Down Expand Up @@ -1287,7 +1322,7 @@ def get_req_id_to_index(self, is_prefill: bool) -> dict[str, int]:
return req_id_to_index

def get_n_free_blocks(self) -> int:
return self.n_blocks - sum(self.req_ids2reserved_blocks.values())
return self.n_blocks - sum(self.req_ids2num_reserved_blocks.values())

def no_prompt_logprob(self, is_prefill: bool) -> bool:
if is_prefill:
Expand Down Expand Up @@ -1856,10 +1891,11 @@ def _prepare_chunked_prefill(self, req_id: str):
input_positions_np = input_positions.numpy()

# create block table tensor
blocks = [0] * (left_padding // self.block_size) + list(
self.req_ids2blocks[req_id])
blocks = self._get_blocks(req_id)
block_end = (chunk_i + 1) * self.chunk_blocks_count
block_table = torch.tensor(blocks[:block_end],
block_ids = [0] * (left_padding // self.block_size) + \
[block.block_id for block in blocks]
block_table = torch.tensor(block_ids[:block_end],
dtype=torch.int64).unsqueeze(0)

slot_mapping = []
Expand Down Expand Up @@ -1973,8 +2009,10 @@ def _prepare_decode(
# adding new blocks if needed
req_state = self.requests[req_id]
if req_state.num_computed_tokens % self.block_size == 0:
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
max_n_blocks = max(max_n_blocks, len(self.req_ids2blocks[req_id]))
blocks = self.kv_cache_manager.allocate_new_blocks(
req_id, req_state.num_computed_tokens + 1)
assert len(blocks) == 1
max_n_blocks = max(max_n_blocks, len(self._get_blocks(req_id)))

# We'll calculate tkv on the fly, it is the max num computed tokens
# of a request since there is no tokens left padding, only for blocks
Expand All @@ -1990,13 +2028,11 @@ def _prepare_decode(
# be any allocated block id on the Sypre card (has to be in range
# [0, self.n_blocks - 1]). Further, it also be a block id that holds
# actual KV cache for another (or the same) sequence.
blocks = self.req_ids2blocks[req_id].copy()
left_pad_blocks_count = (max_n_blocks -
len(self.req_ids2blocks[req_id]))

for _ in range(left_pad_blocks_count):
blocks.appendleft(0)
block_table.append(blocks)
blocks = self._get_blocks(req_id)
left_pad_blocks_count = (max_n_blocks - len(blocks))
block_ids = [0]*left_pad_blocks_count + \
[block.block_id for block in blocks]
block_table.append(block_ids)

# slot_mapping for all blocks of sequence
start_slot = block_table[-1][-1] * self.block_size
Expand Down Expand Up @@ -2055,13 +2091,11 @@ def add_new_request(self, request: NewRequestData):
req_id = request.req_id
prompt_token_ids = request.prompt_token_ids
sampling_params = request.sampling_params
is_new_batch = len(self.req_ids2blocks) == 0
is_new_batch = len(self.req_ids2num_reserved_blocks) == 0
prompt_len = len(prompt_token_ids)

self.prefill_batch.clear_requests()

blocks_count = math.ceil(prompt_len / self.block_size)

# set the new tkv to the prompt length if starting a new decode batch
if is_new_batch:
self.tkv = prompt_len
Expand All @@ -2074,11 +2108,10 @@ def add_new_request(self, request: NewRequestData):
# subtract the padding blocks from the reserved blocks
n_reserved_blocks = math.ceil(total_tokens / self.block_size)

self.req_ids2reserved_blocks[req_id] = n_reserved_blocks
self.req_ids2num_reserved_blocks[req_id] = n_reserved_blocks

# allocate blocks
blocks = [self.block_pool.popleft() for _ in range(blocks_count)]
self.req_ids2blocks[req_id] = deque(blocks)
self.kv_cache_manager.allocate_new_blocks(req_id, prompt_len)

# Add new request to the cached states.
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
Expand Down