Skip to content

Commit 0ffdf8c

Browse files
authored
[HTTP Server] Make model param optional in request (#13568)
1 parent 8c0dd3d commit 0ffdf8c

File tree

9 files changed

+61
-18
lines changed

9 files changed

+61
-18
lines changed

tests/entrypoints/openai/test_chat.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import openai # use the official client for correctness check
1010
import pytest
1111
import pytest_asyncio
12+
import requests
1213
import torch
1314
from openai import BadRequestError
1415

@@ -996,3 +997,34 @@ async def test_long_seed(client: openai.AsyncOpenAI):
996997

997998
assert ("greater_than_equal" in exc_info.value.message
998999
or "less_than_equal" in exc_info.value.message)
1000+
1001+
1002+
@pytest.mark.asyncio
1003+
async def test_http_chat_wo_model_name(server: RemoteOpenAIServer):
1004+
url = f"http://localhost:{server.port}/v1/chat/completions"
1005+
headers = {
1006+
"Content-Type": "application/json",
1007+
}
1008+
data = {
1009+
# model_name is avoided here.
1010+
"messages": [{
1011+
"role": "system",
1012+
"content": "You are a helpful assistant."
1013+
}, {
1014+
"role": "user",
1015+
"content": "what is 1+1?"
1016+
}],
1017+
"max_tokens":
1018+
5
1019+
}
1020+
1021+
response = requests.post(url, headers=headers, json=data)
1022+
response_data = response.json()
1023+
print(response_data)
1024+
1025+
choice = response_data.get("choices")[0]
1026+
message = choice.get("message")
1027+
assert message is not None
1028+
content = message.get("content")
1029+
assert content is not None
1030+
assert len(content) > 0

vllm/entrypoints/openai/protocol.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
213213
# Ordered by official OpenAI API documentation
214214
# https://platform.openai.com/docs/api-reference/chat/create
215215
messages: List[ChatCompletionMessageParam]
216-
model: str
216+
model: Optional[str] = None
217217
frequency_penalty: Optional[float] = 0.0
218218
logit_bias: Optional[Dict[str, float]] = None
219219
logprobs: Optional[bool] = False
@@ -642,7 +642,7 @@ def check_generation_prompt(cls, data):
642642
class CompletionRequest(OpenAIBaseModel):
643643
# Ordered by official OpenAI API documentation
644644
# https://platform.openai.com/docs/api-reference/completions/create
645-
model: str
645+
model: Optional[str] = None
646646
prompt: Union[List[int], List[List[int]], str, List[str]]
647647
best_of: Optional[int] = None
648648
echo: Optional[bool] = False
@@ -907,7 +907,7 @@ def validate_stream_options(cls, data):
907907
class EmbeddingCompletionRequest(OpenAIBaseModel):
908908
# Ordered by official OpenAI API documentation
909909
# https://platform.openai.com/docs/api-reference/embeddings
910-
model: str
910+
model: Optional[str] = None
911911
input: Union[List[int], List[List[int]], str, List[str]]
912912
encoding_format: Literal["float", "base64"] = "float"
913913
dimensions: Optional[int] = None
@@ -939,7 +939,7 @@ def to_pooling_params(self):
939939

940940

941941
class EmbeddingChatRequest(OpenAIBaseModel):
942-
model: str
942+
model: Optional[str] = None
943943
messages: List[ChatCompletionMessageParam]
944944

945945
encoding_format: Literal["float", "base64"] = "float"
@@ -1007,7 +1007,7 @@ def to_pooling_params(self):
10071007

10081008

10091009
class ScoreRequest(OpenAIBaseModel):
1010-
model: str
1010+
model: Optional[str] = None
10111011
text_1: Union[List[str], str]
10121012
text_2: Union[List[str], str]
10131013
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
@@ -1031,7 +1031,7 @@ def to_pooling_params(self):
10311031

10321032

10331033
class RerankRequest(OpenAIBaseModel):
1034-
model: str
1034+
model: Optional[str] = None
10351035
query: str
10361036
documents: List[str]
10371037
top_n: int = Field(default_factory=lambda: 0)
@@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel):
13451345

