Skip to content

Commit a043643

Browse files
committed
FIxes the common_computed_block_nums updating.
Signed-off-by: Tao He <[email protected]>
1 parent b4867ba commit a043643

File tree

6 files changed

+20
-11
lines changed

6 files changed

+20
-11
lines changed

tests/worker/test_model_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import pytest
55
import torch
66

7-
from tests.kernels.utils import (STR_FLASH_ATTN_VAL,
8-
override_backend_env_variable)
7+
from tests.kernels.utils import override_backend_env_variable
98
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
109
init_distributed_environment)
1110
from vllm.engine.arg_utils import EngineArgs
1211
from vllm.model_executor.sampling_metadata import SamplingMetadata
1312
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
1413
SequenceData, SequenceGroupMetadata)
15-
from vllm.utils import get_open_port
14+
from vllm.utils import STR_FLASH_ATTN_VAL, get_open_port
1615
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
1716

1817

vllm/core/block_manager_v1.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -680,10 +680,15 @@ def access_all_blocks_in_seq(
680680
for block in block_table:
681681
block.last_accessed = access_time
682682

683-
def compute_full_blocks_in_seq(self, seq: Sequence):
683+
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
684684
if seq.seq_id not in self.block_tables:
685685
return
686-
max_full_block = seq.get_len() // self.block_size - 1
686+
# We ensure at least 1 token to prefill even fully matched in the
687+
# model runner, so "computing computed_blocks as it is" is safe here.
688+
max_full_block = min(
689+
seq.get_prompt_len(),
690+
seq.data.get_num_computed_tokens() +
691+
token_chunk_size) // self.block_size
687692
block_table = self.block_tables[seq.seq_id]
688693
if max_full_block == -1:
689694
return
@@ -717,10 +722,11 @@ def get_common_computed_block_ids(
717722
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
718723
return commonprefix([ids for ids in ids_list if ids != []])
719724

720-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
725+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
726+
token_chunk_size: int):
721727
if self.enable_caching:
722728
for seq in seq_group.get_seqs():
723-
self.compute_full_blocks_in_seq(seq)
729+
self.compute_full_blocks_in_seq(seq, token_chunk_size)
724730

725731
def get_prefix_cache_hit_rate(self, device: Device) -> float:
726732
if device == Device.GPU:

vllm/core/block_manager_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
286286
self._last_access_blocks_tracker.update_last_access(
287287
seq.seq_id, now)
288288

289-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
289+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
290+
token_chunk_size: int):
290291
# The only need for mark block as computed is for prefix caching,
291292
# while currently we could determine whether one block is computed
292293
# or not by check whether it has content hash.

vllm/core/embedding_model_block_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
8080
seq_group: SequenceGroup) -> List[int]:
8181
return None # type: ignore
8282

83-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
83+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
84+
token_chunk_size: int):
8485
pass
8586

8687
def get_prefix_cache_hit_rate(self, device: Device) -> float:

vllm/core/interfaces.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def get_common_computed_block_ids(
115115
pass
116116

117117
@abstractmethod
118-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
118+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
119+
token_chunk_size: int):
119120
pass
120121

121122
@abstractmethod

vllm/core/scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
11451145
# will crash the vLLM instance / will not retry.
11461146
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
11471147
self.block_manager.mark_blocks_as_computed(
1148-
scheduled_seq_group.seq_group)
1148+
scheduled_seq_group.seq_group,
1149+
scheduled_seq_group.token_chunk_size)
11491150

11501151
scheduler_time = time.perf_counter() - scheduler_start_time
11511152
# Add this to scheduler time to all the sequences that are currently

0 commit comments

Comments
 (0)