diff --git a/tests/conftest.py b/tests/conftest.py index 77efaa40d..055275f2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,11 +104,13 @@ def pytest_generate_tests(metafunc): # markers if ("mode" in metafunc.fixturenames and "cb" not in existing_markers and "chunked_prefill" not in existing_markers + and "cp" not in existing_markers and "pc" not in existing_markers and "mode" not in existing_markers): metafunc.parametrize("mode", [ "sb", pytest.param("cb", marks=pytest.mark.cb, id="cb"), - pytest.param("cp", marks=pytest.mark.chunked_prefill, id="cp") + pytest.param("cp", marks=pytest.mark.chunked_prefill, id="cp"), + pytest.param("pc", marks=pytest.mark.prefix_caching, id="pc") ]) @@ -252,7 +254,7 @@ def remote_openai_server(request): skip_unsupported_tp_size(int(tp_size), backend) server_args.extend(["--tensor-parallel-size", str(tp_size)]) - if "mode" in params and params["mode"] in ["cb", "cp"]: + if "mode" in params and params["mode"] in ["cb", "cp", "pc"]: max_model_len = params["max_model_len"] max_num_seqs = params["max_num_seqs"] env_dict = { @@ -265,12 +267,16 @@ def remote_openai_server(request): str(max_model_len) ]) # Chunked prefill extra - if params["mode"] == "cp": + if params["mode"] in ["cp", "pc"]: env_dict.update({"VLLM_SPYRE_USE_CHUNKED_PREFILL": "1"}) server_args.extend([ "--max_num_batched_tokens", str(128), ]) + if params["mode"] == "pc": + server_args.extend([ + "--enable-prefix-caching", + ]) else: warmup_shapes = params['warmup_shapes'] diff --git a/tests/e2e/test_logits_processors.py b/tests/e2e/test_logits_processors.py index 7e27896fb..e64de3933 100644 --- a/tests/e2e/test_logits_processors.py +++ b/tests/e2e/test_logits_processors.py @@ -3,6 +3,7 @@ import pytest import torch from llm_cache import patch_environment +from llm_cache_util import force_engine_shutdown from spyre_util import ModelInfo from vllm import LLM, SamplingParams from vllm.config import VllmConfig @@ -57,6 +58,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: params = SamplingParams(max_tokens=5, temperature=0, logprobs=0) spyre_model.generate(prompt, params) + force_engine_shutdown(spyre_model) assert has_invoked_logits_processor @@ -154,6 +156,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: spy_outputs = {} params = [params0, params1, params2] outputs = spyre_model.generate(prompt, params) + force_engine_shutdown(spyre_model) assert spy_outputs[5] == outputs[0].outputs[0].token_ids assert spy_outputs[10] == outputs[1].outputs[0].token_ids diff --git a/tests/e2e/test_sampling_params.py b/tests/e2e/test_sampling_params.py index b14a79667..b6041e106 100644 --- a/tests/e2e/test_sampling_params.py +++ b/tests/e2e/test_sampling_params.py @@ -7,6 +7,10 @@ pytestmark = [pytest.mark.full_model, pytest.mark.other_e2e] +sb_mark = pytest.param("sb", marks=pytest.mark.sb, id="sb") +cb_mark = pytest.param("cb", marks=pytest.mark.cb, id="cb") +cp_mark = pytest.param("cp", marks=pytest.mark.chunked_prefill, id="cp") + def test_spyre_batch1_temperature(model: ModelInfo, backend, monkeypatch, use_llm_cache, warmup_shapes): @@ -212,6 +216,7 @@ def test_spyre_batch1_top_k(model: ModelInfo, backend, monkeypatch, assert token_div1 < token_div2 +@pytest.mark.parametrize("mode", [sb_mark, cb_mark, cp_mark]) def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, use_llm_cache, warmup_shapes, max_model_len, max_num_seqs, mode: str): @@ -253,6 +258,7 @@ def test_spyre_batch1_logit_bias(model: ModelInfo, backend, monkeypatch, assert output[0].outputs[0].text != output[1].outputs[0].text +@pytest.mark.parametrize("mode", [sb_mark, cb_mark, cp_mark]) def test_spyre_batch1_min_tokens(model: ModelInfo, backend, monkeypatch, use_llm_cache, max_model_len, max_num_seqs, warmup_shapes, mode: str): @@ -322,6 +328,7 @@ def test_spyre_batch1_ignore_eos(model: ModelInfo, backend, monkeypatch, assert output2.outputs[0].finish_reason != 'length' +@pytest.mark.parametrize("mode", [sb_mark, cb_mark, cp_mark]) def test_spyre_batch1_min_p(model: ModelInfo, backend, monkeypatch, use_llm_cache, max_model_len, max_num_seqs, warmup_shapes, mode: str): diff --git a/tests/e2e/test_spyre_basic.py b/tests/e2e/test_spyre_basic.py index 6b810ef5b..b1074feec 100644 --- a/tests/e2e/test_spyre_basic.py +++ b/tests/e2e/test_spyre_basic.py @@ -175,7 +175,7 @@ def test_max_model_len_override(model: ModelInfo, backend, warmup_shapes, "use_cb": True, "warmup_shapes": None, "use_chunked_prefill": mode == "cp", - } if mode in ["cb", "cp"] else { + } if mode in ["cb", "cp", "pc"] else { "use_cb": False, "warmup_shapes": warmup_shapes, }) diff --git a/tests/e2e/test_spyre_max_new_tokens.py b/tests/e2e/test_spyre_max_new_tokens.py index ee660ab84..eb74e4703 100644 --- a/tests/e2e/test_spyre_max_new_tokens.py +++ b/tests/e2e/test_spyre_max_new_tokens.py @@ -8,8 +8,13 @@ from spyre_util import DecodeWarmupShapes, ModelInfo, get_chicken_soup_prompts from vllm import SamplingParams +sb_mark = pytest.param("sb", marks=pytest.mark.sb, id="sb") +cb_mark = pytest.param("cb", marks=pytest.mark.cb, id="cb") +cp_mark = pytest.param("cp", marks=pytest.mark.chunked_prefill, id="cp") + @pytest.mark.parametrize("stop_last", [True, False]) +@pytest.mark.parametrize("mode", [sb_mark, cb_mark, cp_mark]) def test_output(model: ModelInfo, stop_last: bool, max_model_len: int, max_num_seqs: int, warmup_shapes: DecodeWarmupShapes, backend: str, mode: str, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/e2e/test_spyre_pc_scheduler_steps.py b/tests/e2e/test_spyre_pc_scheduler_steps.py new file mode 100644 index 000000000..f9dde696f --- /dev/null +++ b/tests/e2e/test_spyre_pc_scheduler_steps.py @@ -0,0 +1,794 @@ +"""Verification of the correctness of the step-by-step execution of chunked +prefill with prefix caching. It does so by comparing, at every engine step +(i.e. prefill or decode iteration), a bunch of attributes. +This allows a finer testing of the padding and scheduling implementation. + +Run `python -m pytest tests/e2e/test_spyre_pc_inference_steps.py`. +""" + +import pytest +from scheduling_utils import check_scheduler_inference_steps +from spyre_util import ModelInfo + + +@pytest.mark.cpu +@pytest.mark.chunked_prefill +@pytest.mark.full_model +@pytest.mark.prefix_caching +# These values are all parameterized for test sorting +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("max_model_len", [256]) +@pytest.mark.parametrize("max_num_batched_tokens", [128]) +@pytest.mark.parametrize("available_blocks", [None]) +def test_prefix_hit_within_batch(model: ModelInfo, backend: str, + monkeypatch: pytest.MonkeyPatch, + set_random_seed, max_num_seqs: int, + max_model_len: int, + max_num_batched_tokens: int, + available_blocks: int): + """ Scenario where two equal sequences are scheduled. + While prefilling the second sequence we have a prefix cache + hit and can reuse the first chunk. Note that the fetched prefix blocks + are still part of the existing decode batch. Hence we have duplicated + blocks in the block table for this example. + + Configuration: + * max_num_seqs: 2 + * number of prompts: 2 + * 0: len = 192, max tokens = 2, step joining = 0 + * 1: len = 192, max tokens = 2, step joining = 0 + """ + monkeypatch.setenv("VLLM_SPYRE_CP_INTERLEAVE_STEPS", "0") + + seqs_max_tokens = [2, 2] + prompts_lengths = [192, 192] + steps_add_reqs = [0, 0] + seeds = [0, 0] # twice the same sequence + + checked_steps = [ + { + "step": 0, + "tkv": 0, + "waiting": ["0", "1"], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + { # prefill chunk 1 seq 0 + "step": 1, + "tkv": 192, + "waiting": ["1"], + "running": ["0"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 0 + "step": 2, + "tkv": 192, + "waiting": ["1"], + "running": ["0"], + "request_outputs": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 1 seq 1 + # prefix hit! + "step": 3, + "tkv": 192, + "waiting": [], + "running": ["1", "0"], + "request_outputs": [], + "n_reserved_blocks": 8, + "n_used_blocks": 6, + "n_prefix_hits": 1, + # each chunk has two blocks. Due to padding, the first chunk has + # only one usable block + "n_cached_blocks": 1 + }, + { # prefill chunk 2 seq 1 + # cannot use prefix, as the last chunk has to always be recomputed + "step": 4, + "tkv": 192, + "waiting": [], + "running": ["1", "0"], + "request_outputs": ["1"], + "n_reserved_blocks": 8, + "n_used_blocks": 6, + "n_prefix_hits": 0, + "n_cached_blocks": 1 + }, + + { + # Decode 1 of request 0. + # Decode 1 of request 1. + "step": 5, + "tkv": 193, + "waiting": [], + "running": [], + "request_outputs": ["1", "0"], + "finished_requests": ["1", "0"], + "n_reserved_blocks": 8, + "n_used_blocks": 8, + "n_cached_blocks": 1 + }, + { + # Tkv should be cleared one step later + "step": 6, + "tkv": 0, + "waiting": [], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + ] + + check_scheduler_inference_steps( + model=model, + backend=backend, + monkeypatch=monkeypatch, + seqs_max_tokens=seqs_max_tokens, + prompts_lengths=prompts_lengths, + steps_add_reqs=steps_add_reqs, + checked_steps=checked_steps, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + available_blocks=available_blocks, + use_cb=False, + random_prompts=True, + max_num_batched_tokens=max_num_batched_tokens, + prefix_caching=True, + seeds=seeds, + ) + + +@pytest.mark.cpu +@pytest.mark.chunked_prefill +@pytest.mark.full_model +@pytest.mark.prefix_caching +# These values are all parameterized for test sorting +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("max_model_len", [256]) +@pytest.mark.parametrize("max_num_batched_tokens", [128]) +@pytest.mark.parametrize("available_blocks", [None]) +def test_prefix_hit_not_in_batch(model: ModelInfo, backend: str, + monkeypatch: pytest.MonkeyPatch, + set_random_seed, max_num_seqs: int, + max_model_len: int, + max_num_batched_tokens: int, + available_blocks: int): + """ Scenario where two equal sequences are scheduled. + While prefilling the second sequence we have a prefix cache + hit and can reuse the first chunk. Note that the fetched prefix blocks + are not part of the existing decode batch as the sequence has already + left the batch at the time of prefilling the new sequence. Hence we have + no duplicated blocks in the block table for this example. + + Configuration: + * max_num_seqs: 2 + * number of prompts: 2 + * 0: len = 192, max tokens = 2, step joining = 0 + * 1: len = 192, max tokens = 2, step joining = 3 + """ + monkeypatch.setenv("VLLM_SPYRE_CP_INTERLEAVE_STEPS", "0") + + seqs_max_tokens = [2, 2] + prompts_lengths = [192, 192] + # sequence 1 joins only at step 3, when seq 0 has already finished + steps_add_reqs = [0, 3] + seeds = [0, 0] # twice the same sequence + + checked_steps = [ + { + "step": 0, + "tkv": 0, + "waiting": ["0"], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + { # prefill chunk 1 seq 0 + "step": 1, + "tkv": 192, + "waiting": [], + "running": ["0"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 0 + "step": 2, + "tkv": 192, + "waiting": [], + "running": ["0"], + "request_outputs": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { + # Decode 1 of request 0. + # request 1 joined the waiting queue + "step": 3, + "tkv": 193, + "waiting": ["1"], + "running": [], + "request_outputs": ["0"], + "finished_requests": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 4 + }, + { # prefill chunk 1 seq 1 + # prefix hit! + "step": 4, + "tkv": 192, + "waiting": [], + "running": ["1"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 1, + "n_cached_blocks": 1 + }, + { # prefill chunk 2 seq 1 + # cannot use prefix, as the last chunk has to always be recomputed + "step": 5, + "tkv": 192, + "waiting": [], + "running": ["1"], + "request_outputs": ["1"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + "n_cached_blocks": 1 + }, + { + # Decode 1 of request 0. + "step": 6, + "tkv": 193, + "waiting": [], + "running": [], + "request_outputs": ["1"], + "finished_requests": ["1"], + "n_reserved_blocks": 4, + "n_used_blocks": 4, + "n_cached_blocks": 1 + }, + { + # Tkv should be cleared one step later + "step": 7, + "tkv": 0, + "waiting": [], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + ] + + check_scheduler_inference_steps( + model=model, + backend=backend, + monkeypatch=monkeypatch, + seqs_max_tokens=seqs_max_tokens, + prompts_lengths=prompts_lengths, + steps_add_reqs=steps_add_reqs, + checked_steps=checked_steps, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + available_blocks=available_blocks, + use_cb=False, + random_prompts=True, + max_num_batched_tokens=max_num_batched_tokens, + prefix_caching=True, + seeds=seeds, + ) + + +@pytest.mark.cpu +@pytest.mark.chunked_prefill +@pytest.mark.full_model +@pytest.mark.prefix_caching +# These values are all parameterized for test sorting +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("max_model_len", [256]) +@pytest.mark.parametrize("max_num_batched_tokens", [128]) +@pytest.mark.parametrize("available_blocks", [4]) +def test_limit_blocks_no_prefix_hit(model: ModelInfo, backend: str, + monkeypatch: pytest.MonkeyPatch, + set_random_seed, max_num_seqs: int, + max_model_len: int, + max_num_batched_tokens: int, + available_blocks: int): + """ Scenario where three sequences are scheduled with the 1st and 3rd + sequences being identical. While prefilling the third sequence we don't + have a prefix cache hit for the first chunk as the KV cache has already + been overwritten. This is because we limit the number of available blocks + to 4. Note: When increasing the number of available blocks to 8, see + test_limit_blocks_prefix_hit, the same test results in a prefix hit. + + Configuration: + * max_num_seqs: 2 + * number of prompts: 3 + * 0: len = 192, max tokens = 2, step joining = 0 + * 1: len = 192, max tokens = 2, step joining = 3 + * 2: len = 192, max tokens = 2, step joining = 6 + """ + monkeypatch.setenv("VLLM_SPYRE_CP_INTERLEAVE_STEPS", "0") + + seqs_max_tokens = [2, 2, 2] + prompts_lengths = [192, 192, 192] + # sequence 1 joins only at step 3, when seq 0 has already finished + # sequence 2 joins only at step 6, when seq 1 has already finished + steps_add_reqs = [0, 3, 6] + seeds = [0, 1, 0] # 1st and 3rd sequence are the same + + checked_steps = [ + { + "step": 0, + "tkv": 0, + "waiting": ["0"], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + { # prefill chunk 1 seq 0 + "step": 1, + "tkv": 192, + "waiting": [], + "running": ["0"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 0 + "step": 2, + "tkv": 192, + "waiting": [], + "running": ["0"], + "request_outputs": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { + # Decode 1 of request 0 + # request 1 joined the waiting queue + "step": 3, + "tkv": 193, + "waiting": ["1"], + "running": [], + "request_outputs": ["0"], + "finished_requests": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 4 + }, + { # prefill chunk 1 seq 1 + "step": 4, + "tkv": 192, + "waiting": [], + "running": ["1"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 1 + "step": 5, + "tkv": 192, + "waiting": [], + "running": ["1"], + "request_outputs": ["1"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { + # Decode 1 of request 1 + # request 2 joined the waiting queue + "step": 6, + "tkv": 193, + "waiting": ['2'], + "running": [], + "request_outputs": ["1"], + "finished_requests": ["1"], + "n_reserved_blocks": 4, + "n_used_blocks": 4 + }, + { # prefill chunk 1 seq 2 + # no prefix hit as KV cache is already overwritten! + "step": 7, + "tkv": 192, + "waiting": [], + "running": ["2"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 2 + "step": 8, + "tkv": 192, + "waiting": [], + "running": ["2"], + "request_outputs": ["2"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { + # Decode 1 of request 2 + "step": 9, + "tkv": 193, + "waiting": [], + "running": [], + "request_outputs": ["2"], + "finished_requests": ["2"], + "n_reserved_blocks": 4, + "n_used_blocks": 4 + }, + { + # Tkv should be cleared one step later + "step": 10, + "tkv": 0, + "waiting": [], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + ] + + check_scheduler_inference_steps( + model=model, + backend=backend, + monkeypatch=monkeypatch, + seqs_max_tokens=seqs_max_tokens, + prompts_lengths=prompts_lengths, + steps_add_reqs=steps_add_reqs, + checked_steps=checked_steps, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + available_blocks=available_blocks, + use_cb=False, + random_prompts=True, + max_num_batched_tokens=max_num_batched_tokens, + prefix_caching=True, + seeds=seeds) + + +@pytest.mark.cpu +@pytest.mark.chunked_prefill +@pytest.mark.full_model +@pytest.mark.prefix_caching +# These values are all parameterized for test sorting +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("max_model_len", [256]) +@pytest.mark.parametrize("max_num_batched_tokens", [128]) +@pytest.mark.parametrize("available_blocks", [8]) +def test_limit_blocks_prefix_hit(model: ModelInfo, backend: str, + monkeypatch: pytest.MonkeyPatch, + set_random_seed, max_num_seqs: int, + max_model_len: int, + max_num_batched_tokens: int, + available_blocks: int): + """ Scenario where three sequences are scheduled with the 1st and 3rd + sequences being identical. While prefilling the third sequence we + have a prefix cache hit for the first chunk as the KV cache is still + persistent. This is because the number of available blocks (8) is high + enough. Note: When decreasing the number of available blocks to 4, see + test_limit_blocks_no_prefix_hit, the same test results in a no prefix hit. + + Configuration: + * max_num_seqs: 2 + * number of prompts: 3 + * 0: len = 192, max tokens = 2, step joining = 0 + * 1: len = 192, max tokens = 2, step joining = 3 + * 2: len = 192, max tokens = 2, step joining = 6 + """ + monkeypatch.setenv("VLLM_SPYRE_CP_INTERLEAVE_STEPS", "0") + + seqs_max_tokens = [2, 2, 2] + prompts_lengths = [192, 192, 192] + # sequence 1 joins only at step 3, when seq 0 has already finished + # sequence 2 joins only at step 6, when seq 1 has already finished + steps_add_reqs = [0, 3, 6] + seeds = [0, 1, 0] # 1st and 3rd sequence are the same + + checked_steps = [ + { + "step": 0, + "tkv": 0, + "waiting": ["0"], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + { # prefill chunk 1 seq 0 + "step": 1, + "tkv": 192, + "waiting": [], + "running": ["0"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 0 + "step": 2, + "tkv": 192, + "waiting": [], + "running": ["0"], + "request_outputs": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { + # Decode 1 of request 0 + # request 1 joined the waiting queue + "step": 3, + "tkv": 193, + "waiting": ["1"], + "running": [], + "request_outputs": ["0"], + "finished_requests": ["0"], + "n_reserved_blocks": 4, + "n_used_blocks": 4 + }, + { # prefill chunk 1 seq 1 + "step": 4, + "tkv": 192, + "waiting": [], + "running": ["1"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 1 + "step": 5, + "tkv": 192, + "waiting": [], + "running": ["1"], + "request_outputs": ["1"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + }, + { + # Decode 1 of request 1 + # request 2 joined the waiting queue + "step": 6, + "tkv": 193, + "waiting": ['2'], + "running": [], + "request_outputs": ["1"], + "finished_requests": ["1"], + "n_reserved_blocks": 4, + "n_used_blocks": 4 + }, + { # prefill chunk 1 seq 2 + # prefix hit as KV cache is still persistent + "step": 7, + "tkv": 192, + "waiting": [], + "running": ["2"], + "request_outputs": [], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 1, + "n_cached_blocks": 1 + }, + { # prefill chunk 2 seq 2 + "step": 8, + "tkv": 192, + "waiting": [], + "running": ["2"], + "request_outputs": ["2"], + "n_reserved_blocks": 4, + "n_used_blocks": 3, + "n_prefix_hits": 0, + "n_cached_blocks": 1 + }, + { + # Decode 1 of request 2 + "step": 9, + "tkv": 193, + "waiting": [], + "running": [], + "request_outputs": ["2"], + "finished_requests": ["2"], + "n_reserved_blocks": 4, + "n_used_blocks": 4, + "n_cached_blocks": 1 + }, + { + # Tkv should be cleared one step later + "step": 10, + "tkv": 0, + "waiting": [], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + ] + + check_scheduler_inference_steps( + model=model, + backend=backend, + monkeypatch=monkeypatch, + seqs_max_tokens=seqs_max_tokens, + prompts_lengths=prompts_lengths, + steps_add_reqs=steps_add_reqs, + checked_steps=checked_steps, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + available_blocks=available_blocks, + use_cb=False, + random_prompts=True, + max_num_batched_tokens=max_num_batched_tokens, + prefix_caching=True, + seeds=seeds) + + +@pytest.mark.cpu +@pytest.mark.chunked_prefill +@pytest.mark.full_model +@pytest.mark.prefix_caching +# These values are all parameterized for test sorting +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("max_model_len", [512]) +@pytest.mark.parametrize("max_num_batched_tokens", [128]) +@pytest.mark.parametrize("available_blocks", [None]) +def test_full_match(model: ModelInfo, backend: str, + monkeypatch: pytest.MonkeyPatch, set_random_seed, + max_num_seqs: int, max_model_len: int, + max_num_batched_tokens: int, available_blocks: int): + """ Scenario where two equal sequences are scheduled. + Both sequences have exactly 3 chunks worth of tokens, thus + resulting in a 100% match up to the last token. This test + makes sure that the last chunk is not reused. + + Configuration: + * max_num_seqs: 2 + * number of prompts: 2 + * 0: len = 384, max tokens = 2, step joining = 0 + * 1: len = 384, max tokens = 2, step joining = 0 + """ + monkeypatch.setenv("VLLM_SPYRE_CP_INTERLEAVE_STEPS", "0") + + seqs_max_tokens = [2, 2] + prompts_lengths = [384, 384] + steps_add_reqs = [0, 0] + seeds = [0, 0] # twice the same sequence + + checked_steps = [ + { + "step": 0, + "tkv": 0, + "waiting": ["0", "1"], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + { # prefill chunk 1 seq 0 + "step": 1, + "tkv": 384, + "waiting": ["1"], + "running": ["0"], + "request_outputs": [], + "n_reserved_blocks": 7, + "n_used_blocks": 6, + "n_prefix_hits": 0, + }, + { # prefill chunk 2 seq 0 + "step": 2, + "tkv": 384, + "waiting": ["1"], + "running": ["0"], + "request_outputs": [], + "n_reserved_blocks": 7, + "n_used_blocks": 6, + "n_prefix_hits": 0, + }, + { # prefill chunk 3 seq 0 + "step": 3, + "tkv": 384, + "waiting": ["1"], + "running": ["0"], + "request_outputs": ["0"], + "n_reserved_blocks": 7, + "n_used_blocks": 6, + "n_prefix_hits": 0, + }, + { # prefill chunk 1 seq 1 + # prefix hit! + "step": 4, + "tkv": 384, + "waiting": [], + "running": ["1", "0"], + "request_outputs": [], + "n_reserved_blocks": 14, + "n_used_blocks": 12, + "n_prefix_hits": 1, + # The number of cached blocks is determined up front + "n_cached_blocks": 4 + }, + { # prefill chunk 2 seq 1 + # cannot use prefix, as the last chunk has to always be recomputed + "step": 5, + "tkv": 384, + "waiting": [], + "running": ["1", "0"], + "request_outputs": [], + "n_reserved_blocks": 14, + "n_used_blocks": 12, + "n_prefix_hits": 1, + "n_cached_blocks": 4 + }, + { # prefill chunk 3 seq 1 + # cannot use prefix, as the last chunk has to always be recomputed + "step": 6, + "tkv": 384, + "waiting": [], + "running": ["1", "0"], + "request_outputs": ["1"], + "n_reserved_blocks": 14, + "n_used_blocks": 12, + "n_prefix_hits": 0, + "n_cached_blocks": 4 + }, + { + # Decode 1 of request 0. + # Decode 1 of request 1. + "step": 7, + "tkv": 385, + "waiting": [], + "running": [], + "request_outputs": ["1", "0"], + "finished_requests": ["1", "0"], + "n_reserved_blocks": 14, + "n_used_blocks": 14, + "n_cached_blocks": 4 + }, + { + # Tkv should be cleared one step later + "step": 8, + "tkv": 0, + "waiting": [], + "running": [], + "request_outputs": [], + "n_reserved_blocks": 0, + "n_used_blocks": 0 + }, + ] + + check_scheduler_inference_steps( + model=model, + backend=backend, + monkeypatch=monkeypatch, + seqs_max_tokens=seqs_max_tokens, + prompts_lengths=prompts_lengths, + steps_add_reqs=steps_add_reqs, + checked_steps=checked_steps, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + available_blocks=available_blocks, + use_cb=False, + random_prompts=True, + max_num_batched_tokens=max_num_batched_tokens, + prefix_caching=True, + seeds=seeds, + ) diff --git a/tests/e2e/test_spyre_seed.py b/tests/e2e/test_spyre_seed.py index cff109099..2f78d69b7 100644 --- a/tests/e2e/test_spyre_seed.py +++ b/tests/e2e/test_spyre_seed.py @@ -10,10 +10,14 @@ from spyre_util import DecodeWarmupShapes, ModelInfo, get_chicken_soup_prompts from vllm import SamplingParams +sb_mark = pytest.param("sb", marks=pytest.mark.sb, id="sb") +cb_mark = pytest.param("cb", marks=pytest.mark.cb, id="cb") + @pytest.mark.xfail(reason="Failing currently because of output mismatch") @pytest.mark.parametrize("temperature", [0.1, 1.0]) @pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("mode", [sb_mark, cb_mark]) def test_seed(model: ModelInfo, temperature: float, seed: int, max_model_len: int, max_num_seqs: int, warmup_shapes: DecodeWarmupShapes, backend: str, mode: str, diff --git a/tests/e2e/test_spyre_stagger_basic.py b/tests/e2e/test_spyre_stagger_basic.py index 51911d66d..1ebcc4079 100644 --- a/tests/e2e/test_spyre_stagger_basic.py +++ b/tests/e2e/test_spyre_stagger_basic.py @@ -10,7 +10,11 @@ skip_unsupported_tp_size) from vllm import SamplingParams +sb_mark = pytest.param("sb", marks=pytest.mark.sb, id="sb") +cb_mark = pytest.param("cb", marks=pytest.mark.cb, id="cb") + +@pytest.mark.parametrize("mode", [sb_mark, cb_mark]) def test_stagger_output(model: ModelInfo, tp_size: int, backend: str, mode: str, max_num_seqs: int, max_model_len: int, warmup_shapes, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/hf_cache.json b/tests/hf_cache.json index cf94dad10..b89669070 100644 --- a/tests/hf_cache.json +++ b/tests/hf_cache.json @@ -415,6 +415,30 @@ ], "logprobs": [ -4.388803482055664, -2.1905605792999268, -1.1574044227600098, -3.3592681884765625 ] } + }, + "__tokens__41505_37255_20672_12727_25130_19903_38525_14909_23426_28674_44635_24806_13853_37149_30394_12313_44715_48305_39823_44343_15245_35872_44179_33619_23207_4950_21340_30026_44876_47510_23446_42531_12804_39568_26970_691_35375_19603_40542_32841_57_24260_42644_11989_15985_42785_9392_27894_11729_47556_39478_22019_3954_15732_24966_45850_5361_27096_34729_26908_40032_26556_47374_29648_28882_21872_29309_18919_28294_14270_9309_9178_30119_32276_23422_4415_37237_43095_45386_41408_44147_45371_26572_19233_34666_13548_39893_41754_43993_28990_46682_28493_22146_32452_48968_45069_38993_4049_30119_23910_30973_41537_11946_35954_5758_10836_39055_16345_40103_4945_7194_34292_2224_28207_44729_26257_33452_1313_31211_29803_28309_19229_18193_48194_1789_1064_47236_9092_6090_10351_39358_46053_1120_20920_4989_12776_10854_31798_17218_8863_24755_1936_4961_48573_9799_17624_35959_41205_45145_8328_33061_47507_2854_33236_41554_16826_12322_29333_21741_8593_23181_20148_27973_24999_15308_17555_41172_12334_27555_612_36450_16511_2247_13806_11803_46848_17313_14150_17656_46542_31150_30527_35174_19072_20370_31990_75_9453_16437_11768_31329_18611": { + "2": { + "text": "uary ", + "token_ids": [ 20360, 225 ], + "tokens": [ "uary", " " ], + "logprobs": [ -2.8783950805664062, -2.698195695877075 ] + } + }, + "__tokens__6605_41653_37541_12537_24352_22093_32027_38767_4614_1394_41079_21271_37467_104_21892_35465_11244_46461_44307_1504_1251_26611_46161_18737_10647_20748_1428_10897_21523_24370_11457_11348_10754_22590_14244_1057_41168_27351_31570_9138_48785_42268_5942_16353_35462_34956_46027_20747_40798_32947_14911_28881_43375_41592_24836_28951_1697_11931_39194_20364_8504_26975_34556_33152_18418_21576_24990_38262_25605_19329_24069_1454_2138_34572_48325_29156_19346_8373_24686_48271_37872_26523_42285_11412_25253_46815_28400_22567_13236_26935_47044_281_38518_40328_43557_36397_39771_25494_27592_20943_2759_42762_28017_9823_24808_23835_17537_17011_26467_30646_30103_22519_1375_11286_8711_28727_42320_39245_39179_40129_12548_41373_33085_4092_821_716_37138_12267_5382_30710_16929_3417_7846_25922_8265_13415_34976_22350_15827_23287_1162_19000_20689_9243_5346_44227_25073_10278_29769_40159_1024_879_7199_35332_7876_34633_33334_26773_10843_47952_39214_25392_10971_31875_19410_28304_15790_31012_2890_14677_47574_43034_15060_42197_15255_46167_36561_20456_12404_417_43190_1864_40276_47294_28030_8431_42653_47863_34604_25012_18578_17052": { + "2": { + "text": "light s", + "token_ids": [ 2429, 309 ], + "tokens": [ "light", " s" ], + "logprobs": [ -1.708275318145752, -3.340717315673828 ] + } + }, + "__tokens__41505_37255_20672_12727_25130_19903_38525_14909_23426_28674_44635_24806_13853_37149_30394_12313_44715_48305_39823_44343_15245_35872_44179_33619_23207_4950_21340_30026_44876_47510_23446_42531_12804_39568_26970_691_35375_19603_40542_32841_57_24260_42644_11989_15985_42785_9392_27894_11729_47556_39478_22019_3954_15732_24966_45850_5361_27096_34729_26908_40032_26556_47374_29648_28882_21872_29309_18919_28294_14270_9309_9178_30119_32276_23422_4415_37237_43095_45386_41408_44147_45371_26572_19233_34666_13548_39893_41754_43993_28990_46682_28493_22146_32452_48968_45069_38993_4049_30119_23910_30973_41537_11946_35954_5758_10836_39055_16345_40103_4945_7194_34292_2224_28207_44729_26257_33452_1313_31211_29803_28309_19229_18193_48194_1789_1064_47236_9092_6090_10351_39358_46053_1120_20920_4989_12776_10854_31798_17218_8863_24755_1936_4961_48573_9799_17624_35959_41205_45145_8328_33061_47507_2854_33236_41554_16826_12322_29333_21741_8593_23181_20148_27973_24999_15308_17555_41172_12334_27555_612_36450_16511_2247_13806_11803_46848_17313_14150_17656_46542_31150_30527_35174_19072_20370_31990_75_9453_16437_11768_31329_18611_43028_27926_20369_19772_34496_20557_32548_2300_21890_12742_7751_25931_23950_27594_37133_43444_24310_15338_22949_39766_43008_39932_9241_49123_31117_4103_35662_48504_19750_33350_15541_10495_35258_116_40439_25969_4807_5845_31913_42941_13762_48095_4924_41972_19499_3999_13503_22265_38945_42337_6558_25602_31987_17059_42853_13685_913_1999_33472_27444_46522_46126_44721_2065_36821_34471_32212_35014_44370_31464_18307_26440_10216_28858_438_7423_16388_38811_35315_16626_30501_2026_8054_48263_14231_19405_26959_14422_23498_11782_2372_8827_25709_3483_19817_16148_20384_4886_44662_23298_41329_47983_16891_23548_34386_20965_14839_36114_43961_45204_30805_18460_47901_31402_3236_4162_36857_3006_386_19357_25510_22047_24017_28748_33389_20793_18104_48584_12825_38196_21195_17622_3139_42446_34505_44384_22198_33272_5845_19560_10186_2070_46594_10612_7194_9731_18581_26856_7439_48596_48315_7295_19951_33420_43138_24350_45074_15850_24499_24509_32935_9929_29971_10753_16723_47312_44188_40212_1744_7293_12627_38543_41402_28653_35297_39668_3262_4161_42708_1938_11064_1998_752_41482_16250_7899_7315_32248_47608_24822_44290_24695": { + "2": { + "text": "Listings", + "token_ids": [ 720, 2052 ], + "tokens": [ "List", "ings" ], + "logprobs": [ -3.185478448867798, -3.137629747390747 ] + } } } }, @@ -505,6 +529,30 @@ "tokens": [ "\"", "Hello", ",", "welcome", "to", "the" ], "logprobs": [ -1.5613254308700562, -1.955716848373413, -0.5875644683837891, -2.0813889503479004, -0.3190162479877472, -1.1618010997772217 ] } + }, + "__tokens__27021_24255_13460_8287_16362_12959_25082_9708_15252_18669_29059_16151_9020_24186_19788_8018_29112_31449_25927_28869_9926_23355_28763_21888_15110_3225_13895_19549_29216_30931_15265_27690_8337_25761_17559_452_23031_12764_26395_21381_39_15796_27763_7807_10408_27855_6116_18161_7637_30961_25702_14336_2577_10243_16255_29850_3492_17641_22610_17519_26063_17290_30842_19303_18804_14241_19082_12318_18422_9292_6062_5977_19609_21014_15250_2877_24244_28057_29548_26959_28741_29538_17300_12523_22569_8822_25972_27184_28641_18874_30392_18551_14419_21128_31880_29342_25387_2638_19610_15567_20165_27042_7779_23408_3750_7057_25427_10643_26109_3222_4686_22326_1450_18364_29120_17095_21779_857_20321_19404_18431_12520_11846_31376_1167_695_30753_5921_3967_6740_25624_29983_731_13621_3250_8319_7068_20702_11211_5772_16117_1263_3232_31623_6381_11475_23411_26826_29391_5424_21525_30929_1860_21639_27054_10955_8024_19098_14155_5596_15093_13118_18212_16276_9968_11430_26805_8032_17940_400_23731_10751_1465_8990_7686_30500_11273_9214_11496_30301_20281_19875_22900_12418_13263_20827_51_6156_10702_7663_20397_12118": { + "2": { + "text": "b.", + "token_ids": [ 29890, 29889 ], + "tokens": [ "b", "." ], + "logprobs": [ -4.186609268188477, -4.957806587219238 ] + } + }, + "__tokens__4302_27118_24441_8164_15855_14385_20852_25239_3006_910_26744_13850_24393_70_14254_23090_7322_30248_28845_981_817_17326_30052_12200_6933_13509_932_7096_14014_15867_7461_7390_7003_14708_9275_690_26802_17807_20554_5951_31761_27518_3871_10648_23088_22759_29966_13509_26561_21450_9709_18803_28239_27078_16170_18849_1107_7769_25517_13259_5538_17562_22498_21584_11992_14048_16271_24910_16671_12585_15671_949_1394_22509_31462_18983_12597_5453_16073_31426_24657_17269_27529_7431_16442_30479_18490_14693_8619_17537_30627_185_25077_26256_28358_23696_25893_16599_17964_13636_1798_27840_18241_6397_16152_15519_11419_11076_17232_19952_19599_14662_898_7349_5673_18703_27552_25550_25507_26126_8171_26936_21540_2666_537_468_24179_7988_3506_19994_11023_2227_5110_16877_5383_8735_22771_14552_10306_15162_759_12371_13471_6019_3483_28794_16325_6693_19381_26145_669_574_4689_23003_5129_22548_21702_17431_7061_31219_25530_16532_7144_20753_12638_18428_10281_20191_1883_9557_30973_28017_9806_27472_9933_30057_23803_13319_8077_274_28119_1216_26221_30790_18250_5491_27769_31160_22529_16285_12096_11103": { + "2": { + "text": "of the", + "token_ids": [ 310, 278 ], + "tokens": [ "of", "the" ], + "logprobs": [ -4.2294392585754395, -2.3842616081237793 ] + } + }, + "__tokens__27021_24255_13460_8287_16362_12959_25082_9708_15252_18669_29059_16151_9020_24186_19788_8018_29112_31449_25927_28869_9926_23355_28763_21888_15110_3225_13895_19549_29216_30931_15265_27690_8337_25761_17559_452_23031_12764_26395_21381_39_15796_27763_7807_10408_27855_6116_18161_7637_30961_25702_14336_2577_10243_16255_29850_3492_17641_22610_17519_26063_17290_30842_19303_18804_14241_19082_12318_18422_9292_6062_5977_19609_21014_15250_2877_24244_28057_29548_26959_28741_29538_17300_12523_22569_8822_25972_27184_28641_18874_30392_18551_14419_21128_31880_29342_25387_2638_19610_15567_20165_27042_7779_23408_3750_7057_25427_10643_26109_3222_4686_22326_1450_18364_29120_17095_21779_857_20321_19404_18431_12520_11846_31376_1167_695_30753_5921_3967_6740_25624_29983_731_13621_3250_8319_7068_20702_11211_5772_16117_1263_3232_31623_6381_11475_23411_26826_29391_5424_21525_30929_1860_21639_27054_10955_8024_19098_14155_5596_15093_13118_18212_16276_9968_11430_26805_8032_17940_400_23731_10751_1465_8990_7686_30500_11273_9214_11496_30301_20281_19875_22900_12418_13263_20827_51_6156_10702_7663_20397_12118_28013_18182_13262_12874_22459_13384_21191_1499_14252_8297_5048_16883_15594_17966_24176_28284_15828_9987_14942_25890_28000_25997_6018_31981_20259_2673_23218_31578_12859_21713_10119_6835_22955_78_26327_16908_3131_3807_20777_27957_8961_31312_3208_27326_12696_2605_8793_14496_25355_27563_4272_16669_20826_11107_27900_8911_597_1304_21792_17868_30288_30030_29115_1347_23973_22443_20972_22796_28887_20485_11920_17215_6653_18789_287_4835_10671_25268_22992_10826_19858_1321_5246_31421_9267_12635_17552_9391_15299_7672_1547_5749_16739_2270_12903_10514_13272_3183_29077_15169_26907_31239_10998_15332_22387_13650_9662_23512_28621_29430_20056_12020_31186_20445_2109_2712_23996_1959_254_12603_16609_14355_15637_18717_21738_13538_11788_31630_8351_24867_13800_11474_2046_27634_22465_28896_14453_21662_3807_12736_6633_1350_30334_6910_4685_6337_12098_17485_4845_31638_31455_4751_12990_21758_28085_15854_29345_10320_15951_15958_21443_6466_19513_7003_10889_30802_28768_26180_1137_4750_8222_25093_26955_18655_22981_25826_2126_2711_27805_1264_7205_1303_492_27007_10581_5144_4764_20995_30995_16161_28835_16079": { + "2": { + "text": "zzt", + "token_ids": [ 29920, 2065 ], + "tokens": [ "z", "zt" ], + "logprobs": [ -4.901067733764648, -4.543080806732178 ] + } } } }, diff --git a/tests/llm_cache.py b/tests/llm_cache.py index a3f2ff623..2e57c2d9f 100644 --- a/tests/llm_cache.py +++ b/tests/llm_cache.py @@ -5,10 +5,12 @@ from typing import Callable, Generic, Optional, TypeVar import pytest -from llm_cache_util import force_engine_shutdown +from llm_cache_util import force_engine_core_shutdown, force_engine_shutdown from spyre_util import (DecodeWarmupShapes, ModelInfo, RemoteOpenAIServer, patch_environment) from vllm import LLM, EngineArgs +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor @@ -95,6 +97,7 @@ def get_cached_llm(self, warmup_shapes: DecodeWarmupShapes | None = None, max_num_seqs: Optional[int] = None, use_cb: bool = False, + use_pc: bool = False, max_num_batched_tokens: Optional[int] = None) -> LLM: """Creates an LLM with the provided runtime configuration. @@ -106,6 +109,7 @@ def get_cached_llm(self, "tensor_parallel_size": tensor_parallel_size, "backend": backend, "use_cb": use_cb, + "use_pc": use_pc, "max_num_batched_tokens": max_num_batched_tokens } if use_cb: @@ -118,12 +122,14 @@ def get_cached_llm(self, # Always patch the environment so that it's consistent with the LLM # Use chunked prefill if max_num_batched_tokens is set + use_chunked_prefill = bool(max_num_batched_tokens) + if use_pc: + assert use_chunked_prefill patch_environment(use_cb, warmup_shapes, backend, monkeypatch, - use_chunked_prefill=max_num_batched_tokens - is not None, + use_chunked_prefill=use_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens) maybe_llm = self._cache.maybe_get(runtime_config) @@ -149,6 +155,7 @@ def get_cached_llm(self, tensor_parallel_size=tensor_parallel_size, max_num_batched_tokens=max_num_batched_tokens, logits_processors=[GoldenTokenInjector], + enable_prefix_caching=use_pc, ), ) @@ -160,7 +167,8 @@ class EngineCache: """Cache for continuous batching engines""" def __init__(self): - self._cache: ModelCache[EngineCore] = ModelCache[EngineCore]() + self._cache: ModelCache[EngineCore] = ModelCache[EngineCore]( + teardown_method=lambda x: force_engine_core_shutdown(x)) def get_engine( self, @@ -169,6 +177,7 @@ def get_engine( max_num_seqs: int, available_blocks: int, max_num_batched_tokens: int, + use_pc: bool, backend: str, monkeypatch, ) -> EngineCore: @@ -177,6 +186,7 @@ def get_engine( "max_model_len": max_model_len, "max_num_seqs": max_num_seqs, "available_blocks": available_blocks, + "use_pc": use_pc, "max_num_batched_tokens": max_num_batched_tokens, } @@ -185,6 +195,9 @@ def get_engine( use_chunked_prefill = True else: use_chunked_prefill = False + + if use_pc: + assert use_chunked_prefill patch_environment(use_cb=True, warmup_shapes=None, backend=backend, @@ -193,6 +206,10 @@ def get_engine( maybe_engine = self._cache.maybe_get(runtime_config) if maybe_engine: + if use_pc: + # reset the prefix cache across tests + (maybe_engine.model_executor.driver_worker.worker.model_runner. + block_pool.reset_prefix_cache()) return maybe_engine self.clear() @@ -224,7 +241,8 @@ def get_engine( max_num_seqs=max_num_seqs_compiled, num_gpu_blocks_override=None, logits_processors=[GoldenTokenInjector], - max_num_batched_tokens=max_num_batched_tokens) + max_num_batched_tokens=max_num_batched_tokens, + enable_prefix_caching=use_pc) vllm_config = engine_args.create_engine_config() executor_class = Executor.get_class(vllm_config) @@ -243,6 +261,23 @@ def get_engine( assert worker.model_runner.n_blocks >= available_blocks, \ "Cannot set available_blocks > (context * batch size // 64)" worker.model_runner.n_blocks = available_blocks + # need to overwrite the block pool and kv cache manager if the + # number of available blocks has changed + worker.model_runner.block_pool = BlockPool( + num_gpu_blocks=available_blocks + 1, + enable_caching=use_pc, + enable_kv_cache_events=False) + worker.model_runner.kv_cache_manager = FullAttentionManager( + kv_cache_spec=worker.model_runner._attn_spec, + block_pool=worker.model_runner.block_pool, + # Currently don't support models with more than one + # attention type, e.g. full and sliding window, so + # there is only one group. + kv_cache_group_id=0, + # We don't support DCP + # https://docs.vllm.ai/en/latest/serving/context_parallel_deployment/#decode-context-parallel + dcp_world_size=1, + ) return self._cache.set( runtime_config, @@ -352,7 +387,8 @@ def get_cached_engine(model: str, available_blocks: int, backend: str, monkeypatch, - max_num_batched_tokens: int | None = None) -> EngineCore: + max_num_batched_tokens: int | None = None, + use_pc: bool = False) -> EngineCore: # Clear other caches first LLM_CACHE.clear() API_SERVER_CACHE.clear() @@ -363,6 +399,7 @@ def get_cached_engine(model: str, max_num_seqs=max_num_seqs, available_blocks=available_blocks, max_num_batched_tokens=max_num_batched_tokens, + use_pc=use_pc, backend=backend, monkeypatch=monkeypatch, ) diff --git a/tests/llm_cache_util.py b/tests/llm_cache_util.py index 8fce9842b..b57de8766 100644 --- a/tests/llm_cache_util.py +++ b/tests/llm_cache_util.py @@ -7,6 +7,10 @@ def force_engine_shutdown(llm: LLM): + force_engine_core_shutdown(llm.llm_engine.engine_core) + + +def force_engine_core_shutdown(engine_core): """ 🌶️🌶️🌶️ This hack is here because of an issue in vllm 0.9.2+ where a circular @@ -20,7 +24,7 @@ def force_engine_shutdown(llm: LLM): new engine will fail with an EADDRINUSE error. 🌶️🌶️🌶️ """ - llm.llm_engine.engine_core.shutdown() + engine_core.shutdown() def sort_tests_for_llm_caching(items: list) -> None: @@ -52,6 +56,7 @@ class SortKey(NamedTuple): tp_size: int = 1 use_cb: bool = False use_cp: bool = False + use_pc: bool = False max_model_len: int = 0 max_num_seqs: int = 0 num_blocks: int = 0 @@ -73,6 +78,7 @@ def from_item(item) -> "SortKey": use_cb = SortKey._uses_cb(item) use_cp = SortKey._uses_cp(item) + use_pc = SortKey._uses_pc(item) if use_cb or use_cp: sort_kwargs = { "max_model_len": SortKey._get_max_model_len(item), @@ -90,6 +96,7 @@ def from_item(item) -> "SortKey": tp_size=SortKey._get_tp_size(item), use_cb=use_cb, use_cp=use_cp, + use_pc=use_pc, num_blocks=SortKey._get_num_blocks(item), max_num_batched_tokens=SortKey._get_max_num_batched_tokens(item), **sort_kwargs, @@ -132,6 +139,13 @@ def _uses_cp(item) -> bool: markers = {mark.name for mark in item.own_markers} return "chunked_prefill" in markers + @staticmethod + def _uses_pc(item) -> bool: + """True if the test uses prefix caching. + Checks for the pytest.mark.prefix_caching mark.""" + markers = {mark.name for mark in item.own_markers} + return "chunked_prefill" in markers and "prefix_caching" in markers + def _get_max_num_batched_tokens(item) -> int: """Chunk size for chunked prefill, if enabled""" params = item.callspec.params diff --git a/tests/scheduling_utils.py b/tests/scheduling_utils.py index 5fc3faee0..74fb98ca3 100644 --- a/tests/scheduling_utils.py +++ b/tests/scheduling_utils.py @@ -39,13 +39,12 @@ def augment_checked_steps( return all_checked_steps -def generate_prompts( - model: ModelInfo, - steps_add_reqs: list[int], - seqs_max_tokens: list[int], - prompts_lengths: list[int], - from_model_vocab: bool = False, -): +def generate_prompts(model: ModelInfo, + steps_add_reqs: list[int], + seqs_max_tokens: list[int], + prompts_lengths: list[int], + from_model_vocab: bool = False, + seeds: list[int] = None): generated_prompts = [] # Create random requests of specified lengths and max_tokens @@ -53,6 +52,16 @@ def generate_prompts( # will be overridden sorted_reqs_params = zip(steps_add_reqs, seqs_max_tokens, prompts_lengths) requests: deque[tuple[int, EngineCoreRequest]] = deque() + + # seeds for random (repeated) prompts generation to test prefix caching + if seeds: + assert from_model_vocab, \ + "when providing seeds we create random prompts" + assert len(seeds) == len(steps_add_reqs), \ + "number of seeds must be equal to the number of prompts" + else: + seeds = [None] * len(steps_add_reqs) + for i, (add_step, max_tokens, prompt_length) in enumerate(sorted_reqs_params): # ignoring eos because we want to force the decoding to finish @@ -65,7 +74,8 @@ def generate_prompts( num_tokens=prompt_length, sampling_params=sampling_params, model=model, - from_model_vocab=from_model_vocab) + from_model_vocab=from_model_vocab, + seed=seeds[i]) requests.append((add_step, request)) # NOTE: It is going to be decoded later generated_prompts.append(request.prompt_token_ids) @@ -88,6 +98,8 @@ def check_scheduler_inference_steps( use_cb: bool = True, max_num_batched_tokens: int = None, random_prompts: bool = False, + prefix_caching: bool = False, + seeds: list[int] = None, ): """ Test the scheduler execution by comparing the scheduler attributes at each @@ -131,7 +143,8 @@ def check_scheduler_inference_steps( steps_add_reqs, seqs_max_tokens, prompts_lengths, - from_model_vocab=random_prompts) + from_model_vocab=random_prompts, + seeds=seeds) hf_results = generate_hf_output( model=model, @@ -160,6 +173,7 @@ def check_scheduler_inference_steps( max_model_len=max_model_len, max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + use_pc=prefix_caching, available_blocks=available_blocks, backend=backend, monkeypatch=monkeypatch) @@ -219,12 +233,13 @@ def check_scheduler_inference_steps( ), f"Step {step}, finished request output: {out_reqs_finished}" # checking the scheduler handling of free and reserved blocks - n_blocks = (engine_core.model_executor.driver_worker.worker. - model_runner.n_blocks) + model_runner = ( + engine_core.model_executor.driver_worker.worker.model_runner) + n_blocks = model_runner.n_blocks + block_size = model_runner.block_size n_reserved_blocks = n_blocks - scheduler.n_free_blocks - kv_cache_manager = (engine_core.model_executor.driver_worker. - worker.model_runner.kv_cache_manager) + kv_cache_manager = model_runner.kv_cache_manager req_ids2blocks = { req_id: [block.block_id for block in blocks] @@ -232,23 +247,45 @@ def check_scheduler_inference_steps( if blocks } req_ids2num_reserved_blocks = ( - engine_core.model_executor.driver_worker.worker.model_runner. - req_ids2num_reserved_blocks) + model_runner.req_ids2num_reserved_blocks) n_used_blocks = sum( [len(blocks) for blocks in req_ids2blocks.values()]) + n_cached_blocks = n_prefix_hits = 0 + if prefix_caching: + reqs = model_runner.requests + prefix_hits = [ + reqs[r_id].num_cached_tokens + > reqs[r_id].num_computed_tokens for r_id in req_ids2blocks + ] + cached_blocks = [ + reqs[r_id].num_cached_tokens // block_size + for r_id in req_ids2blocks + ] + n_cached_blocks = sum(cached_blocks) + for r_id in req_ids2blocks: + print(f"{reqs[r_id].num_cached_tokens=}") + n_prefix_hits = sum(prefix_hits) + if step > 0: if DISABLE_ASSERTS: print( f"{step=}, {n_reserved_blocks=}, {n_used_blocks=}, " f"{scheduler.tkv=}, {waiting=}, {out_reqs_finished=}, " - f"{running=}, {out_reqs_ids=}") + f"{running=}, {out_reqs_ids=}, {n_prefix_hits=}" + f"{n_cached_blocks=}") assert DISABLE_ASSERTS or ( n_reserved_blocks == step_ref["n_reserved_blocks"] ), f"Step {step}, n_reserved_blocks: {n_reserved_blocks}" assert DISABLE_ASSERTS or ( n_used_blocks == step_ref["n_used_blocks"] ), f"Step {step}, n_used_blocks: {n_used_blocks}" + assert DISABLE_ASSERTS or "n_prefix_hits" not in step_ref or ( + n_prefix_hits == step_ref["n_prefix_hits"] + ), f"Step {step}, n_prefix_hits: {n_prefix_hits}" + assert DISABLE_ASSERTS or "n_cached_blocks" not in step_ref or ( + n_cached_blocks == step_ref["n_cached_blocks"] + ), f"Step {step}, n_cached_blocks: {n_cached_blocks}" assert DISABLE_ASSERTS or len(req_ids2blocks) == len( req_ids2num_reserved_blocks) diff --git a/tests/spyre_util.py b/tests/spyre_util.py index 8d2e93711..dbaf8a425 100644 --- a/tests/spyre_util.py +++ b/tests/spyre_util.py @@ -363,13 +363,12 @@ def create_seq_prompt(model: ModelInfo, token_length: int) -> str: return tokenizer.decode(tokens) -def create_random_request( - request_id: int, - num_tokens: int, - sampling_params: SamplingParams, - from_model_vocab: bool = False, - model: Optional[ModelInfo] = None, -) -> Request: +def create_random_request(request_id: int, + num_tokens: int, + sampling_params: SamplingParams, + from_model_vocab: bool = False, + model: Optional[ModelInfo] = None, + seed: int = None) -> Request: tokenizer = AutoTokenizer.from_pretrained(model.name, revision=model.revision) @@ -381,6 +380,8 @@ def create_random_request( v for v in tokenizer.vocab.values() if v not in tokenizer.all_special_ids ]) + if seed is not None: + random.seed(seed) prompt_token_ids = random.choices(valid_token_ids, k=num_tokens) else: # start with existing prompts and tokenize them diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 3978c5626..33f14095f 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -13,13 +13,19 @@ import math import operator import os -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union import torch from transformers.models.granite import GraniteConfig from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger +try: + # pre 0.11.1 compatibility + from vllm.utils import FlexibleArgumentParser +except ImportError: + from vllm.utils.argparse_utils import FlexibleArgumentParser + if TYPE_CHECKING: # NB: We can't eagerly import many things from vllm since vllm.config # will import this file. These would lead to circular imports @@ -420,6 +426,14 @@ def _get_matching_warmup_shapes( and max_tokens <= shape['new_tokens'] ] + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + + if parser is not None: + parser.set_defaults(enable_prefix_caching=False) + @classmethod def _check_threading_config(cls, worker_count: int): """ diff --git a/vllm_spyre/v1/worker/spyre_input_batch.py b/vllm_spyre/v1/worker/spyre_input_batch.py index 0059c38ed..9ff18ec0f 100644 --- a/vllm_spyre/v1/worker/spyre_input_batch.py +++ b/vllm_spyre/v1/worker/spyre_input_batch.py @@ -12,6 +12,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.request import Request # from vllm.v1.sample.logits_processor.state import LogitsProcessors from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, LogitsProcessors, @@ -205,6 +206,12 @@ def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) +@dataclass +class ChunkedPrefillRequestState(SamplingRequestState): + scheduler_request: Optional[Request] = None + num_cached_tokens: int = 0 + + class SamplingInputBatch(BaseInputBatch[SamplingRequestState]): ''' This class was based on the InputBatch for GPU of vLLM V1. diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index d856b0f47..08e97da56 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -19,17 +19,21 @@ try: # pre 0.11.1 compatibility - from vllm.utils import is_pin_memory_available + from vllm.utils import get_hash_fn_by_name, is_pin_memory_available except ImportError: from vllm.utils.platform_utils import is_pin_memory_available + from vllm.utils.hashing import get_hash_fn_by_name from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import KVCacheBlock +from vllm.v1.core.kv_cache_utils import (KVCacheBlock, + get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.output import CachedRequestData from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.request import Request from vllm.v1.sample.logits_processor import build_logitsprocs import vllm_spyre.envs as envs_spyre @@ -44,6 +48,7 @@ # yapf: disable from vllm_spyre.v1.worker.spyre_input_batch import (BaseInputBatch, BaseRequestState, + ChunkedPrefillRequestState, PoolingInputBatch, PoolingRequestState, SamplingInputBatch, @@ -847,6 +852,9 @@ def __init__( self.tkv: int = 0 + self._enable_prefix_caching = ( + vllm_config.cache_config.enable_prefix_caching) + # TODO: Remove this once we can prefill and decode in the same step self.prefill_batch = SamplingInputBatch( # TODO: review this, currently we only support prefill for @@ -857,6 +865,10 @@ def __init__( pin_memory=self.pin_memory, vocab_size=vllm_config.model_config.get_vocab_size()) + @property + def enable_prefix_caching(self): + return self._enable_prefix_caching and not self.warmup_mode + def pre_warmup(self) -> None: # Set the number of kv cache blocks to the minimal value of 2 which is # required for warmup. After the warmup, the number of blocks will be @@ -892,16 +904,16 @@ def _set_blocks(self, num_blocks: int) -> None: # set number of available blocks and populate block_pool self.n_blocks = num_blocks - 1 self.block_pool = BlockPool(num_gpu_blocks=self.n_blocks + 1, - enable_caching=False, + enable_caching=self.enable_prefix_caching, enable_kv_cache_events=False) - attn_spec = FullAttentionSpec( + self._attn_spec = FullAttentionSpec( block_size=self.block_size, # dummy values num_kv_heads=1, head_size=1, dtype=torch.float16) self.kv_cache_manager = FullAttentionManager( - kv_cache_spec=attn_spec, + kv_cache_spec=self._attn_spec, block_pool=self.block_pool, # Currently don't support models with more than one # attention type, e.g. full and sliding window, so @@ -1822,8 +1834,18 @@ def __init__( is_driver_worker=is_driver_worker, rank=rank) - self.chunk_blocks_count = \ - self.scheduler_config.max_num_batched_tokens // self.block_size + self.chunk_size = self.scheduler_config.max_num_batched_tokens + self.chunk_blocks_count = self.chunk_size // self.block_size + + if vllm_config.cache_config.enable_prefix_caching: + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo) + init_none_hash(caching_hash_fn) + + self.request_block_hasher = get_request_block_hasher( + self.block_size, caching_hash_fn) + else: + self.request_block_hasher = None def _prepare_prompt(self, _): AssertionError( @@ -1895,18 +1917,32 @@ def _prepare_chunked_prefill(self, req_id: str): ''' request = self.requests[req_id] - prompt_token_ids = request.prompt_token_ids - - chunk_size = self.scheduler_config.max_num_batched_tokens - num_computed_tokens = request.num_computed_tokens + assert isinstance(request, ChunkedPrefillRequestState) + prompt_token_ids = request.prompt_token_ids prompt_len = len(prompt_token_ids) padded_prompt_len = math.ceil( prompt_len / self.block_size) * self.block_size - chunk_i = math.ceil(num_computed_tokens / chunk_size) + chunk_size = self.chunk_size chunk_count = math.ceil(prompt_len / chunk_size) left_padding = chunk_count * chunk_size - padded_prompt_len + left_padded_prompt_mask = torch.tensor([left_padding], + dtype=torch.int64, + device=self.device) + + num_computed_tokens = request.num_computed_tokens + num_cached_tokens = request.num_cached_tokens + + if num_cached_tokens > num_computed_tokens: + assert self.enable_prefix_caching, \ + "prefix caching has to be enabled" + # this will be an idle step + return SamplingForwardInputs( + is_prompt=True, + left_padded_prompt_mask=left_padded_prompt_mask) + + chunk_i = math.ceil(num_computed_tokens / chunk_size) input_tokens = torch.zeros(chunk_size, dtype=torch.int64, @@ -1964,10 +2000,6 @@ def _prepare_chunked_prefill(self, req_id: str): input_tokens = input_tokens.unsqueeze(0).clone() input_positions = input_positions.unsqueeze(0).clone() - left_padded_prompt_mask = torch.tensor([left_padding], - dtype=torch.int64, - device=self.device) - # NOTE(wallas): Looks like we need to use multiple of blocks for prefill # so, later we use model.n_pads_right to get right logits. # In my naive mind this would be the `request_tkv` below, @@ -2114,6 +2146,59 @@ def _prepare_decode( return model_inputs + def _maybe_load_prefix_from_cache(self, scheduler_request: Request) -> int: + num_cached_tokens = 0 + if self.enable_prefix_caching: + + prompt_len = len(scheduler_request.prompt_token_ids) + + chunk_size = self.chunk_size + padded_prompt_len = math.ceil( + prompt_len / self.block_size) * self.block_size + chunk_count = math.ceil(prompt_len / chunk_size) + + # chunks that we can fill from cache + # we can't reuse the last chunk even with a full hit + cacheable_chunks = chunk_count - 1 + cacheable_blocks = cacheable_chunks * self.chunk_blocks_count + + left_padding = chunk_count * chunk_size - padded_prompt_len + assert left_padding % self.block_size == 0 + left_blocks = left_padding // self.block_size + cacheable_blocks -= left_blocks + max_cache_hit_length = cacheable_blocks * self.block_size + + if max_cache_hit_length > 0: + computed_blocks: list[ + KVCacheBlock] = FullAttentionManager.find_longest_cache_hit( + block_hashes=scheduler_request.block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=[0], + block_pool=self.block_pool, + kv_cache_spec=self._attn_spec, + use_eagle=False, + dcp_world_size=1, + )[0] + n_hit = len(computed_blocks) + else: + computed_blocks = list[KVCacheBlock]() + n_hit = 0 + logger.debug("Found: %d cached_blocks", n_hit) + + # trim down to chunk boundary + usable_blocks = (((left_blocks + n_hit) // self.chunk_blocks_count)\ + * self.chunk_blocks_count) - left_blocks + usable_blocks = max(usable_blocks, 0) + logger.debug("Found: %d usable blocks in cache", usable_blocks) + computed_blocks = computed_blocks[:usable_blocks] + num_cached_tokens = usable_blocks * self.block_size + + self.block_pool.touch((computed_blocks, )) + self.kv_cache_manager.save_new_computed_blocks( + scheduler_request.request_id, computed_blocks) + + return num_cached_tokens + def add_new_request(self, request: NewRequestData): req_id = request.req_id prompt_token_ids = request.prompt_token_ids @@ -2137,6 +2222,18 @@ def add_new_request(self, request: NewRequestData): self.req_ids2num_reserved_blocks[req_id] = n_reserved_blocks + num_cached_tokens = 0 + scheduler_request = Request( + request_id=req_id, + prompt_token_ids=prompt_token_ids, + sampling_params=request.sampling_params, + pooling_params=None, + eos_token_id=None, + block_hasher=self.request_block_hasher, + ) + num_cached_tokens = self._maybe_load_prefix_from_cache( + scheduler_request) + # allocate blocks self.kv_cache_manager.allocate_new_blocks(req_id, prompt_len) @@ -2147,7 +2244,7 @@ def add_new_request(self, request: NewRequestData): else: generator = None - req_state = SamplingRequestState( + req_state = ChunkedPrefillRequestState( req_id=req_id, prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, @@ -2157,6 +2254,8 @@ def add_new_request(self, request: NewRequestData): # to always use the optimizations of blocks # usage left_padding=0, + scheduler_request=scheduler_request, + num_cached_tokens=num_cached_tokens, ) self.requests[req_id] = req_state @@ -2264,11 +2363,23 @@ def update_states(self, scheduler_output: SchedulerOutput): # input batch, and update states try to access it there. req_id = cached_reqs.req_ids[0] req_state = self.requests[req_id] + assert isinstance(req_state, ChunkedPrefillRequestState) num_computed_tokens = cached_reqs.num_computed_tokens[0] if num_computed_tokens < len(req_state.prompt_token_ids): # For now, if it is prefilling, we only need to update num of # computed tokens of the request req_state.num_computed_tokens = num_computed_tokens + if self.enable_prefix_caching: + num_cached_blocks = self.kv_cache_manager.\ + num_cached_block[req_id] + # if the number of cached tokens is larger or equal to the + # number of computed tokens, it means that during this call + # to execute_model we're just loading blocks from the KV + # cache and can't call `cache_blocks()` + if num_computed_tokens > \ + num_cached_blocks * self.block_size: + self.kv_cache_manager.cache_blocks( + req_state.scheduler_request, num_computed_tokens) # hide the prefill request from the super class scheduler_output.scheduled_cached_reqs = \ CachedRequestData.make_empty() @@ -2295,15 +2406,25 @@ def execute_model( model_input = self.prepare_model_input(scheduler_output) - # Execute the model - attn_metadata = self.build_attn_metadata(model_input) - with set_forward_context(attn_metadata, self.vllm_config): - logits = self.model(input_ids=model_input.input_tokens, - positions=model_input.input_positions, - masks=model_input.input_masks, - is_prompt=model_input.is_prompt) - + incomplete_prefill = False + is_cached_chunk = False is_prefill = cast(bool, model_input.is_prompt) + if is_prefill: + incomplete_prefill = self.check_incomplete_prefill( + scheduler_output) + is_cached_chunk = model_input.input_tokens is None + if is_cached_chunk: + assert incomplete_prefill, \ + "can't apply caching on the last chunked prefill" + + if not is_cached_chunk: + # Execute the model + attn_metadata = self.build_attn_metadata(model_input) + with set_forward_context(attn_metadata, self.vllm_config): + logits = self.model(input_ids=model_input.input_tokens, + positions=model_input.input_positions, + masks=model_input.input_masks, + is_prompt=model_input.is_prompt) # Get mapping between requests ids to the index within the batch req_id_to_index = self.get_req_id_to_index(is_prefill) @@ -2318,8 +2439,10 @@ def execute_model( # TODO: dead code, this only works for SB with bs=1, either # fix it or remove it. - prompt_logprobs_dicts = self._get_prompt_logprobs_dict( - logits=logits, model_inputs=model_input) + #prompt_logprobs_dicts = self._get_prompt_logprobs_dict( + # logits=logits, model_inputs=model_input) + # TODO: disable prefix caching for requests with prompt logprobs + prompt_logprobs_dicts: dict[str, Optional[LogprobsTensors]] = {} # If the prompt is being prefilled we don't have to sample # and generate a new token. @@ -2365,6 +2488,13 @@ def execute_model( for i, req_id in enumerate(req_ids): req_state = self.requests[req_id] + assert isinstance(req_state, ChunkedPrefillRequestState) + assert req_state.scheduler_request is not None + req_state.scheduler_request.append_output_token_ids(sampled_ids[i]) + if self.enable_prefix_caching: + self.kv_cache_manager.cache_blocks( + req_state.scheduler_request, + req_state.scheduler_request.num_tokens) req_state.output_token_ids.extend(sampled_ids[i]) # Only return outputs from the driver worker