Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/e2e/test_chunked_prefill_tkv_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def test_decode_padding_to_same_block(model: ModelInfo, max_model_len: int,
computed_tokens_dict = computed_tokens()

# Save the number of free blocks to compare once we allocate a new one
initial_free_blocks = len(runner.block_pool)
initial_free_blocks = runner.block_pool.get_num_free_blocks()

# Both requests are still in the first block
# Scheduler schedules 1 token each
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_decode_padding_to_same_block(model: ModelInfo, max_model_len: int,
# 🌶️🌶️🌶️ short prompt gets padded, it's now the longest sequence
assert output.tkv == short_prompt_len + steps + 64
# We should have allocated only one new block for the long prompt entering
assert len(runner.block_pool) == initial_free_blocks - 1
assert runner.block_pool.get_num_free_blocks() == initial_free_blocks - 1
computed_tokens_dict = computed_tokens()

# The shorter request is now at the second block boundary (tkv = 128), so we
Expand All @@ -371,4 +371,4 @@ def test_decode_padding_to_same_block(model: ModelInfo, max_model_len: int,
# 🌶️🌶️🌶️ short prompt padding removed again, tkv is back to long + steps
assert output.tkv == long_prompt_len + steps
# One more real block was allocated for the short request
assert len(runner.block_pool) == initial_free_blocks - 2
assert runner.block_pool.get_num_free_blocks() == initial_free_blocks - 2
11 changes: 9 additions & 2 deletions tests/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,15 @@ def check_scheduler_inference_steps(
n_blocks = (engine_core.model_executor.driver_worker.worker.
model_runner.n_blocks)
n_reserved_blocks = n_blocks - scheduler.n_free_blocks
req_ids2blocks = (engine_core.model_executor.driver_worker.worker.
model_runner.req_ids2blocks)

kv_cache_manager = (engine_core.model_executor.driver_worker.
worker.model_runner.kv_cache_manager)

req_ids2blocks = {
req_id: [block.block_id for block in blocks]
for req_id, blocks in kv_cache_manager.req_to_blocks.items()
if blocks
}
req_ids2reserved_blocks = (
engine_core.model_executor.driver_worker.worker.model_runner.
req_ids2reserved_blocks)
Expand Down
103 changes: 63 additions & 40 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
import time
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from logging import DEBUG
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast

import torch
Expand All @@ -16,7 +16,10 @@
from vllm.model_executor.layers.pooler import ClassifierPooler, Pooler
from vllm.sampling_params import SamplingType
from vllm.utils import is_pin_memory_available
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.core.sched.output import CachedRequestData
from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.pool.metadata import PoolingMetadata
Expand Down Expand Up @@ -812,8 +815,6 @@ def __init__(

self.block_size = SpyrePlatform.get_block_size()

# TODO: move to a KV cache manager
self.req_ids2blocks: dict[str, deque[int]] = {}
# max number of blocks needed (reserved) per request id
self.req_ids2reserved_blocks: dict[str, int] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to me that self.req_ids2reserved_blocks is the last thing related to block management left in the model runner. As far as I can tell there is no concept of reserved blocks in the FullAttentionManager. We could potentially derive a custom class from FullAttentionManager that adds req_ids2reserved_blocks... Not sure if this is something we want to do. Advantage: all block managing is then happening in the kv cache manager (not model runner and kv cache manager), downside: we need a custom class (which only adds one field tho...). WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could do something like that. Actually, I was going to ask you why we even have a concept of reserved blocks. Is it because we don't support preemption?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the volumetric constraint is probably always lower than the available number of blocks. We need to verify this information and resolve this question in a future PR. @yannicks1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because we don't support preemption?

I don't see why we couldn't support pre-emption actually, but for now let's aim to keep the behaviour the same w.r.t reserved blocks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are actively discussing this. I agree that we should not change behavior in this PR. This PR is only refactoring/integrating upstream code.

For #586 , follow up PRs:
the current implementation in #586 (still using the reserved block concept) does not consider blocks with reference counts > 1 (prefix hits) neither in the scheduler nor in the model runner (which actually 'reserves' the number of blocks).

As we opted for a prefix-caching unaware scheduler, the minimal thing we should do (probably in #586) is to modify the model runner to consider prefix hits when reserving the number of blocks. The scheduler is then still unaware of prefix hits when making decision for a new sequence, but the total number of available blocks is less conservative as it is considering the duplicates in the existing decode batch (blocks with reference count > 2)

In a next step (follow up PR) we can remove the concept of reserved blocks. I believe this should be doable by proving that the volumetric constraint is always stricter than the number of blocks constraint. if that is not given, we could indeed support preemption.


Expand Down Expand Up @@ -863,7 +864,22 @@ def complete_warmup(self) -> None:
def _set_blocks(self, num_blocks: int) -> None:
# set number of available blocks and populate block_pool
self.n_blocks = num_blocks - 1
self.block_pool = deque([i for i in range(1, self.n_blocks + 1)])
self.block_pool = BlockPool(num_gpu_blocks=self.n_blocks + 1,
enable_caching=False,
enable_kv_cache_events=False)
attn_spec = FullAttentionSpec(block_size=self.block_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could get that via get_kv_cache_spec()... otherwise we have the same default values in two different places in the code in case we have to change these arguments at some point in time...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did that because the block size coming from get_kv_cache() was the one from vllm and it was 512, if I remember correctly. Maybe we could just set the vllm block size to 64 in platform.py and remove this duplication. What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: check if this is easy to do, otherwise open an issue to resolve this later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this might be tricky to do for now. We set the block size to the max model len in platform.py to disable vllm's paged attention scheduling:

# - Set the block size (in tokens) to the maximum sequence length

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, what does happen if we set

block_size = self.vllm_config.cache_config.block_size

to

block_size = SpyrePlatform.get_block_size()

in get_kv_cache_spec(), does that solve it? Or is that not working, because the upstream scheduler will call get_kv_cache_spec() and we require it to return the max model length to disable vllm's paged attention scheduling ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the engine core calls get_kv_cache_spec() and it has to return the max model length :/

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set the block size to the max model len in platform.py to disable vllm's paged attention scheduling

What does this mean exactly? I'm wondering why we need to do this, but agree it is probably better to address as a follow-up and keep this PR nicely scoped.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, there is a comment in platform.py about this:

# To disable any paged attention ops in the base scheduler, we:
. I haven't traced the full lineage of this code, but that last PR that touched it was #206

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep I believe this was a hack introduced by @joerunde (please correct me if I am wrong). in our Spyre scheduler we call super.schedule() which calls the upstream vllm scheduler. by setting the block size to the max model length we ensure the upstream policy does not interfere with our custom logic in the plugin.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks. Let's keep it like that for now.

num_kv_heads=1,
head_size=1,
dtype=torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we need to set these to arbitrary values? My understanding is that we are only really using the attentionspec for passing the block size to the FullAttentionManager. Perhaps we could create a SpyreAttentionSpec or something without these unnecessary arguments. This is just a thought, doesn't need to be addressed before merging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the FullAttentionManager has an assertion to test that the KVCacheSpec is of type FullAttentionSpec. So we'll have to go with the dummy values for now. :/

self.kv_cache_manager = FullAttentionManager(
kv_cache_spec=attn_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
dcp_world_size=1,
)

def _get_blocks(self, request_id: str) -> list[KVCacheBlock]:
return self.kv_cache_manager.req_to_blocks[request_id]

def get_total_spyre_blocks(self) -> int:
"""Returns the total number of KV cache blocks available for spyre.
Expand Down Expand Up @@ -912,12 +928,13 @@ def update_states(self, scheduler_output):
# TODO: move to kv cache manager
# Continuous batching: free blocks
for req_id in scheduler_output.finished_req_ids:
if blocks_to_free := self.req_ids2blocks.pop(req_id, None):
if logger.isEnabledFor(DEBUG) and (blocks_to_free :=
self._get_blocks(req_id)):
logger.debug("Freeing request id: %s", req_id)
self.req_ids2reserved_blocks.pop(req_id)
for block_id in blocks_to_free:
logger.debug("Freeing block with id: %s", block_id)
self.block_pool.append(block_id)
for block in blocks_to_free:
logger.debug("Freeing block with id: %s", block.block_id)
self.req_ids2reserved_blocks.pop(req_id, None)
self.kv_cache_manager.free(req_id)

def _prepare_prompt(
self,
Expand All @@ -930,7 +947,7 @@ def _prepare_prompt(
req_id = request.req_id
prompt_token_ids = request.prompt_token_ids
sampling_params = request.sampling_params
is_new_batch = len(self.req_ids2blocks) == 0
is_new_batch = len(self.req_ids2reserved_blocks) == 0
prompt_len = len(prompt_token_ids)

# make sure that the current tkv of the decode batch is greater or
Expand Down Expand Up @@ -992,16 +1009,17 @@ def _prepare_prompt(
self.req_ids2reserved_blocks[req_id] = n_reserved_blocks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit confused reading the code because I assumed this dictionary mapped the request id to the list of block IDs that were reserved, and couldn't understand how that interacted with the KV cache manager. Now I understand it is just the number of reserved blocks. Maybe we could change the name to something like req_ids2num_reserved_blocks or something (as a follow-up).


# filling block table and slot mapping
blocks = []

blocks = self.kv_cache_manager.allocate_new_blocks(
req_id, right_padding_tkv)

slots = []
for pos_i in range(right_padding_tkv):
if pos_i % self.block_size == 0:
block_number = self.block_pool.popleft()
blocks.append(block_number)
block_idx = pos_i // self.block_size
block_offset = pos_i % self.block_size
block_number = blocks[block_idx].block_id
slot = block_number * self.block_size + block_offset
slots.append(slot)
self.req_ids2blocks[req_id] = deque(blocks)

# Add new request to the cached states.
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
Expand Down Expand Up @@ -1049,7 +1067,7 @@ def _prepare_prompt(
current_tkv_mask = None
# left padding info is stored in CachedRequestState of self.requests
left_padded_prompt_mask = None
# block table is stored in self.req_ids2blocks (only passed for decode)
# block table is only passed for decode
block_table = None

model_inputs = SamplingForwardInputs(
Expand Down Expand Up @@ -1094,24 +1112,31 @@ def _prepare_decode(
for req_id in req_ids:
# adding new blocks if needed
if self.tkv % self.block_size == 0:
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
n_blocks = max(n_blocks, len(self.req_ids2blocks[req_id]))
req_state: SamplingRequestState = self.requests[req_id]
total_tokens = req_state.left_padding % self.block_size \
+ req_state.num_computed_tokens + 1
Comment on lines +1126 to +1127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment to explain this formula (especially the req_state.left_padding % self.block_size bit)?

blocks = self.kv_cache_manager.allocate_new_blocks(
req_id, total_tokens)
assert len(blocks) == 1
n_blocks = max(n_blocks, len(self._get_blocks(req_id)))

for req_id in req_ids:
# TODO: Will this always just be one token ID if there's no spec
# or jump decoding?

req_state: SamplingRequestState = self.requests[req_id]
req_state = self.requests[req_id]

# filling block table with padding blocks to make it rectangular
# Note: the padding block id 0 here is chosen arbitrarily, it can
# be any allocated block id on the Sypre card (has to be in range
# [0, self.n_blocks - 1]). Further, it also be a block id that holds
# actual KV cache for another (or the same) sequence.
blocks = self.req_ids2blocks[req_id].copy()
for i in range(n_blocks - len(self.req_ids2blocks[req_id])):
blocks.appendleft(0)
block_table.append(blocks)
blocks = self._get_blocks(req_id)
left_padding_blocks = n_blocks - len(blocks)

req_block_ids = [0] * left_padding_blocks + \
[block.block_id for block in blocks]
block_table.append(req_block_ids)

# slot_mapping for all blocks of sequence
start_slot = block_table[-1][-1] * self.block_size
Expand Down Expand Up @@ -1856,10 +1881,11 @@ def _prepare_chunked_prefill(self, req_id: str):
input_positions_np = input_positions.numpy()

# create block table tensor
blocks = [0] * (left_padding // self.block_size) + list(
self.req_ids2blocks[req_id])
blocks = self._get_blocks(req_id)
block_end = (chunk_i + 1) * self.chunk_blocks_count
block_table = torch.tensor(blocks[:block_end],
block_ids = [0] * (left_padding // self.block_size) + \
[block.block_id for block in blocks]
block_table = torch.tensor(block_ids[:block_end],
dtype=torch.int64).unsqueeze(0)

slot_mapping = []
Expand Down Expand Up @@ -1973,8 +1999,10 @@ def _prepare_decode(
# adding new blocks if needed
req_state = self.requests[req_id]
if req_state.num_computed_tokens % self.block_size == 0:
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
max_n_blocks = max(max_n_blocks, len(self.req_ids2blocks[req_id]))
blocks = self.kv_cache_manager.allocate_new_blocks(
req_id, req_state.num_computed_tokens + 1)
assert len(blocks) == 1
max_n_blocks = max(max_n_blocks, len(self._get_blocks(req_id)))

# We'll calculate tkv on the fly, it is the max num computed tokens
# of a request since there is no tokens left padding, only for blocks
Expand All @@ -1990,13 +2018,11 @@ def _prepare_decode(
# be any allocated block id on the Sypre card (has to be in range
# [0, self.n_blocks - 1]). Further, it also be a block id that holds
# actual KV cache for another (or the same) sequence.
blocks = self.req_ids2blocks[req_id].copy()
left_pad_blocks_count = (max_n_blocks -
len(self.req_ids2blocks[req_id]))

for _ in range(left_pad_blocks_count):
blocks.appendleft(0)
block_table.append(blocks)
blocks = self._get_blocks(req_id)
left_pad_blocks_count = (max_n_blocks - len(blocks))
block_ids = [0]*left_pad_blocks_count + \
[block.block_id for block in blocks]
block_table.append(block_ids)

# slot_mapping for all blocks of sequence
start_slot = block_table[-1][-1] * self.block_size
Expand Down Expand Up @@ -2055,13 +2081,11 @@ def add_new_request(self, request: NewRequestData):
req_id = request.req_id
prompt_token_ids = request.prompt_token_ids
sampling_params = request.sampling_params
is_new_batch = len(self.req_ids2blocks) == 0
is_new_batch = len(self.req_ids2reserved_blocks) == 0
prompt_len = len(prompt_token_ids)

self.prefill_batch.clear_requests()

blocks_count = math.ceil(prompt_len / self.block_size)

# set the new tkv to the prompt length if starting a new decode batch
if is_new_batch:
self.tkv = prompt_len
Expand All @@ -2077,8 +2101,7 @@ def add_new_request(self, request: NewRequestData):
self.req_ids2reserved_blocks[req_id] = n_reserved_blocks

# allocate blocks
blocks = [self.block_pool.popleft() for _ in range(blocks_count)]
self.req_ids2blocks[req_id] = deque(blocks)
self.kv_cache_manager.allocate_new_blocks(req_id, prompt_len)

# Add new request to the cached states.
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
Expand Down