-
Notifications
You must be signed in to change notification settings - Fork 31
[PC] Refactor CB model runner to use vLLMs block pool #585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
5447dc5
a43e072
6378294
f8dd1d2
0296747
8420160
a491e7d
f71df91
f201675
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,9 +1,9 @@ | ||||||
| import math | ||||||
| import time | ||||||
| from abc import ABC, abstractmethod | ||||||
| from collections import deque | ||||||
| from collections.abc import Iterable | ||||||
| from dataclasses import asdict, dataclass, field | ||||||
| from logging import DEBUG | ||||||
| from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast | ||||||
|
|
||||||
| import torch | ||||||
|
|
@@ -16,7 +16,10 @@ | |||||
| from vllm.model_executor.layers.pooler import ClassifierPooler, Pooler | ||||||
| from vllm.sampling_params import SamplingType | ||||||
| from vllm.utils import is_pin_memory_available | ||||||
| from vllm.v1.core.block_pool import BlockPool | ||||||
| from vllm.v1.core.kv_cache_utils import KVCacheBlock | ||||||
| from vllm.v1.core.sched.output import CachedRequestData | ||||||
| from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager | ||||||
| from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec | ||||||
| from vllm.v1.outputs import LogprobsTensors, SamplerOutput | ||||||
| from vllm.v1.pool.metadata import PoolingMetadata | ||||||
|
|
@@ -812,8 +815,6 @@ def __init__( | |||||
|
|
||||||
| self.block_size = SpyrePlatform.get_block_size() | ||||||
|
|
||||||
| # TODO: move to a KV cache manager | ||||||
| self.req_ids2blocks: dict[str, deque[int]] = {} | ||||||
| # max number of blocks needed (reserved) per request id | ||||||
| self.req_ids2reserved_blocks: dict[str, int] = {} | ||||||
|
|
||||||
|
|
@@ -863,7 +864,22 @@ def complete_warmup(self) -> None: | |||||
| def _set_blocks(self, num_blocks: int) -> None: | ||||||
| # set number of available blocks and populate block_pool | ||||||
| self.n_blocks = num_blocks - 1 | ||||||
| self.block_pool = deque([i for i in range(1, self.n_blocks + 1)]) | ||||||
| self.block_pool = BlockPool(num_gpu_blocks=self.n_blocks + 1, | ||||||
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| enable_caching=False, | ||||||
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| enable_kv_cache_events=False) | ||||||
| attn_spec = FullAttentionSpec(block_size=self.block_size, | ||||||
|
||||||
| # - Set the block size (in tokens) to the maximum sequence length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, what does happen if we set
block_size = self.vllm_config.cache_config.block_size
to
block_size = SpyrePlatform.get_block_size()
in get_kv_cache_spec(), does that solve it? Or is that not working, because the upstream scheduler will call get_kv_cache_spec() and we require it to return the max model length to disable vllm's paged attention scheduling ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the engine core calls get_kv_cache_spec() and it has to return the max model length :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We set the block size to the max model len in platform.py to disable vllm's paged attention scheduling
What does this mean exactly? I'm wondering why we need to do this, but agree it is probably better to address as a follow-up and keep this PR nicely scoped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference, there is a comment in platform.py about this:
vllm-spyre/vllm_spyre/platform.py
Line 182 in 7e03127
| # To disable any paged attention ops in the base scheduler, we: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdoublep I believe this was a hack introduced by @joerunde (please correct me if I am wrong). in our Spyre scheduler we call super.schedule() which calls the upstream vllm scheduler. by setting the block size to the max model length we ensure the upstream policy does not interfere with our custom logic in the plugin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks. Let's keep it like that for now.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering why we need to set these to arbitrary values? My understanding is that we are only really using the attentionspec for passing the block size to the FullAttentionManager. Perhaps we could create a SpyreAttentionSpec or something without these unnecessary arguments. This is just a thought, doesn't need to be addressed before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the FullAttentionManager has an assertion to test that the KVCacheSpec is of type FullAttentionSpec. So we'll have to go with the dummy values for now. :/
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
maxdebayser marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a bit confused reading the code because I assumed this dictionary mapped the request id to the list of block IDs that were reserved, and couldn't understand how that interacted with the KV cache manager. Now I understand it is just the number of reserved blocks. Maybe we could change the name to something like req_ids2num_reserved_blocks or something (as a follow-up).
maxdebayser marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment to explain this formula (especially the req_state.left_padding % self.block_size bit)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems to me that
self.req_ids2reserved_blocksis the last thing related to block management left in the model runner. As far as I can tell there is no concept of reserved blocks in theFullAttentionManager. We could potentially derive a custom class fromFullAttentionManagerthat addsreq_ids2reserved_blocks... Not sure if this is something we want to do. Advantage: all block managing is then happening in the kv cache manager (not model runner and kv cache manager), downside: we need a custom class (which only adds one field tho...). WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we could do something like that. Actually, I was going to ask you why we even have a concept of reserved blocks. Is it because we don't support preemption?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the volumetric constraint is probably always lower than the available number of blocks. We need to verify this information and resolve this question in a future PR. @yannicks1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why we couldn't support pre-emption actually, but for now let's aim to keep the behaviour the same w.r.t reserved blocks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are actively discussing this. I agree that we should not change behavior in this PR. This PR is only refactoring/integrating upstream code.
For #586 , follow up PRs:
the current implementation in #586 (still using the reserved block concept) does not consider blocks with reference counts > 1 (prefix hits) neither in the scheduler nor in the model runner (which actually 'reserves' the number of blocks).
As we opted for a prefix-caching unaware scheduler, the minimal thing we should do (probably in #586) is to modify the model runner to consider prefix hits when reserving the number of blocks. The scheduler is then still unaware of prefix hits when making decision for a new sequence, but the total number of available blocks is less conservative as it is considering the duplicates in the existing decode batch (blocks with reference count > 2)
In a next step (follow up PR) we can remove the concept of reserved blocks. I believe this should be doable by proving that the volumetric constraint is always stricter than the number of blocks constraint. if that is not given, we could indeed support preemption.