Skip to content

Commit 76a5e13

Browse files
authored
[core] move parallel sampling out from vllm core (#9302)
1 parent ef7faad commit 76a5e13

File tree

4 files changed

+222
-29
lines changed

4 files changed

+222
-29
lines changed

tests/entrypoints/openai/test_completion.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
340340
assert "".join(chunks) == single_output
341341

342342

343+
@pytest.mark.asyncio
344+
@pytest.mark.parametrize(
345+
"model_name",
346+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
347+
)
348+
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
349+
"""Streaming for parallel sampling.
350+
The tokens from multiple samples, are flattened into a single stream,
351+
with an index to indicate which sample the token belongs to.
352+
"""
353+
354+
prompt = "What is an LLM?"
355+
n = 3
356+
max_tokens = 5
357+
358+
stream = await client.completions.create(model=model_name,
359+
prompt=prompt,
360+
max_tokens=max_tokens,
361+
n=n,
362+
stream=True)
363+
chunks: List[List[str]] = [[] for i in range(n)]
364+
finish_reason_count = 0
365+
async for chunk in stream:
366+
index = chunk.choices[0].index
367+
text = chunk.choices[0].text
368+
chunks[index].append(text)
369+
if chunk.choices[0].finish_reason is not None:
370+
finish_reason_count += 1
371+
assert finish_reason_count == n
372+
for chunk in chunks:
373+
assert len(chunk) == max_tokens
374+
print("".join(chunk))
375+
376+
343377
@pytest.mark.asyncio
344378
@pytest.mark.parametrize(
345379
"model_name",

vllm/engine/llm_engine.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444
from vllm.prompt_adapter.request import PromptAdapterRequest
4545
from vllm.sampling_params import RequestOutputKind, SamplingParams
4646
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
47-
Sequence, SequenceGroup, SequenceGroupMetadata,
48-
SequenceGroupOutput, SequenceStatus)
47+
ParallelSampleSequenceGroup, Sequence,
48+
SequenceGroup, SequenceGroupBase,
49+
SequenceGroupMetadata, SequenceGroupOutput,
50+
SequenceStatus)
4951
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
5052
init_tracer)
5153
from vllm.transformers_utils.config import try_get_generation_config
@@ -474,6 +476,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
474476
),
475477
))
476478

