Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RPCProcessRequest:
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0

@overload # DEPRECATED
def __init__(
Expand All @@ -41,6 +42,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...

Expand All @@ -53,6 +55,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...

Expand All @@ -68,6 +71,7 @@ def __init__(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
Expand All @@ -84,6 +88,7 @@ def __init__(
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority


@dataclass
Expand Down
15 changes: 12 additions & 3 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...

Expand All @@ -390,6 +391,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
...

Expand All @@ -405,6 +407,7 @@ def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
Expand All @@ -423,6 +426,9 @@ def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if inputs is not None:
prompt = inputs
Expand All @@ -431,7 +437,7 @@ def generate(

return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
prompt_adapter_request, priority)

@overload # DEPRECATED
def encode(
Expand Down Expand Up @@ -503,7 +509,8 @@ async def _process_request(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
Expand Down Expand Up @@ -536,7 +543,9 @@ async def _process_request(
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
prompt_adapter_request=prompt_adapter_request,
priority=priority,
))

# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def generate(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
Expand Down
12 changes: 12 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))

# doc: end-chat-completion-extra-params

Expand Down Expand Up @@ -534,6 +540,12 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))

# doc: end-completion-extra-params

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ async def create_chat_completion(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ async def create_completion(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)

generators.append(generator)
Expand Down