Skip to content

Commit 76515f3

Browse files
authored
[Frontend] Use MQLLMEngine for embeddings models too (#8584)
1 parent 855c8ae commit 76515f3

File tree

3 files changed

+90
-46
lines changed

3 files changed

+90
-46
lines changed

vllm/engine/multiprocessing/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from enum import Enum
33
from typing import List, Mapping, Optional, Union
44

5+
from vllm import PoolingParams
56
from vllm.inputs import PromptInputs
67
from vllm.lora.request import LoRARequest
78
from vllm.outputs import RequestOutput
@@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):
2122

2223

2324
@dataclass
24-
class RPCGenerateRequest:
25+
class RPCProcessRequest:
2526
inputs: PromptInputs
26-
sampling_params: SamplingParams
27+
params: Union[SamplingParams, PoolingParams]
2728
request_id: str
2829
lora_request: Optional[LoRARequest] = None
2930
trace_headers: Optional[Mapping[str, str]] = None
@@ -55,7 +56,7 @@ class RPCStartupResponse:
5556
tracing_enabled: bool
5657

5758

58-
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
59+
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
5960
RPCStartupRequest]
6061

6162
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

vllm/engine/multiprocessing/client.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from zmq import Frame # type: ignore[attr-defined]
1212
from zmq.asyncio import Socket
1313

14+
from vllm import PoolingParams
1415
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
1516
from vllm.engine.arg_utils import AsyncEngineArgs
1617
# yapf conflicts with isort for this block
@@ -19,8 +20,8 @@
1920
IPC_HEALTH_EXT, IPC_INPUT_EXT,
2021
IPC_OUTPUT_EXT, RPC_REQUEST_T,
2122
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
22-
RPCError, RPCGenerateRequest,
23-
RPCHealthRequest, RPCStartupRequest,
23+
RPCError, RPCHealthRequest,
24+
RPCProcessRequest, RPCStartupRequest,
2425
RPCStartupResponse)
2526
# yapf: enable
2627
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
111112

112113
@staticmethod
113114
def is_unsupported_config(engine_args: AsyncEngineArgs):
114-
if engine_args.pipeline_parallel_size > 1:
115-
return True
116-
117-
is_embedding = ModelConfig(
118-
model=engine_args.model,
119-
revision=engine_args.revision,
120-
tokenizer=engine_args.model,
121-
tokenizer_mode="auto",
122-
trust_remote_code=engine_args.trust_remote_code,
123-
quantization=engine_args.quantization,
124-
seed=0,
125-
dtype="auto").embedding_mode
126-
127-
return is_embedding
115+
# Pipeline parallel not yet supported
116+
return engine_args.pipeline_parallel_size > 1
128117

129118
@contextmanager
130119
def get_data_socket(self) -> Iterator[Socket]:
@@ -382,12 +371,9 @@ def errored(self) -> bool:
382371

383372
@property
384373
def dead_error(self) -> BaseException:
385-
if self._errored_with is not None:
386-
return ENGINE_DEAD_ERROR(self._errored_with)
387-
else:
388-
return ENGINE_DEAD_ERROR()
374+
return ENGINE_DEAD_ERROR(self._errored_with)
389375