479+
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
480+
477481
def _initialize_kv_caches(self) -> None:
478482
"""Initialize the KV cache in the worker(s).
479483
@@ -642,7 +646,10 @@ def _add_processed_request(
642646
prompt_adapter_request: Optional[PromptAdapterRequest],
643647
trace_headers: Optional[Mapping[str, str]] = None,
644648
priority: int = 0,
645-
) -> None:
649+
) -> SequenceGroup:
650+
"""Add a processed request to the engine's request pool.
651+
return the created sequence group.
652+
"""
646653
self._validate_model_inputs(processed_inputs)
647654
# Create the sequences.
648655
block_size = self.cache_config.block_size
@@ -696,6 +703,8 @@ def _add_processed_request(
696703
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
697704
min_cost_scheduler.add_seq_group(seq_group)
698705

706+
return seq_group
707+
699708
def stop_remote_worker_execution_loop(self) -> None:
700709
self.model_executor.stop_remote_worker_execution_loop()
701710

@@ -711,7 +720,7 @@ def add_request(
711720
trace_headers: Optional[Mapping[str, str]] = None,
712721
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
713722
priority: int = 0,
714-
) -> None:
723+
) -> Optional[SequenceGroup]:
715724
...
716725

717726
@overload
@@ -725,7 +734,7 @@ def add_request(
725734
trace_headers: Optional[Mapping[str, str]] = None,
726735
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
727736
priority: int = 0,
728-
) -> None:
737+
) -> Optional[SequenceGroup]:
729738
...
730739

731740
@deprecate_kwargs(
@@ -744,7 +753,7 @@ def add_request(
744753
priority: int = 0,
745754
*,
746755
inputs: Optional[PromptType] = None, # DEPRECATED
747-
) -> None:
756+
) -> Optional[SequenceGroup]:
748757
"""Add a request to the engine's request pool.
749758
750759
The request is added to the request pool and will be processed by the
@@ -788,6 +797,22 @@ def add_request(
788797
>>> # continue the request processing
789798
>>> ...
790799
"""
800+
801+
if isinstance(params, SamplingParams) and params.n > 1:
802+
ParallelSampleSequenceGroup.add_request(
803+
request_id,
804+
self,
805+
params,
806+
prompt=prompt,
807+
arrival_time=arrival_time,
808+
lora_request=lora_request,
809+
trace_headers=trace_headers,
810+
prompt_adapter_request=prompt_adapter_request,
811+
priority=priority,
812+
inputs=inputs,
813+
)
814+
return None
815+
791816
if inputs is not None:
792817
prompt = inputs
793818
assert prompt is not None and params is not None
@@ -818,7 +843,7 @@ def add_request(
818843
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
819844
"mm_processor_kwargs")
820845

821-
self._add_processed_request(
846+
return self._add_processed_request(
822847
request_id=request_id,
823848
processed_inputs=processed_inputs,
824849
params=params,
@@ -1135,7 +1160,9 @@ def _process_model_outputs(self,
11351160
seq_group = scheduled_seq_group.seq_group
11361161
seq_group.maybe_set_first_token_time(now)
11371162
request_output = RequestOutputFactory.create(
1138-
seq_group, use_cache=self.use_cached_outputs)
1163+
seq_group,
1164+
self.seq_id_to_seq_group,
1165+
use_cache=self.use_cached_outputs)
11391166
if request_output:
11401167
ctx.request_outputs.append(request_output)
11411168

@@ -1175,7 +1202,9 @@ def _process_model_outputs(self,
11751202
seq_group = scheduled_seq_group.seq_group
11761203
seq_group.maybe_set_first_token_time(now)
11771204
request_output = RequestOutputFactory.create(
1178-
seq_group, use_cache=self.use_cached_outputs)
1205+
seq_group,
1206+
self.seq_id_to_seq_group,
1207+
use_cache=self.use_cached_outputs)
11791208
if request_output:
11801209
ctx.request_outputs.append(request_output)
11811210

@@ -1194,7 +1223,10 @@ def _process_model_outputs(self,
11941223
continue
11951224

11961225
request_output = RequestOutputFactory.create(
1197-
seq_group, use_cache=self.use_cached_outputs)
1226+
seq_group,
1227+
self.seq_id_to_seq_group,
1228+
use_cache=self.use_cached_outputs,
1229+
)
11981230
if request_output:
11991231
ctx.request_outputs.append(request_output)
12001232

vllm/outputs.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import time
22
from dataclasses import dataclass
3-
from typing import List, Optional
3+
from typing import Dict, List, Optional
44
from typing import Sequence as GenericSequence
55
from typing import Union
66

77
from vllm.lora.request import LoRARequest
88
from vllm.sampling_params import RequestOutputKind
99
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
10-
SequenceGroup, SequenceStatus)
10+
SequenceGroup, SequenceGroupBase, SequenceStatus)
1111

1212

1313
@dataclass
@@ -114,14 +114,28 @@ def __init__(
114114
self.encoder_prompt_token_ids = encoder_prompt_token_ids
115115

116116
@classmethod
117-
def from_seq_group(cls, seq_group: SequenceGroup,
118-
use_cache: bool) -> Optional["RequestOutput"]:
117+
def from_seq_group(
118+
cls, seq_group: SequenceGroup, use_cache: bool,
119+
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
120+
) -> Optional["RequestOutput"]:
121+
finished = seq_group.is_finished()
122+
123+
if seq_group.request_id in seq_id_to_seq_group:
124+
group: SequenceGroupBase = seq_id_to_seq_group[
125+
seq_group.request_id]
126+
if finished:
127+
group.finish_seq(seq_group)
128+
assembled_seq_group = group.maybe_assemble_group(seq_group)
129+
if assembled_seq_group is None:
130+
return None
131+
return cls.from_seq_group(assembled_seq_group, use_cache,
132+
seq_id_to_seq_group)
133+
119134
sampling_params = seq_group.sampling_params
120135
if sampling_params is None:
121136
raise ValueError(
122137
"Sampling parameters are missing for a CompletionRequest.")
123138

124-
finished = seq_group.is_finished()
125139
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
126140
not finished):
127141
return None
@@ -136,15 +150,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
136150
outputs=[],
137151
finished=False)
138152

139-
seqs = seq_group.get_seqs()
140-
if len(seqs) == 1:
141-
top_n_seqs = seqs
142-
else:
143-
# Get the top-n sequences.
144-
n = sampling_params._real_n or sampling_params.n
145-
sorting_key = lambda seq: seq.get_cumulative_logprob()
146-
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
147-
top_n_seqs = sorted_seqs[:n]
153+
top_n_seqs = seq_group.get_seqs()
148154

149155
# Create the outputs.
150156
# NOTE: We need omit logprobs here explicitly because the sequence
@@ -208,7 +214,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
208214

209215
else:
210216
output = CompletionOutput(
211-
seqs.index(seq), output_text, [output_token_ids]
217+
top_n_seqs.index(seq), output_text, [output_token_ids]
212218
if isinstance(output_token_ids, int) else output_token_ids,
213219
seq.get_cumulative_logprob() if include_logprobs else None,
214220
output_logprobs,
@@ -309,10 +315,13 @@ def __repr__(self):
309315
class RequestOutputFactory:
310316

311317
@staticmethod
312-
def create(seq_group: SequenceGroup, use_cache: bool = False):
318+
def create(seq_group: SequenceGroup,
319+
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
320+
use_cache: bool = False):
313321
# Determine the type based on a condition, for example:
314322
if hasattr(seq_group,
315323
'embeddings') and seq_group.embeddings is not None:
316324
return EmbeddingRequestOutput.from_seq_group(seq_group)
317325
else:
318-
return RequestOutput.from_seq_group(seq_group, use_cache)
326+
return RequestOutput.from_seq_group(seq_group, use_cache,
327+
seq_id_to_seq_group)

0 commit comments

Comments
 (0)