Skip to content

Commit abdff1d

Browse files
youkaichaoAlvant
authored andcommitted
[core] move parallel sampling out from vllm core (vllm-project#9302)
1 parent 744b0a1 commit abdff1d

File tree

4 files changed

+222
-29
lines changed

4 files changed

+222
-29
lines changed

tests/entrypoints/openai/test_completion.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
340340
assert "".join(chunks) == single_output
341341

342342

343+
@pytest.mark.asyncio
344+
@pytest.mark.parametrize(
345+
"model_name",
346+
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
347+
)
348+
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
349+
"""Streaming for parallel sampling.
350+
The tokens from multiple samples, are flattened into a single stream,
351+
with an index to indicate which sample the token belongs to.
352+
"""
353+
354+
prompt = "What is an LLM?"
355+
n = 3
356+
max_tokens = 5
357+
358+
stream = await client.completions.create(model=model_name,
359+
prompt=prompt,
360+
max_tokens=max_tokens,
361+
n=n,
362+
stream=True)
363+
chunks: List[List[str]] = [[] for i in range(n)]
364+
finish_reason_count = 0
365+
async for chunk in stream:
366+
index = chunk.choices[0].index
367+
text = chunk.choices[0].text
368+
chunks[index].append(text)
369+
if chunk.choices[0].finish_reason is not None:
370+
finish_reason_count += 1
371+
assert finish_reason_count == n
372+
for chunk in chunks:
373+
assert len(chunk) == max_tokens
374+
print("".join(chunk))
375+
376+
343377
@pytest.mark.asyncio
344378
@pytest.mark.parametrize(
345379
"model_name",

vllm/engine/llm_engine.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444
from vllm.prompt_adapter.request import PromptAdapterRequest
4545
from vllm.sampling_params import RequestOutputKind, SamplingParams
4646
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
47-
Sequence, SequenceGroup, SequenceGroupMetadata,
48-
SequenceGroupOutput, SequenceStatus)
47+
ParallelSampleSequenceGroup, Sequence,
48+
SequenceGroup, SequenceGroupBase,
49+
SequenceGroupMetadata, SequenceGroupOutput,
50+
SequenceStatus)
4951
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
5052
init_tracer)
5153
from vllm.transformers_utils.config import try_get_generation_config
@@ -474,6 +476,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
474476
),
475477
))
476478

