|
6 | 6 |
|
7 | 7 | Run `pytest tests/models/test_chunked_prefill.py`. |
8 | 8 | """ |
| 9 | +from contextlib import nullcontext |
9 | 10 |
|
10 | 11 | import pytest |
11 | 12 |
|
@@ -156,3 +157,68 @@ def test_models_with_fp8_kv_cache( |
156 | 157 | name_0="no_chunked_prefill", |
157 | 158 | name_1="chunked_prefill", |
158 | 159 | ) |
| 160 | + |
| 161 | + |
| 162 | +@pytest.mark.parametrize("max_tokens", [16]) |
| 163 | +@pytest.mark.parametrize("enforce_eager", [False]) |
| 164 | +@pytest.mark.parametrize("chunk_size", [30, 32]) |
| 165 | +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) |
| 166 | +# NOTE: Increasing this in this suite will fail CI because we currently cannot |
| 167 | +# reset distributed env properly. Use a value > 1 just when you test. |
| 168 | +@pytest.mark.parametrize("tensor_parallel_size", [1]) |
| 169 | +def test_with_prefix_caching( |
| 170 | + vllm_runner, |
| 171 | + max_tokens: int, |
| 172 | + enforce_eager: bool, |
| 173 | + chunk_size: int, |
| 174 | + use_v2_block_manager: bool, |
| 175 | + tensor_parallel_size: int, |
| 176 | +) -> None: |
| 177 | + """ |
| 178 | + Checks exact match decode with and without prefix caching |
| 179 | + with chunked prefill enabled. |
| 180 | + """ |
| 181 | + model = "meta-llama/Llama-2-7b-chat-hf" |
| 182 | + # The common prompt has 142 tokens with Llama-2 tokenizer. |
| 183 | + common_prompt = "You are a helpful AI assistant " * 20 |
| 184 | + unique_prompts = [ |
| 185 | + "Question", # Warmup |
| 186 | + "Question", # Fully cached |
| 187 | + "Another question", # Partial cached |
| 188 | + ] |
| 189 | + full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts] |
| 190 | + |
| 191 | + max_num_batched_tokens = max_num_seqs = chunk_size |
| 192 | + outputs = {} # type: ignore |
| 193 | + check_result = True |
| 194 | + for enable in (True, False): |
| 195 | + with vllm_runner( |
| 196 | + model, |
| 197 | + dtype="half", |
| 198 | + max_num_batched_tokens=max_num_batched_tokens, |
| 199 | + enable_chunked_prefill=True, |
| 200 | + enable_prefix_caching=enable, |
| 201 | + tensor_parallel_size=tensor_parallel_size, |
| 202 | + use_v2_block_manager=use_v2_block_manager, |
| 203 | + enforce_eager=enforce_eager, |
| 204 | + max_num_seqs=max_num_seqs, |
| 205 | + ) as vllm_model: |
| 206 | + # It should fail when prefix caching is enable and chunk |
| 207 | + # size is not a multiple of block size (16). |
| 208 | + should_fail = chunk_size % 16 != 0 and enable |
| 209 | + check_result &= not should_fail |
| 210 | + outputs[enable] = [] |
| 211 | + # Send the request one-by-one to ensure the cache is populated. |
| 212 | + with pytest.raises(ValueError) if should_fail else nullcontext(): |
| 213 | + for prompt in full_prompts: |
| 214 | + outputs[enable] += vllm_model.generate_greedy([prompt], |
| 215 | + max_tokens) |
| 216 | + |
| 217 | + # Check results only if we did not expect a failure. |
| 218 | + if check_result: |
| 219 | + check_outputs_equal( |
| 220 | + outputs_0_lst=outputs[False], |
| 221 | + outputs_1_lst=outputs[True], |
| 222 | + name_0="w/o prefix caching", |
| 223 | + name_1="with prefix caching", |
| 224 | + ) |
0 commit comments