Skip to content

Commit d57951f

Browse files
committed
comments and tests
1 parent f1e9548 commit d57951f

File tree

5 files changed

+86
-5
lines changed

5 files changed

+86
-5
lines changed

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def test_models_with_fp8_kv_cache(
155155

156156
@pytest.mark.parametrize("max_tokens", [16])
157157
@pytest.mark.parametrize("enforce_eager", [False])
158-
@pytest.mark.parametrize("chunk_size", [30, 32, 64])
158+
@pytest.mark.parametrize("chunk_size", [30, 32])
159159
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
160160
# NOTE: Increasing this in this suite will fail CI because we currently cannot
161161
# reset distributed env properly. Use a value > 1 just when you test.

tests/core/test_block_manager.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
595595

596596
# assert all blocks are free now
597597
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
598+
599+
600+
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
601+
"""When prefix cache and chunked prefill are enabled, the block manager
602+
should only mark a chunk of blocks as computed instead of all blocks.
603+
"""
604+
605+
block_size = 4
606+
num_cpu_blocks = 0
607+
num_gpu_blocks = 16
608+
block_manager = BlockSpaceManagerV1(block_size,
609+
num_gpu_blocks,
610+
num_cpu_blocks,
611+
watermark=0,
612+
enable_caching=True)
613+
614+
# Set prompt size to have num_gpu_blocks - 1 full blocks.
615+
prompt_length = block_size * num_gpu_blocks - 1
616+
617+
# Allocate (reserve) all blocks.
618+
_, seq_group = create_dummy_prompt("0",
619+
prompt_length,
620+
block_size=block_size)
621+
block_manager.allocate(seq_group)
622+
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
623+
624+
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
625+
token_chunk_size = int(block_size * 2.5)
626+
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
627+
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
628+
assert len(computed_blocks) == 2
629+
630+
# Actual computed tokens.
631+
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
632+
633+
# 2nd chunk: Complete 3rd block and additional 4 blocks.
634+
token_chunk_size = int(block_size * 4.5)
635+
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
636+
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
637+
assert len(computed_blocks) == 7

tests/core/test_chunked_prefill_scheduler.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
562562
assert len(get_sequence_groups(out)) == max_seqs
563563
assert not running[0].is_prefill()
564564
assert not running[1].is_prefill()
565+
566+
567+
def test_perfix_caching():
568+
"""Verify allocating full blocks when prefix caching is enabled."""
569+
block_size = 4
570+
max_seqs = 10
571+
max_model_len = 80
572+
max_num_batched_tokens = 64
573+
scheduler_config = SchedulerConfig(max_num_batched_tokens,
574+
max_seqs,
575+
max_model_len,
576+
enable_chunked_prefill=True)
577+
cache_config = CacheConfig(block_size,
578+
1.0,
579+
1,
580+
"auto",
581+
enable_prefix_caching=True)
582+
cache_config.num_cpu_blocks = 0
583+
cache_config.num_gpu_blocks = 32
584+
scheduler = Scheduler(scheduler_config, cache_config, None)
585+
running: List[SequenceGroup] = []
586+
587+
# Add seq groups to scheduler.
588+
for i in range(2):
589+
_, seq_group = create_dummy_prompt(str(i),
590+
block_size=block_size,
591+
prompt_length=50)
592+
scheduler.add_seq_group(seq_group)
593+
running.append(seq_group)
594+
595+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
596+
assert set(get_sequence_groups(out)) == set(running)
597+
assert seq_group_meta[0].token_chunk_size == 50
598+
# Verify it is chunked. Note that although the budget is 64-50=14,
599+
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
600+
# tokens are allocated.
601+
assert seq_group_meta[1].token_chunk_size == 12
602+
assert out.num_prefill_groups == 2
603+
assert out.num_batched_tokens == 62

vllm/core/block_manager_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
693693
block_table = self.block_tables[seq.seq_id]
694694
if computed_full_blocks == 0:
695695
return
696-
for i in reversed(range(computed_full_blocks - 1)):
696+
for i in reversed(range(computed_full_blocks)):
697697
if block_table[i].computed:
698698
break
699699
block_table[i].computed = True

vllm/core/scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,12 +1358,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
13581358
# the number of new tokens that is dividable by the block size
13591359
# to avoid partial block matching.
13601360
block_size = self.cache_config.block_size
1361-
if budget.token_budget % block_size != 0:
1361+
reminder = budget.token_budget % block_size
1362+
if reminder != 0:
13621363
raise ValueError("When enabling chunked prefill and "
13631364
"prefix caching, max_num_batched_tokens "
13641365
"(chunk size) must be dividable by "
1365-
"block size, but got "
1366-
f"{budget.token_budget % block_size = }")
1366+
"block size, but got chunk_size "
1367+
f"({budget.token_budget}) % block_size "
1368+
f"({block_size}) = {reminder}")
13671369
if remaining_token_budget < num_new_tokens:
13681370
num_new_tokens = (remaining_token_budget //
13691371
block_size) * block_size

0 commit comments

Comments
 (0)