1111from zmq import Frame # type: ignore[attr-defined]
1212from zmq .asyncio import Socket
1313
14+ from vllm import PoolingParams
1415from vllm .config import DecodingConfig , EngineConfig , ModelConfig
1516from vllm .engine .arg_utils import AsyncEngineArgs
1617# yapf conflicts with isort for this block
1920 IPC_HEALTH_EXT , IPC_INPUT_EXT ,
2021 IPC_OUTPUT_EXT , RPC_REQUEST_T ,
2122 VLLM_RPC_SUCCESS_STR , RPCAbortRequest ,
22- RPCError , RPCGenerateRequest ,
23- RPCHealthRequest , RPCStartupRequest ,
23+ RPCError , RPCHealthRequest ,
24+ RPCProcessRequest , RPCStartupRequest ,
2425 RPCStartupResponse )
2526# yapf: enable
2627from vllm .envs import VLLM_RPC_TIMEOUT
@@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
111112
112113 @staticmethod
113114 def is_unsupported_config (engine_args : AsyncEngineArgs ):
114- if engine_args .pipeline_parallel_size > 1 :
115- return True
116-
117- is_embedding = ModelConfig (
118- model = engine_args .model ,
119- revision = engine_args .revision ,
120- tokenizer = engine_args .model ,
121- tokenizer_mode = "auto" ,
122- trust_remote_code = engine_args .trust_remote_code ,
123- quantization = engine_args .quantization ,
124- seed = 0 ,
125- dtype = "auto" ).embedding_mode
126-
127- return is_embedding
115+ # Pipeline parallel not yet supported
116+ return engine_args .pipeline_parallel_size > 1
128117
129118 @contextmanager
130119 def get_data_socket (self ) -> Iterator [Socket ]:
@@ -382,12 +371,9 @@ def errored(self) -> bool:
382371
383372 @property
384373 def dead_error (self ) -> BaseException :
385- if self ._errored_with is not None :
386- return ENGINE_DEAD_ERROR (self ._errored_with )
387- else :
388- return ENGINE_DEAD_ERROR ()
374+ return ENGINE_DEAD_ERROR (self ._errored_with )
389375
390- async def generate (
376+ def generate (
391377 self ,
392378 inputs : PromptInputs ,
393379 sampling_params : SamplingParams ,
@@ -396,6 +382,67 @@ async def generate(
396382 trace_headers : Optional [Mapping [str , str ]] = None ,
397383 prompt_adapter_request : Optional [PromptAdapterRequest ] = None
398384 ) -> AsyncGenerator [RequestOutput , None ]:
385+ """Generate outputs for a request.
386+
387+ Generate outputs for a request. This method is a coroutine. It adds the
388+ request into the waiting queue of the LLMEngine and streams the outputs
389+ from the LLMEngine to the caller.
390+
391+ Args:
392+ inputs: The inputs to the LLM. See
393+ :class:`~vllm.inputs.PromptInputs`
394+ for more details about the format of each input.
395+ sampling_params: The sampling parameters of the request.
396+ request_id: The unique id of the request.
397+ lora_request: LoRA request to use for generation, if any.
398+ trace_headers: OpenTelemetry trace headers.
399+ prompt_adapter_request: Prompt Adapter request to use
400+ for generation, if any.
401+ """
402+ return self ._process_request (inputs , sampling_params , request_id ,
403+ lora_request , trace_headers ,
404+ prompt_adapter_request )
405+
406+ def encode (
407+ self ,
408+ inputs : PromptInputs ,
409+ pooling_params : PoolingParams ,
410+ request_id : str ,
411+ lora_request : Optional [LoRARequest ] = None ,
412+ trace_headers : Optional [Mapping [str , str ]] = None ,
413+ ) -> AsyncGenerator [EmbeddingRequestOutput , None ]:
414+ """Generate outputs for a request from an embedding model.
415+
416+ Generate outputs for a request. This method is a coroutine. It adds the
417+ request into the waiting queue of the LLMEngine and streams the outputs
418+ from the LLMEngine to the caller.
419+
420+ Args:
421+ inputs: The inputs to the LLM. See
422+ :class:`~vllm.inputs.PromptInputs`
423+ for more details about the format of each input.
424+ pooling_params: The pooling parameters of the request.
425+ request_id: The unique id of the request.
426+ lora_request: LoRA request to use for generation, if any.
427+ trace_headers: OpenTelemetry trace headers.
428+
429+ Yields:
430+ The output `EmbeddingRequestOutput` objects from the LLMEngine
431+ for the request.
432+ """
433+ return self ._process_request (inputs , pooling_params , request_id ,
434+ lora_request , trace_headers )
435+
436+ async def _process_request (
437+ self ,
438+ inputs : PromptInputs ,
439+ params : Union [SamplingParams , PoolingParams ],
440+ request_id : str ,
441+ lora_request : Optional [LoRARequest ] = None ,
442+ trace_headers : Optional [Mapping [str , str ]] = None ,
443+ prompt_adapter_request : Optional [PromptAdapterRequest ] = None
444+ ) -> Union [AsyncGenerator [RequestOutput , None ], AsyncGenerator [
445+ EmbeddingRequestOutput , None ]]:
399446 """Send an RPCGenerateRequest to the RPCServer and stream responses."""
400447
401448 # If already dead, error out.
@@ -410,19 +457,19 @@ async def generate(
410457 try :
411458 # 2) Detach logits processors so that they can be pickled
412459 # separately (may require cloudpickle which is slower)
413- if sampling_params .logits_processors :
460+ if isinstance ( params , SamplingParams ) and params .logits_processors :
414461 # Defensive shallow copy
415- sampling_params = copy .copy (sampling_params )
416- logits_processors = sampling_params .logits_processors
417- sampling_params .logits_processors = None
462+ params = copy .copy (params )
463+ logits_processors = params .logits_processors
464+ params .logits_processors = None
418465 lp_bytes = cloudpickle .dumps (logits_processors )
419466 else :
420467 lp_bytes = None
421468
422469 request_bytes = pickle .dumps (
423- RPCGenerateRequest (
470+ RPCProcessRequest (
424471 inputs = inputs ,
425- sampling_params = sampling_params ,
472+ params = params ,
426473 request_id = request_id ,
427474 lora_request = lora_request ,
428475 trace_headers = trace_headers ,
@@ -452,8 +499,3 @@ async def generate(
452499 await self .abort (request_id )
453500 finally :
454501 self .output_queues .pop (request_id )
455-
456- async def encode (self , * args ,
457- ** kwargs ) -> AsyncGenerator [EmbeddingRequestOutput , None ]:
458- raise NotImplementedError (
459- "Embeddings not supported with multiprocessing backend" )
0 commit comments