Skip to content

Commit 992cfc9

Browse files
author
Varun Sundar Rabindranath
committed
ms for chunked prefill
1 parent 3351973 commit 992cfc9

File tree

6 files changed

+166
-50
lines changed

6 files changed

+166
-50
lines changed

examples/openai_completion_client.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,44 @@
44
openai_api_key = "EMPTY"
55
openai_api_base = "http://localhost:8000/v1"
66

7-
client = OpenAI(
8-
# defaults to os.environ.get("OPENAI_API_KEY")
9-
api_key=openai_api_key,
10-
base_url=openai_api_base,
11-
)
12-
13-
models = client.models.list()
14-
model = models.data[0].id
15-
16-
# Completion API
17-
stream = False
18-
completion = client.completions.create(
19-
model=model,
20-
prompt="A robot may not injure a human being",
21-
echo=False,
22-
n=2,
23-
stream=stream,
24-
logprobs=3)
25-
26-
print("Completion results:")
27-
if stream:
28-
for c in completion:
29-
print(c)
30-
else:
31-
print(completion)
7+
8+
def get_prompts(n=1):
9+
ps = ['A robot may not injure a human being']
10+
for i in range(1, n):
11+
ps.append(' '.join(["hi!"] * i))
12+
13+
return ps
14+
15+
16+
def main():
17+
client = OpenAI(
18+
# defaults to os.environ.get("OPENAI_API_KEY")
19+
api_key=openai_api_key,
20+
base_url=openai_api_base,
21+
)
22+
23+
models = client.models.list()
24+
model = models.data[0].id
25+
26+
prompts = get_prompts(50)
27+
#print (f"{prompts}")
28+
print(f"# PROMPTS : {len(prompts)}")
29+
30+
# Completion API
31+
stream = False
32+
completion = client.completions.create(model=model,
33+
prompt=prompts,
34+
echo=False,
35+
n=1,
36+
stream=stream)
37+
38+
print("Completion results:")
39+
if stream:
40+
for c in completion:
41+
print(c)
42+
else:
43+
print(completion)
44+
45+
46+
if __name__ == '__main__':
47+
main()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Test the AsyncLLMEngine with multi-step-decoding and chunked prefill
2+
3+
from typing import List
4+
5+
import pytest
6+
7+
from ..utils import RemoteOpenAIServer
8+
9+
MODELS = [
10+
"facebook/opt-125m",
11+
"meta-llama/Llama-2-7b-hf",
12+
]
13+
NUM_SCHEDULER_STEPS = [8, 16] # Multi-step decoding steps
14+
NUM_PROMPTS = [100]
15+
16+
# TODO (varun) : Expand tests for multiple TP & PP
17+
DEFAULT_SERVER_ARGS: List[str] = [
18+
"--disable-log-requests",
19+
"--use-v2-block-manager",
20+
"--worker-use-ray",
21+
"--gpu-memory-utilization",
22+
"0.90",
23+
"--swap-space",
24+
"16",
25+
"--tensor-parallel-size",
26+
"1",
27+
"--pipeline-parallel-size",
28+
"1",
29+
]
30+
31+
32+
async def completions_with_server_args(prompts: List[str], model_name: str,
33+
server_cli_args: List[str]):
34+
35+
outputs = None
36+
with RemoteOpenAIServer(model_name, server_cli_args) as server:
37+
client = server.get_async_client()
38+
outputs = await client.completions.create(model=model_name,
39+
prompt=prompts,
40+
temperature=0,
41+
stream=False,
42+
max_tokens=150)
43+
assert outputs is not None
44+
45+
return outputs
46+
47+
48+
@pytest.mark.parametrize("model", MODELS)
49+
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
50+
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
51+
@pytest.mark.asyncio
52+
async def test_mutli_step_with_chunked_prefill(example_prompts, model: str,
53+
num_scheduler_steps: int,
54+
num_prompts: int):
55+
56+
prompts = example_prompts
57+
if len(prompts) < num_prompts:
58+
prompts = prompts * ((num_prompts // len(prompts)) + 1)
59+
prompts = prompts[:num_prompts]
60+
assert len(prompts) == num_prompts
61+
62+
server_args = DEFAULT_SERVER_ARGS + \
63+
["--num-scheduler-steps", f"{num_scheduler_steps}"]
64+
65+
ref_completions = await completions_with_server_args(
66+
prompts, model, server_args)
67+
test_completions = await completions_with_server_args(
68+
prompts, model, server_args + ["--enable-chunked-prefill"])
69+
70+
def get_text_generations(completions):
71+
return [x.text for x in completions.choices]
72+
73+
ref_generations = get_text_generations(ref_completions)
74+
test_generations = get_text_generations(test_completions)
75+
assert ref_generations == test_generations

vllm/core/scheduler.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
981981
[s.seq_group for s in swapped_in.prefill_seq_groups])
982982
# Update swapped requests.
983983
self.swapped.extend(running_scheduled.swapped_out)
984+
985+
if self.scheduler_config.is_multi_step:
986+
# It maybe the case that prefills are scheduled along
987+
# with decodes. In that case update the multi-step state
988+
# of all the scheduled sequences to perform just a single
989+
# decoding step.
990+
has_prefills = len(prefills.seq_groups) + \
991+
len(running_scheduled.prefill_seq_groups) + \
992+
len(swapped_in.prefill_seq_groups) > 0
993+
if has_prefills:
994+
for sg in running_scheduled.decode_seq_groups:
995+
sg.seq_group.init_multi_step(1)
996+
for sg in swapped_in.decode_seq_groups:
997+
sg.seq_group.init_multi_step(1)
998+
984999
return SchedulerOutputs(
9851000
scheduled_seq_groups=(prefills.seq_groups +
9861001
running_scheduled.prefill_seq_groups +
@@ -1187,7 +1202,8 @@ def _append_slots(
11871202
the new source and destination block indices for the appended
11881203
slots.
11891204
"""
1190-
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
1205+
num_lookahead_slots = self._get_num_lookahead_slots(\
1206+
is_prefill=seq_group.is_prefill())
11911207
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
11921208

11931209
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):

vllm/engine/arg_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,6 @@ def create_engine_config(self, ) -> EngineConfig:
868868
if speculative_config is not None:
869869
raise ValueError("Speculative decoding is not supported with "
870870
"multi-step (--num-scheduler-steps > 1)")
871-
if self.enable_chunked_prefill:
872-
raise ValueError("Chunked prefill is not supported with "
873-
"multi-step (--num-scheduler-steps > 1)")
874871
if not self.use_v2_block_manager:
875872
raise ValueError("BlockSpaceManagerV2 is required for "
876873
"multi-step (--num-scheduler-steps > 1)")

vllm/engine/async_llm_engine.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,10 @@ async def step_async(
294294
seq_group_metadata_list, scheduler_outputs = self.scheduler[
295295
virtual_engine].schedule()
296296

297-
if (self.scheduler_config.is_multi_step
298-
and scheduler_outputs.num_lookahead_slots > 0):
297+
if self.scheduler_config.is_multi_step and \
298+
self._remaining_steps(seq_group_metadata_list) > 1:
299299
# cache the scheduler outputs for the next iteration if we have
300-
# lookahead slots
300+
# one.
301301
self._cache_scheduler_outputs_for_multi_step(
302302
virtual_engine, seq_group_metadata_list, scheduler_outputs)
303303

@@ -361,14 +361,15 @@ async def step_async(
361361

362362
return request_outputs
363363

364-
def _has_remaining_steps(
365-
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
366-
) -> bool:
364+
def _remaining_steps(
365+
self,
366+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
367+
) -> int:
367368
if not self.scheduler_config.is_multi_step:
368-
return False
369+
return 0
369370

370371
if not seq_group_metadata_list:
371-
return False
372+
return 0
372373

373374
# TODO(will) this is a sanity check for nowto make sure that all the
374375
# seqs are on the same steps. Eventually we will want to do some sort of
@@ -381,7 +382,12 @@ def _has_remaining_steps(
381382
raise AssertionError(("All running sequence groups should "
382383
"have the same remaining steps."))
383384

384-
return ref_remaining_steps > 0
385+
return ref_remaining_steps
386+
387+
def _has_remaining_steps(
388+
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
389+
) -> bool:
390+
return self._remaining_steps(seq_group_metadata_list) > 0
385391

386392
def _cache_scheduler_outputs_for_multi_step(
387393
self, virtual_engine: int,

vllm/worker/multi_step_model_runner.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,10 @@ def _pythonize_sampler_output(
481481
# samples generation should have been skipped
482482
assert not output.outputs
483483

484-
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
484+
# dont use num-queries as some of the sequence's may not need sampling.
485+
# Like, chunked prefill seqs.
486+
n_sampled_token_ids = sampled_token_ids.shape[0]
487+
pinned_buffer = pinned_sampled_token_buffer[:n_sampled_token_ids]
485488

486489
# CPU GPU sync
487490
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
@@ -491,20 +494,23 @@ def _pythonize_sampler_output(
491494

492495
sampling_metadata = frozen_model_input.sampling_metadata
493496

494-
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
495-
samples_list):
496-
seq_ids = seq_group.seq_ids
497-
next_token_ids = sample_result
498-
parent_ids = [0]
497+
sample_result_it = iter(samples_list)
498+
for seq_group in sampling_metadata.seq_groups:
499499
seq_outputs: List[SequenceOutput] = []
500500
if seq_group.sampling_params.logits_processors:
501501
assert len(seq_group.sampling_params.logits_processors) == 0, (
502502
"Logits Processors are not supported in multi-step decoding")
503-
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
504-
# TODO(will): support logprobs
505-
# Hard coded logprob
506-
seq_outputs.append(
507-
SequenceOutput(seq_ids[parent_id], next_token_id,
508-
{next_token_id: Logprob(logprob=42)}))
503+
if seq_group.do_sample:
504+
sample_result = next(sample_result_it)
505+
seq_ids = seq_group.seq_ids
506+
next_token_ids = sample_result
507+
parent_ids = [0]
508+
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
509+
# TODO(will): support logprobs
510+
# Hard coded logprob
511+
seq_outputs.append(
512+
SequenceOutput(seq_ids[parent_id], next_token_id,
513+
{next_token_id: Logprob(logprob=42)}))
509514
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))
515+
510516
assert len(output.outputs) > 0

0 commit comments

Comments
 (0)