Skip to content

Commit 4d31cd4

Browse files
authored
[Frontend] merge beam search implementations (#9296)
1 parent 473e7b3 commit 4d31cd4

File tree

5 files changed

+145
-234
lines changed

5 files changed

+145
-234
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 6 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,31 @@
77
from weakref import ReferenceType
88

99
import vllm.envs as envs
10-
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
1110
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
1211
ParallelConfig, SchedulerConfig)
1312
from vllm.core.scheduler import SchedulerOutputs
1413
from vllm.engine.arg_utils import AsyncEngineArgs
1514
from vllm.engine.async_timeout import asyncio_timeout
1615
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1716
from vllm.engine.metrics_types import StatLoggerBase
17+
from vllm.engine.protocol import EngineClient
1818
from vllm.executor.executor_base import ExecutorAsyncBase
1919
from vllm.executor.gpu_executor import GPUExecutorAsync
2020
from vllm.executor.ray_utils import initialize_ray_cluster
21-
from vllm.inputs import PromptType, TokensPrompt
21+
from vllm.inputs import PromptType
2222
from vllm.logger import init_logger
2323
from vllm.lora.request import LoRARequest
2424
from vllm.model_executor.guided_decoding import (
2525
get_guided_decoding_logits_processor)
2626
from vllm.model_executor.layers.sampler import SamplerOutput
27-
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
28-
RequestOutput)
27+
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
2928
from vllm.pooling_params import PoolingParams
3029
from vllm.prompt_adapter.request import PromptAdapterRequest
31-
from vllm.sampling_params import BeamSearchParams, SamplingParams
30+
from vllm.sampling_params import SamplingParams
3231
from vllm.sequence import ExecuteModelRequest
3332
from vllm.transformers_utils.tokenizer import AnyTokenizer
3433
from vllm.usage.usage_lib import UsageContext
35-
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
36-
random_uuid, weak_bind)
34+
from vllm.utils import deprecate_kwargs, weak_bind
3735

