-
Notifications
You must be signed in to change notification settings - Fork 31
[PC] Refactor CB model runner to use vLLMs block pool #585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
5447dc5
Replace the block_pool list with the vLLM Block Pool
maxdebayser a43e072
manage padding blocks outside of block pool
maxdebayser 6378294
Switch to Single Type KV Cache manager
maxdebayser f8dd1d2
fix linting problem
maxdebayser 0296747
address review comments
maxdebayser 8420160
Merge branch 'main' into integrate_block_pool
maxdebayser a491e7d
Merge branch 'main' into integrate_block_pool
maxdebayser f71df91
address review comments
maxdebayser f201675
revert bad change
maxdebayser File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,9 @@ | ||
| import math | ||
| import time | ||
| from abc import ABC, abstractmethod | ||
| from collections import deque | ||
| from collections.abc import Iterable | ||
| from dataclasses import asdict, dataclass, field | ||
| from logging import DEBUG | ||
| from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast | ||
|
|
||
| import torch | ||
|
|
@@ -16,7 +16,10 @@ | |
| from vllm.model_executor.layers.pooler import ClassifierPooler, Pooler | ||
| from vllm.sampling_params import SamplingType | ||
| from vllm.utils import is_pin_memory_available | ||
| from vllm.v1.core.block_pool import BlockPool | ||
| from vllm.v1.core.kv_cache_utils import KVCacheBlock | ||
| from vllm.v1.core.sched.output import CachedRequestData | ||
| from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager | ||
| from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec | ||
| from vllm.v1.outputs import LogprobsTensors, SamplerOutput | ||
| from vllm.v1.pool.metadata import PoolingMetadata | ||
|
|
@@ -812,10 +815,8 @@ def __init__( | |
|
|
||
| self.block_size = SpyrePlatform.get_block_size() | ||
|
|
||
| # TODO: move to a KV cache manager | ||
| self.req_ids2blocks: dict[str, deque[int]] = {} | ||
| # max number of blocks needed (reserved) per request id | ||
| self.req_ids2reserved_blocks: dict[str, int] = {} | ||
| self.req_ids2num_reserved_blocks: dict[str, int] = {} | ||
|
|
||
| self.tkv: int = 0 | ||
|
|
||
|
|
@@ -863,7 +864,29 @@ def complete_warmup(self) -> None: | |
| def _set_blocks(self, num_blocks: int) -> None: | ||
| # set number of available blocks and populate block_pool | ||
| self.n_blocks = num_blocks - 1 | ||
| self.block_pool = deque([i for i in range(1, self.n_blocks + 1)]) | ||
| self.block_pool = BlockPool(num_gpu_blocks=self.n_blocks + 1, | ||
| enable_caching=False, | ||
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| enable_kv_cache_events=False) | ||
| attn_spec = FullAttentionSpec( | ||
| block_size=self.block_size, | ||
| # dummy values | ||
| num_kv_heads=1, | ||
| head_size=1, | ||
| dtype=torch.float16) | ||
| self.kv_cache_manager = FullAttentionManager( | ||
| kv_cache_spec=attn_spec, | ||
| block_pool=self.block_pool, | ||
| # Currently don't support models with more than one | ||
| # attention type, e.g. full and sliding window, so | ||
| # there is only one group. | ||
| kv_cache_group_id=0, | ||
| # We don't support DCP | ||
| # https://docs.vllm.ai/en/latest/serving/context_parallel_deployment/#decode-context-parallel | ||
| dcp_world_size=1, | ||
| ) | ||
|
|
||
| def _get_blocks(self, request_id: str) -> list[KVCacheBlock]: | ||
| return self.kv_cache_manager.req_to_blocks[request_id] | ||
|
|
||
| def get_total_spyre_blocks(self) -> int: | ||
| """Returns the total number of KV cache blocks available for spyre. | ||
|
|
@@ -912,12 +935,13 @@ def update_states(self, scheduler_output): | |
| # TODO: move to kv cache manager | ||
| # Continuous batching: free blocks | ||
| for req_id in scheduler_output.finished_req_ids: | ||
| if blocks_to_free := self.req_ids2blocks.pop(req_id, None): | ||
| if logger.isEnabledFor(DEBUG) and (blocks_to_free := | ||
| self._get_blocks(req_id)): | ||
| logger.debug("Freeing request id: %s", req_id) | ||
| self.req_ids2reserved_blocks.pop(req_id) | ||
| for block_id in blocks_to_free: | ||
| logger.debug("Freeing block with id: %s", block_id) | ||
| self.block_pool.append(block_id) | ||
| for block in blocks_to_free: | ||
| logger.debug("Freeing block with id: %s", block.block_id) | ||
| self.req_ids2num_reserved_blocks.pop(req_id, None) | ||
| self.kv_cache_manager.free(req_id) | ||
|
|
||
| def _prepare_prompt( | ||
| self, | ||
|
|
@@ -930,7 +954,7 @@ def _prepare_prompt( | |
| req_id = request.req_id | ||
| prompt_token_ids = request.prompt_token_ids | ||
| sampling_params = request.sampling_params | ||
| is_new_batch = len(self.req_ids2blocks) == 0 | ||
| is_new_batch = len(self.req_ids2num_reserved_blocks) == 0 | ||
| prompt_len = len(prompt_token_ids) | ||
|
|
||
| # make sure that the current tkv of the decode batch is greater or | ||
|
|
@@ -989,19 +1013,20 @@ def _prepare_prompt( | |
| (self.tkv - len(prompt_token_ids)) / self.block_size) | ||
| n_reserved_blocks = math.ceil( | ||
| n / self.block_size) - n_fully_padded_blocks | ||
| self.req_ids2reserved_blocks[req_id] = n_reserved_blocks | ||
| self.req_ids2num_reserved_blocks[req_id] = n_reserved_blocks | ||
|
|
||
| # filling block table and slot mapping | ||
| blocks = [] | ||
| slots = [] | ||
| for pos_i in range(right_padding_tkv): | ||
| if pos_i % self.block_size == 0: | ||
| block_number = self.block_pool.popleft() | ||
| blocks.append(block_number) | ||
| block_offset = pos_i % self.block_size | ||
| slot = block_number * self.block_size + block_offset | ||
| slots.append(slot) | ||
| self.req_ids2blocks[req_id] = deque(blocks) | ||
|
|
||
| blocks = self.kv_cache_manager.allocate_new_blocks( | ||
| req_id, right_padding_tkv) | ||
|
|
||
| block_offsets = [block.block_id * self.block_size for block in blocks] | ||
| slot_mapping = torch.arange(self.block_size, | ||
| dtype=torch.int64).repeat(len(blocks)) | ||
| slot_mapping += torch.tensor(block_offsets, | ||
| dtype=torch.int64).repeat_interleave( | ||
| self.block_size) | ||
| slot_mapping.unsqueeze_(0) | ||
|
|
||
| # Add new request to the cached states. | ||
| if sampling_params.sampling_type == SamplingType.RANDOM_SEED: | ||
|
|
@@ -1029,7 +1054,6 @@ def _prepare_prompt( | |
| self.prefill_batch.refresh_metadata() | ||
|
|
||
| self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu') | ||
| slot_mapping = torch.tensor([slots], dtype=torch.int64) | ||
| prompt_token_ids_tensor = torch.tensor(prompt_token_ids, | ||
| dtype=torch.long, | ||
| device=torch.device("cpu")) | ||
|
|
@@ -1049,7 +1073,7 @@ def _prepare_prompt( | |
| current_tkv_mask = None | ||
| # left padding info is stored in CachedRequestState of self.requests | ||
| left_padded_prompt_mask = None | ||
| # block table is stored in self.req_ids2blocks (only passed for decode) | ||
| # block table is only passed for decode | ||
| block_table = None | ||
|
|
||
| model_inputs = SamplingForwardInputs( | ||
|
|
@@ -1094,24 +1118,35 @@ def _prepare_decode( | |
| for req_id in req_ids: | ||
| # adding new blocks if needed | ||
| if self.tkv % self.block_size == 0: | ||
| self.req_ids2blocks[req_id].append(self.block_pool.popleft()) | ||
| n_blocks = max(n_blocks, len(self.req_ids2blocks[req_id])) | ||
| req_state: SamplingRequestState = self.requests[req_id] | ||
| # we want to allocate the block for 1 next token. | ||
| # So we need to compute the number of tokens that the | ||
| # kv_cache_manager knows about: the number of computed | ||
| # tokens so far plus any intra-block padding. | ||
| total_tokens = req_state.left_padding % self.block_size \ | ||
| + req_state.num_computed_tokens + 1 | ||
|
Comment on lines
+1126
to
+1127
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a comment to explain this formula (especially the |
||
| blocks = self.kv_cache_manager.allocate_new_blocks( | ||
| req_id, total_tokens) | ||
| assert len(blocks) == 1 | ||
| n_blocks = max(n_blocks, len(self._get_blocks(req_id))) | ||
|
|
||
| for req_id in req_ids: | ||
| # TODO: Will this always just be one token ID if there's no spec | ||
| # or jump decoding? | ||
|
|
||
| req_state: SamplingRequestState = self.requests[req_id] | ||
| req_state = self.requests[req_id] | ||
tdoublep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # filling block table with padding blocks to make it rectangular | ||
| # Note: the padding block id 0 here is chosen arbitrarily, it can | ||
| # be any allocated block id on the Sypre card (has to be in range | ||
| # [0, self.n_blocks - 1]). Further, it also be a block id that holds | ||
| # actual KV cache for another (or the same) sequence. | ||
| blocks = self.req_ids2blocks[req_id].copy() | ||
| for i in range(n_blocks - len(self.req_ids2blocks[req_id])): | ||
| blocks.appendleft(0) | ||
| block_table.append(blocks) | ||
| blocks = self._get_blocks(req_id) | ||
| left_padding_blocks = n_blocks - len(blocks) | ||
|
|
||
| req_block_ids = [0] * left_padding_blocks + \ | ||
| [block.block_id for block in blocks] | ||
| block_table.append(req_block_ids) | ||
|
|
||
| # slot_mapping for all blocks of sequence | ||
| start_slot = block_table[-1][-1] * self.block_size | ||
|
|
@@ -1287,7 +1322,7 @@ def get_req_id_to_index(self, is_prefill: bool) -> dict[str, int]: | |
| return req_id_to_index | ||
|
|
||
| def get_n_free_blocks(self) -> int: | ||
| return self.n_blocks - sum(self.req_ids2reserved_blocks.values()) | ||
| return self.n_blocks - sum(self.req_ids2num_reserved_blocks.values()) | ||
|
|
||
| def no_prompt_logprob(self, is_prefill: bool) -> bool: | ||
| if is_prefill: | ||
|
|
@@ -1856,10 +1891,11 @@ def _prepare_chunked_prefill(self, req_id: str): | |
| input_positions_np = input_positions.numpy() | ||
|
|
||
| # create block table tensor | ||
| blocks = [0] * (left_padding // self.block_size) + list( | ||
| self.req_ids2blocks[req_id]) | ||
| blocks = self._get_blocks(req_id) | ||
| block_end = (chunk_i + 1) * self.chunk_blocks_count | ||
| block_table = torch.tensor(blocks[:block_end], | ||
| block_ids = [0] * (left_padding // self.block_size) + \ | ||
| [block.block_id for block in blocks] | ||
| block_table = torch.tensor(block_ids[:block_end], | ||
| dtype=torch.int64).unsqueeze(0) | ||
|
|
||
| slot_mapping = [] | ||
|
|
@@ -1973,8 +2009,10 @@ def _prepare_decode( | |
| # adding new blocks if needed | ||
| req_state = self.requests[req_id] | ||
| if req_state.num_computed_tokens % self.block_size == 0: | ||
| self.req_ids2blocks[req_id].append(self.block_pool.popleft()) | ||
| max_n_blocks = max(max_n_blocks, len(self.req_ids2blocks[req_id])) | ||
| blocks = self.kv_cache_manager.allocate_new_blocks( | ||
| req_id, req_state.num_computed_tokens + 1) | ||
| assert len(blocks) == 1 | ||
| max_n_blocks = max(max_n_blocks, len(self._get_blocks(req_id))) | ||
|
|
||
| # We'll calculate tkv on the fly, it is the max num computed tokens | ||
| # of a request since there is no tokens left padding, only for blocks | ||
|
|
@@ -1990,13 +2028,11 @@ def _prepare_decode( | |
| # be any allocated block id on the Sypre card (has to be in range | ||
| # [0, self.n_blocks - 1]). Further, it also be a block id that holds | ||
| # actual KV cache for another (or the same) sequence. | ||
| blocks = self.req_ids2blocks[req_id].copy() | ||
| left_pad_blocks_count = (max_n_blocks - | ||
| len(self.req_ids2blocks[req_id])) | ||
|
|
||
| for _ in range(left_pad_blocks_count): | ||
| blocks.appendleft(0) | ||
| block_table.append(blocks) | ||
| blocks = self._get_blocks(req_id) | ||
| left_pad_blocks_count = (max_n_blocks - len(blocks)) | ||
| block_ids = [0]*left_pad_blocks_count + \ | ||
| [block.block_id for block in blocks] | ||
| block_table.append(block_ids) | ||
|
|
||
| # slot_mapping for all blocks of sequence | ||
| start_slot = block_table[-1][-1] * self.block_size | ||
|
|
@@ -2055,13 +2091,11 @@ def add_new_request(self, request: NewRequestData): | |
| req_id = request.req_id | ||
| prompt_token_ids = request.prompt_token_ids | ||
| sampling_params = request.sampling_params | ||
| is_new_batch = len(self.req_ids2blocks) == 0 | ||
| is_new_batch = len(self.req_ids2num_reserved_blocks) == 0 | ||
| prompt_len = len(prompt_token_ids) | ||
|
|
||
| self.prefill_batch.clear_requests() | ||
|
|
||
| blocks_count = math.ceil(prompt_len / self.block_size) | ||
|
|
||
| # set the new tkv to the prompt length if starting a new decode batch | ||
| if is_new_batch: | ||
| self.tkv = prompt_len | ||
|
|
@@ -2074,11 +2108,10 @@ def add_new_request(self, request: NewRequestData): | |
| # subtract the padding blocks from the reserved blocks | ||
| n_reserved_blocks = math.ceil(total_tokens / self.block_size) | ||
|
|
||
| self.req_ids2reserved_blocks[req_id] = n_reserved_blocks | ||
| self.req_ids2num_reserved_blocks[req_id] = n_reserved_blocks | ||
|
|
||
| # allocate blocks | ||
| blocks = [self.block_pool.popleft() for _ in range(blocks_count)] | ||
| self.req_ids2blocks[req_id] = deque(blocks) | ||
| self.kv_cache_manager.allocate_new_blocks(req_id, prompt_len) | ||
|
|
||
| # Add new request to the cached states. | ||
| if sampling_params.sampling_type == SamplingType.RANDOM_SEED: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.