Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
import torch
from openai import BadRequestError

Expand Down Expand Up @@ -996,3 +997,34 @@ async def test_long_seed(client: openai.AsyncOpenAI):

assert ("greater_than_equal" in exc_info.value.message
or "less_than_equal" in exc_info.value.message)


@pytest.mark.asyncio
async def test_http_chat_wo_model_name(server: RemoteOpenAIServer):
url = f"http://localhost:{server.port}/v1/chat/completions"
headers = {
"Content-Type": "application/json",
}
data = {
# model_name is avoided here.
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "what is 1+1?"
}],
"max_tokens":
5
}

response = requests.post(url, headers=headers, json=data)
response_data = response.json()
print(response_data)

choice = response_data.get("choices")[0]
message = choice.get("message")
assert message is not None
content = message.get("content")
assert content is not None
assert len(content) > 0
20 changes: 10 additions & 10 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
model: Optional[str] = None
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
Expand Down Expand Up @@ -642,7 +642,7 @@ def check_generation_prompt(cls, data):
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
model: Optional[str] = None
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
Expand Down Expand Up @@ -907,7 +907,7 @@ def validate_stream_options(cls, data):
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
model: Optional[str] = None
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
Expand Down Expand Up @@ -939,7 +939,7 @@ def to_pooling_params(self):


class EmbeddingChatRequest(OpenAIBaseModel):
model: str
model: Optional[str] = None
messages: List[ChatCompletionMessageParam]

encoding_format: Literal["float", "base64"] = "float"
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def to_pooling_params(self):


class ScoreRequest(OpenAIBaseModel):
model: str
model: Optional[str] = None
text_1: Union[List[str], str]
text_2: Union[List[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
Expand All @@ -1031,7 +1031,7 @@ def to_pooling_params(self):


class RerankRequest(OpenAIBaseModel):
model: str
model: Optional[str] = None
query: str
documents: List[str]
top_n: int = Field(default_factory=lambda: 0)
Expand Down Expand Up @@ -1345,7 +1345,7 @@ class BatchRequestOutput(OpenAIBaseModel):


class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
model: Optional[str] = None
prompt: str

add_special_tokens: bool = Field(
Expand All @@ -1357,7 +1357,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel):


class TokenizeChatRequest(OpenAIBaseModel):
model: str
model: Optional[str] = None
messages: List[ChatCompletionMessageParam]

add_generation_prompt: bool = Field(
Expand Down Expand Up @@ -1423,7 +1423,7 @@ class TokenizeResponse(OpenAIBaseModel):


class DetokenizeRequest(OpenAIBaseModel):
model: str
model: Optional[str] = None
tokens: List[int]


Expand Down Expand Up @@ -1456,7 +1456,7 @@ class TranscriptionRequest(OpenAIBaseModel):
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""

model: str
model: Optional[str] = None
"""ID of the model to use.
"""

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def create_chat_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

model_name = self.models.model_name(lora_request)
model_name = self._get_model_name(request.model, lora_request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)

Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def create_completion(

result_generator = merge_async_iterators(*generators)

model_name = self.models.model_name(lora_request)
model_name = self._get_model_name(request.model, lora_request)
num_prompts = len(engine_prompts)

# Similar to the OpenAI API, when n != best_of, we do not stream the
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def create_embedding(
return self.create_error_response(
"dimensions is currently not supported")

model_name = request.model
model_name = self._get_model_name(request.model)
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.time())

Expand Down
13 changes: 12 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,5 +523,16 @@ def _get_decoded_token(logprob: Logprob,
return logprob.decoded_token
return tokenizer.decode(token_id)

def _is_model_supported(self, model_name):
def _is_model_supported(self, model_name) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)

def _get_model_name(self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> str:
if lora_request:
return lora_request.lora_name
if model_name is None:
return self.models.base_model_paths[0].name
return model_name
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def init_static_loras(self):
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message)

def is_base_model(self, model_name):
def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)

def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def create_pooling(
return self.create_error_response(
"dimensions is currently not supported")

model_name = request.model
model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())

Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ async def create_score(
final_res_batch,
request_id,
created_time,
request.model,
self._get_model_name(request.model),
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
Expand Down Expand Up @@ -358,7 +358,7 @@ async def do_rerank(
request.truncate_prompt_tokens,
)
return self.request_output_to_rerank_response(
final_res_batch, request_id, request.model, documents, top_n)
final_res_batch, request_id, self._get_model_name(request.model), documents, top_n)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
Expand Down