3836
logger = init_logger(__name__)
3937
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async(
541539
return sampling_params
542540

543541

544-
class AsyncLLMEngine:
542+
class AsyncLLMEngine(EngineClient):
545543
"""An asynchronous wrapper for :class:`LLMEngine`.
546544
547545
This class is used to wrap the :class:`LLMEngine` class to make it
@@ -1039,102 +1037,6 @@ async def generate(
10391037
):
10401038
yield LLMEngine.validate_output(output, RequestOutput)
10411039

1042-
async def beam_search(
1043-
self,
1044-
prompt: Union[PromptType, List[int]],
1045-
request_id: str,
1046-
params: BeamSearchParams,
1047-
) -> AsyncGenerator[RequestOutput, None]:
1048-
1049-
beam_width = params.beam_width
1050-
max_tokens = params.max_tokens
1051-
ignore_eos = params.ignore_eos
1052-
temperature = params.temperature
1053-
length_penalty = params.length_penalty
1054-
1055-
tokenizer = await self.get_tokenizer()
1056-
tokenizedPrompt = prompt if isinstance(
1057-
prompt, list) else tokenizer.encode(prompt)
1058-
tokenizedLength = len(tokenizedPrompt)
1059-
1060-
sort_beams_key = create_sort_beams_key_function(
1061-
tokenizer.eos_token_id, length_penalty)
1062-
1063-
beam_search_params = SamplingParams(logprobs=2 * beam_width,
1064-
max_tokens=1,
1065-
temperature=temperature)
1066-
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
1067-
completed = []
1068-
1069-
for _ in range(max_tokens):
1070-
prompts_batch = [
1071-
TokensPrompt(prompt_token_ids=beam.tokens)
1072-
for beam in all_beams
1073-
]
1074-
1075-
tasks = []
1076-
1077-
request_id = f"beam_search-{random_uuid()}"
1078-
for i, individual_prompt in enumerate(prompts_batch):
1079-
request_id_item = f"{request_id}-{i}"
1080-
task = asyncio.create_task(
1081-
collect_from_async_generator(
1082-
self.generate(individual_prompt, beam_search_params,
1083-
request_id_item)))
1084-
tasks.append(task)
1085-
1086-
output = await asyncio.gather(*tasks)
1087-
1088-
output = [x[0] for x in output]
1089-
1090-
logger.info(output)
1091-
1092-
new_beams = []
1093-
for i, current_beam in enumerate(all_beams):
1094-
result = output[i]
1095-
1096-
if result.outputs[0].logprobs is not None:
1097-
logprobs = result.outputs[0].logprobs[0]
1098-
for token_id, logprob_obj in logprobs.items():
1099-
new_beam = BeamSearchSequence(
1100-
tokens=current_beam.tokens + [token_id],
1101-
cum_logprob=current_beam.cum_logprob +
1102-
logprob_obj.logprob)
1103-
1104-
if token_id == tokenizer.eos_token_id and \
1105-
not ignore_eos:
1106-
completed.append(new_beam)
1107-
else:
1108-
new_beams.append(new_beam)
1109-
1110-
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
1111-
all_beams = sorted_beams[:beam_width]
1112-
1113-
completed.extend(all_beams)
1114-
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
1115-
best_beams = sorted_completed[:beam_width]
1116-
1117-
for beam in best_beams:
1118-
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
1119-
1120-
beam_search_output = RequestOutput(
1121-
request_id=request_id,
1122-
prompt=prompt,
1123-
outputs=[
1124-
CompletionOutput(
1125-
text=beam.text,
1126-
cumulative_logprob=beam.cum_logprob,
1127-
token_ids=beam.tokens,
1128-
index=i,
1129-
logprobs=beam.cum_logprob,
1130-
) for (i, beam) in enumerate(best_beams)
1131-
],
1132-
finished=True,
1133-
prompt_token_ids=tokenizedPrompt,
1134-
prompt_logprobs=None)
1135-
1136-
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
1137-
11381040
async def encode(
11391041
self,
11401042
prompt: PromptType,

vllm/engine/multiprocessing/client.py

Lines changed: 17 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from zmq.asyncio import Socket
1313

1414
from vllm import PoolingParams
15-
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
1615
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
16+
from vllm.core.scheduler import SchedulerOutputs
1717
from vllm.engine.arg_utils import AsyncEngineArgs
1818
# yapf conflicts with isort for this block
1919
# yapf: disable
@@ -26,18 +26,18 @@
2626
RPCError, RPCProcessRequest,
2727
RPCStartupRequest, RPCStartupResponse,
2828
RPCUProfileRequest)
29+
from vllm.engine.protocol import EngineClient
2930
# yapf: enable
3031
from vllm.envs import VLLM_RPC_TIMEOUT
31-
from vllm.inputs import PromptType, TokensPrompt
32+
from vllm.inputs import PromptType
3233
from vllm.logger import init_logger
3334
from vllm.lora.request import LoRARequest
34-
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
35-
RequestOutput)
35+
from vllm.model_executor.layers.sampler import SamplerOutput
36+
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
3637
from vllm.prompt_adapter.request import PromptAdapterRequest
37-
from vllm.sampling_params import BeamSearchParams, SamplingParams
38+
from vllm.sampling_params import SamplingParams
3839
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
39-
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
40-
random_uuid)
40+
from vllm.utils import deprecate_kwargs
4141

4242
logger = init_logger(__name__)
4343

@@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
5353
"""
5454

5555

56-
class MQLLMEngineClient:
56+
class MQLLMEngineClient(EngineClient):
5757
"""A client wrapper for MQLLMEngine that conforms to the
5858
EngineClient protocol.
5959
@@ -316,7 +316,7 @@ async def _check_success(error_message: str, socket: Socket):
316316
or response != VLLM_RPC_SUCCESS_STR):
317317
raise ValueError(error_message)
318318

319-
async def get_tokenizer(self, lora_request: LoRARequest):
319+
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
320320
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
321321

322322
async def get_decoding_config(self) -> DecodingConfig:
@@ -344,8 +344,14 @@ async def abort(self, request_id: str):
344344
await self._send_one_way_rpc_request(
345345
request=RPCAbortRequest(request_id), socket=self.input_socket)
346346

347-
async def do_log_stats(self):
348-
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
347+
async def do_log_stats(
348+
self,
349+
scheduler_outputs: Optional[SchedulerOutputs] = None,
350+
model_output: Optional[List[SamplerOutput]] = None,
351+
) -> None:
352+
"""
353+
Ignore do_log_stats (handled on MQLLMEngine polling)
354+
"""
349355
pass
350356

351357
async def check_health(self):
@@ -444,104 +450,6 @@ def generate(
444450
lora_request, trace_headers,
445451
prompt_adapter_request, priority)
446452

447-
async def beam_search(
448-
self,
449-
prompt: Union[PromptType, List[int]],
450-
request_id: str,
451-
params: BeamSearchParams,
452-
) -> AsyncGenerator[RequestOutput, None]:
453-
454-
beam_width = params.beam_width
455-
max_tokens = params.max_tokens
456-
ignore_eos = params.ignore_eos
457-
temperature = params.temperature
458-
length_penalty = params.length_penalty
459-
460-
tokenizer = await self.get_tokenizer(lora_request=None)
461-
tokenizedPrompt = prompt if isinstance(
462-
prompt, list) else tokenizer.encode(prompt)
463-
tokenizedLength = len(tokenizedPrompt)
464-
465-
sort_beams_key = create_sort_beams_key_function(
466-
tokenizer.eos_token_id, length_penalty)
467-
468-
beam_search_params = SamplingParams(logprobs=2 * beam_width,
469-
max_tokens=1,
470-
temperature=temperature)
471-
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
472-
completed = []
473-
474-
for _ in range(max_tokens):
475-
prompts_batch = [
476-
TokensPrompt(prompt_token_ids=beam.tokens)
477-
for beam in all_beams
478-
]
479-
480-
tasks = []
481-
482-
request_id = f"beam_search-{random_uuid()}"
483-
for i, individual_prompt in enumerate(prompts_batch):
484-
request_id_item = f"{request_id}-{i}"
485-
task = asyncio.create_task(
486-
collect_from_async_generator(
487-
self.generate(individual_prompt, beam_search_params,
488-
request_id_item)))
489-
tasks.append(task)
490-
491-
output = await asyncio.gather(*tasks)
492-
493-
output = [x[0] for x in output]
494-
495-
logger.info(output)
496-
497-
new_beams = []
498-
for i, current_beam in enumerate(all_beams):
499-
result = output[i]
500-
501-
if result.outputs[0].logprobs is not None:
502-
logprobs = result.outputs[0].logprobs[0]
503-
for token_id, logprob_obj in logprobs.items():
504-
new_beam = BeamSearchSequence(
505-
tokens=current_beam.tokens + [token_id],
506-
cum_logprob=current_beam.cum_logprob +
507-
logprob_obj.logprob)
508-
509-
if token_id == tokenizer.eos_token_id and \
510-
not ignore_eos:
511-
completed.append(new_beam)
512-
else:
513-
new_beams.append(new_beam)
514-
515-
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
516-
all_beams = sorted_beams[:beam_width]
517-
518-
completed.extend(all_beams)
519-
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
520-
best_beams = sorted_completed[:beam_width]
521-
522-
for beam in best_beams:
523-
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
524-
525-
beam_search_output = RequestOutput(
526-
request_id=request_id,
527-
prompt=prompt,
528-
outputs=[
529-
CompletionOutput(
530-
text=beam.text,
531-
cumulative_logprob=beam.cum_logprob,
532-
token_ids=beam.tokens,
533-
index=i,
534-
logprobs=beam.cum_logprob,
535-
) for (i, beam) in enumerate(best_beams)
536-
],
537-
finished=True,
538-
prompt_token_ids=tokenizedPrompt,
539-
prompt_logprobs=None)
540-
541-
logger.info(beam_search_output)
542-
543-
yield beam_search_output
544-
545453
@overload # DEPRECATED
546454
def encode(
547455
self,

0 commit comments

Comments
 (0)