13461346

13471347
class TokenizeCompletionRequest(OpenAIBaseModel):
1348-
model: str
1348+
model: Optional[str] = None
13491349
prompt: str
13501350

13511351
add_special_tokens: bool = Field(
@@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
13571357

13581358

13591359
class TokenizeChatRequest(OpenAIBaseModel):
1360-
model: str
1360+
model: Optional[str] = None
13611361
messages: List[ChatCompletionMessageParam]
13621362

13631363
add_generation_prompt: bool = Field(
@@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel):
14231423

14241424

14251425
class DetokenizeRequest(OpenAIBaseModel):
1426-
model: str
1426+
model: Optional[str] = None
14271427
tokens: List[int]
14281428

14291429

@@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel):
14561456
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
14571457
"""
14581458

1459-
model: str
1459+
model: Optional[str] = None
14601460
"""ID of the model to use.
14611461
"""
14621462

vllm/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ async def create_chat_completion(
141141
prompt_adapter_request,
142142
) = self._maybe_get_adapters(request)
143143

144-
model_name = self.models.model_name(lora_request)
144+
model_name = self._get_model_name(request.model, lora_request)
145145

146146
tokenizer = await self.engine_client.get_tokenizer(lora_request)
147147

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ async def create_completion(
166166

167167
result_generator = merge_async_iterators(*generators)
168168

169-
model_name = self.models.model_name(lora_request)
169+
model_name = self._get_model_name(request.model, lora_request)
170170
num_prompts = len(engine_prompts)
171171

172172
# Similar to the OpenAI API, when n != best_of, we do not stream the

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def create_embedding(
8383
return self.create_error_response(
8484
"dimensions is currently not supported")
8585

86-
model_name = request.model
86+
model_name = self._get_model_name(request.model)
8787
request_id = f"embd-{self._base_request_id(raw_request)}"
8888
created_time = int(time.time())
8989

vllm/entrypoints/openai/serving_engine.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,5 +523,16 @@ def _get_decoded_token(logprob: Logprob,
523523
return logprob.decoded_token
524524
return tokenizer.decode(token_id)
525525

526-
def _is_model_supported(self, model_name):
526+
def _is_model_supported(self, model_name) -> bool:
527+
if not model_name:
528+
return True
527529
return self.models.is_base_model(model_name)
530+
531+
def _get_model_name(self,
532+
model_name: Optional[str] = None,
533+
lora_request: Optional[LoRARequest] = None) -> str:
534+
if lora_request:
535+
return lora_request.lora_name
536+
if model_name is None:
537+
return self.models.base_model_paths[0].name
538+
return model_name

vllm/entrypoints/openai/serving_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ async def init_static_loras(self):
9595
if isinstance(load_result, ErrorResponse):
9696
raise ValueError(load_result.message)
9797

98-
def is_base_model(self, model_name):
98+
def is_base_model(self, model_name) -> bool:
9999
return any(model.name == model_name for model in self.base_model_paths)
100100

101101
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def create_pooling(
7979
return self.create_error_response(
8080
"dimensions is currently not supported")
8181

82-
model_name = request.model
82+
model_name = self._get_model_name(request.model)
8383
request_id = f"pool-{self._base_request_id(raw_request)}"
8484
created_time = int(time.time())
8585

vllm/entrypoints/openai/serving_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ async def create_score(
318318
final_res_batch,
319319
request_id,
320320
created_time,
321-
request.model,
321+
self._get_model_name(request.model),
322322
)
323323
except asyncio.CancelledError:
324324
return self.create_error_response("Client disconnected")
@@ -358,7 +358,7 @@ async def do_rerank(
358358
request.truncate_prompt_tokens,
359359
)
360360
return self.request_output_to_rerank_response(
361-
final_res_batch, request_id, request.model, documents, top_n)
361+
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
362362
except asyncio.CancelledError:
363363
return self.create_error_response("Client disconnected")
364364
except ValueError as e:

0 commit comments

Comments
 (0)