390-
async def generate(
376+
def generate(
391377
self,
392378
inputs: PromptInputs,
393379
sampling_params: SamplingParams,
@@ -396,6 +382,67 @@ async def generate(
396382
trace_headers: Optional[Mapping[str, str]] = None,
397383
prompt_adapter_request: Optional[PromptAdapterRequest] = None
398384
) -> AsyncGenerator[RequestOutput, None]:
385+
"""Generate outputs for a request.
386+
387+
Generate outputs for a request. This method is a coroutine. It adds the
388+
request into the waiting queue of the LLMEngine and streams the outputs
389+
from the LLMEngine to the caller.
390+
391+
Args:
392+
inputs: The inputs to the LLM. See
393+
:class:`~vllm.inputs.PromptInputs`
394+
for more details about the format of each input.
395+
sampling_params: The sampling parameters of the request.
396+
request_id: The unique id of the request.
397+
lora_request: LoRA request to use for generation, if any.
398+
trace_headers: OpenTelemetry trace headers.
399+
prompt_adapter_request: Prompt Adapter request to use
400+
for generation, if any.
401+
"""
402+
return self._process_request(inputs, sampling_params, request_id,
403+
lora_request, trace_headers,
404+
prompt_adapter_request)
405+
406+
def encode(
407+
self,
408+
inputs: PromptInputs,
409+
pooling_params: PoolingParams,
410+
request_id: str,
411+
lora_request: Optional[LoRARequest] = None,
412+
trace_headers: Optional[Mapping[str, str]] = None,
413+
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
414+
"""Generate outputs for a request from an embedding model.
415+
416+
Generate outputs for a request. This method is a coroutine. It adds the
417+
request into the waiting queue of the LLMEngine and streams the outputs
418+
from the LLMEngine to the caller.
419+
420+
Args:
421+
inputs: The inputs to the LLM. See
422+
:class:`~vllm.inputs.PromptInputs`
423+
for more details about the format of each input.
424+
pooling_params: The pooling parameters of the request.
425+
request_id: The unique id of the request.
426+
lora_request: LoRA request to use for generation, if any.
427+
trace_headers: OpenTelemetry trace headers.
428+
429+
Yields:
430+
The output `EmbeddingRequestOutput` objects from the LLMEngine
431+
for the request.
432+
"""
433+
return self._process_request(inputs, pooling_params, request_id,
434+
lora_request, trace_headers)
435+
436+
async def _process_request(
437+
self,
438+
inputs: PromptInputs,
439+
params: Union[SamplingParams, PoolingParams],
440+
request_id: str,
441+
lora_request: Optional[LoRARequest] = None,
442+
trace_headers: Optional[Mapping[str, str]] = None,
443+
prompt_adapter_request: Optional[PromptAdapterRequest] = None
444+
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
445+
EmbeddingRequestOutput, None]]:
399446
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
400447

401448
# If already dead, error out.
@@ -410,19 +457,19 @@ async def generate(
410457
try:
411458
# 2) Detach logits processors so that they can be pickled
412459
# separately (may require cloudpickle which is slower)
413-
if sampling_params.logits_processors:
460+
if isinstance(params, SamplingParams) and params.logits_processors:
414461
# Defensive shallow copy
415-
sampling_params = copy.copy(sampling_params)
416-
logits_processors = sampling_params.logits_processors
417-
sampling_params.logits_processors = None
462+
params = copy.copy(params)
463+
logits_processors = params.logits_processors
464+
params.logits_processors = None
418465
lp_bytes = cloudpickle.dumps(logits_processors)
419466
else:
420467
lp_bytes = None
421468

422469
request_bytes = pickle.dumps(
423-
RPCGenerateRequest(
470+
RPCProcessRequest(
424471
inputs=inputs,
425-
sampling_params=sampling_params,
472+
params=params,
426473
request_id=request_id,
427474
lora_request=lora_request,
428475
trace_headers=trace_headers,
@@ -452,8 +499,3 @@ async def generate(
452499
await self.abort(request_id)
453500
finally:
454501
self.output_queues.pop(request_id)
455-
456-
async def encode(self, *args,
457-
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
458-
raise NotImplementedError(
459-
"Embeddings not supported with multiprocessing backend")

vllm/engine/multiprocessing/engine.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import cloudpickle
77
import zmq
88

9-
from vllm import AsyncEngineArgs, LLMEngine
9+
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
1010
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
1111
ParallelConfig, SchedulerConfig)
1212
# yapf conflicts with isort for this block
@@ -15,8 +15,8 @@
1515
IPC_HEALTH_EXT, IPC_INPUT_EXT,
1616
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
1717
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
18-
RPCError, RPCGenerateRequest,
19-
RPCHealthRequest, RPCStartupRequest,
18+
RPCError, RPCHealthRequest,
19+
RPCProcessRequest, RPCStartupRequest,
2020
RPCStartupResponse)
2121
# yapf: enable
2222
from vllm.logger import init_logger
@@ -39,8 +39,8 @@ class MQLLMEngine:
3939
in concurrnet manner. It runs a background loop and uses zeromq to
4040
receive new requests and stream outputs incrementally via ipc.
4141
42-
The :class:`LLMEngine.generate` is kicked off when a new
43-
RPCGenerateRequest is received by the input_socket.
42+
The :class:`LLMEngine` generate or encode process is kicked off when a new
43+
RPCProcessRequest is received by the input_socket.
4444
4545
The self.engine_loop checks the input_socket for new requests,
4646
adds them to the LLMEngine if there are any, calls the internal
@@ -213,12 +213,13 @@ def handle_new_input(self):
213213
frames = self.input_socket.recv_multipart(copy=False)
214214
request = pickle.loads(frames[0].buffer)
215215

216-
if isinstance(request, RPCGenerateRequest):
216+
if isinstance(request, RPCProcessRequest):
217217
if len(frames) > 1:
218218
# Use cloudpickle for logits processors
219+
assert isinstance(request.params, SamplingParams)
219220
lprocs = cloudpickle.loads(frames[1].buffer)
220-
request.sampling_params.logits_processors = lprocs
221-
self._handle_generate_request(request)
221+
request.params.logits_processors = lprocs
222+
self._handle_process_request(request)
222223
elif isinstance(request, RPCAbortRequest):
223224
self._handle_abort_request(request)
224225
elif isinstance(request, RPCHealthRequest):
@@ -231,8 +232,8 @@ def handle_new_input(self):
231232
self._send_unhealthy(e)
232233
raise e
233234

234-
def _handle_generate_request(self, request: RPCGenerateRequest):
235-
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
235+
def _handle_process_request(self, request: RPCProcessRequest):
236+
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
236237
request_id = request.request_id
237238

238239
if self._errored_with is not None:
@@ -245,7 +246,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest):
245246
self.engine.add_request(
246247
request_id=request_id,
247248
inputs=request.inputs,
248-
params=request.sampling_params,
249+
params=request.params,
249250
lora_request=request.lora_request,
250251
trace_headers=request.trace_headers,
251252
prompt_adapter_request=request.prompt_adapter_request)

0 commit comments

Comments
 (0)