479+
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
480+
477481
def _initialize_kv_caches(self) -> None:
478482
"""Initialize the KV cache in the worker(s).
479483
@@ -648,7 +652,10 @@ def _add_processed_request(
648652
prompt_adapter_request: Optional[PromptAdapterRequest],
649653
trace_headers: Optional[Mapping[str, str]] = None,
650654
priority: int = 0,
651-
) -> None:
655+
) -> SequenceGroup:
656+
"""Add a processed request to the engine's request pool.
657+
return the created sequence group.
658+
"""
652659
self._validate_model_inputs(processed_inputs)
653660
# Create the sequences.
654661
block_size = self.cache_config.block_size
@@ -701,6 +708,8 @@ def _add_processed_request(
701708
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
702709
min_cost_scheduler.add_seq_group(seq_group)
703710

711+
return seq_group
712+
704713
def stop_remote_worker_execution_loop(self) -> None:
705714
self.model_executor.stop_remote_worker_execution_loop()
706715

@@ -754,7 +763,7 @@ def add_request(
754763
trace_headers: Optional[Mapping[str, str]] = None,
755764
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
756765
priority: int = 0,
757-
) -> None:
766+
) -> Optional[SequenceGroup]:
758767
...
759768

760769
@overload
@@ -768,7 +777,7 @@ def add_request(
768777
trace_headers: Optional[Mapping[str, str]] = None,
769778
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
770779
priority: int = 0,
771-
) -> None:
780+
) -> Optional[SequenceGroup]:
772781
...
773782

774783
@deprecate_kwargs(
@@ -787,7 +796,7 @@ def add_request(
787796
priority: int = 0,
788797
*,
789798
inputs: Optional[PromptType] = None, # DEPRECATED
790-
) -> None:
799+
) -> Optional[SequenceGroup]:
791800
"""Add a request to the engine's request pool.
792801
793802
The request is added to the request pool and will be processed by the
@@ -831,6 +840,22 @@ def add_request(
831840
>>> # continue the request processing
832841
>>> ...
833842
"""
843+
844+
if isinstance(params, SamplingParams) and params.n > 1:
845+
ParallelSampleSequenceGroup.add_request(
846+
request_id,
847+
self,
848+
params,
849+
prompt=prompt,
850+
arrival_time=arrival_time,
851+
lora_request=lora_request,
852+
trace_headers=trace_headers,
853+
prompt_adapter_request=prompt_adapter_request,
854+
priority=priority,
855+
inputs=inputs,
856+
)
857+
return None
858+
834859
if inputs is not None:
835860
prompt = inputs
836861
assert prompt is not None and params is not None
@@ -865,7 +890,7 @@ def add_request(
865890
params=params,
866891
lora_request=lora_request)
867892

868-
self._add_processed_request(
893+
return self._add_processed_request(
869894
request_id=request_id,
870895
processed_inputs=processed_inputs,
871896
params=processed_params,
@@ -1182,7 +1207,9 @@ def _process_model_outputs(self,
11821207
seq_group = scheduled_seq_group.seq_group
11831208
seq_group.maybe_set_first_token_time(now)
11841209
request_output = RequestOutputFactory.create(
1185-
seq_group, use_cache=self.use_cached_outputs)
1210+
seq_group,
1211+
self.seq_id_to_seq_group,
1212+
use_cache=self.use_cached_outputs)
11861213
if request_output:
11871214
ctx.request_outputs.append(request_output)
11881215

@@ -1222,7 +1249,9 @@ def _process_model_outputs(self,
12221249
seq_group = scheduled_seq_group.seq_group
12231250
seq_group.maybe_set_first_token_time(now)
12241251
request_output = RequestOutputFactory.create(
1225-
seq_group, use_cache=self.use_cached_outputs)
1252+
seq_group,
1253+
self.seq_id_to_seq_group,
1254+
use_cache=self.use_cached_outputs)
12261255
if request_output:
12271256
ctx.request_outputs.append(request_output)
12281257

@@ -1241,7 +1270,10 @@ def _process_model_outputs(self,
12411270
continue
12421271

12431272
request_output = RequestOutputFactory.create(
1244-
seq_group, use_cache=self.use_cached_outputs)
1273+
seq_group,
1274+
self.seq_id_to_seq_group,
1275+
use_cache=self.use_cached_outputs,
1276+
)
12451277
if request_output:
12461278
ctx.request_outputs.append(request_output)
12471279

vllm/outputs.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import time
22
from dataclasses import dataclass
3-
from typing import List, Optional
3+
from typing import Dict, List, Optional
44
from typing import Sequence as GenericSequence
55
from typing import Union
66

77
from vllm.lora.request import LoRARequest
88
from vllm.sampling_params import RequestOutputKind
99
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
10-
SequenceGroup, SequenceStatus)
10+
SequenceGroup, SequenceGroupBase, SequenceStatus)
1111

1212

1313
@dataclass
@@ -114,14 +114,28 @@ def __init__(
114114
self.encoder_prompt_token_ids = encoder_prompt_token_ids
115115

116116
@classmethod
117-
def from_seq_group(cls, seq_group: SequenceGroup,
118-
use_cache: bool) -> Optional["RequestOutput"]:
117+
def from_seq_group(
118+
cls, seq_group: SequenceGroup, use_cache: bool,
119+
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
120+
) -> Optional["RequestOutput"]:
121+
finished = seq_group.is_finished()
122+
123+
if seq_group.request_id in seq_id_to_seq_group:
124+
group: SequenceGroupBase = seq_id_to_seq_group[
125+
seq_group.request_id]
126+
if finished:
127+
group.finish_seq(seq_group)
128+
assembled_seq_group = group.maybe_assemble_group(seq_group)
129+
if assembled_seq_group is None:
130+
return None
131+
return cls.from_seq_group(assembled_seq_group, use_cache,
132+
seq_id_to_seq_group)
133+
119134
sampling_params = seq_group.sampling_params
120135
if sampling_params is None:
121136
raise ValueError(
122137
"Sampling parameters are missing for a CompletionRequest.")
123138

124-
finished = seq_group.is_finished()
125139
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
126140
not finished):
127141
return None
@@ -136,15 +150,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
136150
outputs=[],
137151
finished=False)
138152

139-
seqs = seq_group.get_seqs()
140-
if len(seqs) == 1:
141-
top_n_seqs = seqs
142-
else:
143-
# Get the top-n sequences.
144-
n = sampling_params._real_n or sampling_params.n
145-
sorting_key = lambda seq: seq.get_cumulative_logprob()
146-
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
147-
top_n_seqs = sorted_seqs[:n]
153+
top_n_seqs = seq_group.get_seqs()
148154

149155
# Create the outputs.
150156
# NOTE: We need omit logprobs here explicitly because the sequence
@@ -208,7 +214,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
208214

209215
else:
210216
output = CompletionOutput(
211-
seqs.index(seq), output_text, [output_token_ids]
217+
top_n_seqs.index(seq), output_text, [output_token_ids]
212218
if isinstance(output_token_ids, int) else output_token_ids,
213219
seq.get_cumulative_logprob() if include_logprobs else None,
214220
output_logprobs,
@@ -309,10 +315,13 @@ def __repr__(self):
309315
class RequestOutputFactory:
310316

311317
@staticmethod
312-
def create(seq_group: SequenceGroup, use_cache: bool = False):
318+
def create(seq_group: SequenceGroup,
319+
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
320+
use_cache: bool = False):
313321
# Determine the type based on a condition, for example:
314322
if hasattr(seq_group,
315323
'embeddings') and seq_group.embeddings is not None:
316324
return EmbeddingRequestOutput.from_seq_group(seq_group)
317325
else:
318-
return RequestOutput.from_seq_group(seq_group, use_cache)
326+
return RequestOutput.from_seq_group(seq_group, use_cache,
327+
seq_id_to_seq_group)

0 commit comments

Comments
 (0)