Skip to content

Commit 5bbd56e

Browse files
hmellorFeiDaLI
authored andcommitted
Fix model name included in responses (vllm-project#24663)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 9f00890 commit 5bbd56e

File tree

10 files changed

+50
-74
lines changed

10 files changed

+50
-74
lines changed

tests/entrypoints/openai/test_chat.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import regex as re
1313
import requests
1414
import torch
15-
from openai import BadRequestError, OpenAI
15+
from openai import BadRequestError
1616

1717
from ...utils import RemoteOpenAIServer
1818

@@ -968,59 +968,6 @@ async def test_long_seed(client: openai.AsyncOpenAI):
968968
or "less_than_equal" in exc_info.value.message)
969969

970970

971-
@pytest.mark.asyncio
972-
async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
973-
url = f"http://localhost:{server.port}/v1/chat/completions"
974-
headers = {
975-
"Content-Type": "application/json",
976-
}
977-
data = {
978-
# model_name is avoided here.
979-
"messages": [{
980-
"role": "system",
981-
"content": "You are a helpful assistant."
982-
}, {
983-
"role": "user",
984-
"content": "what is 1+1?"
985-
}],
986-
"max_tokens":
987-
5
988-
}
989-
990-
response = requests.post(url, headers=headers, json=data)
991-
response_data = response.json()
992-
print(response_data)
993-
assert response_data.get("model") == MODEL_NAME
994-
choice = response_data.get("choices")[0]
995-
message = choice.get("message")
996-
assert message is not None
997-
content = message.get("content")
998-
assert content is not None
999-
assert len(content) > 0
1000-
1001-
1002-
@pytest.mark.asyncio
1003-
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
1004-
openai_api_key = "EMPTY"
1005-
openai_api_base = f"http://localhost:{server.port}/v1"
1006-
1007-
client = OpenAI(
1008-
api_key=openai_api_key,
1009-
base_url=openai_api_base,
1010-
)
1011-
messages = [
1012-
{
1013-
"role": "user",
1014-
"content": "Hello, vLLM!"
1015-
},
1016-
]
1017-
response = client.chat.completions.create(
1018-
model="", # empty string
1019-
messages=messages,
1020-
)
1021-
assert response.model == MODEL_NAME
1022-
1023-
1024971
@pytest.mark.asyncio
1025972
async def test_invocations(server: RemoteOpenAIServer,
1026973
client: openai.AsyncOpenAI):

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,12 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
213213

214214

215215
MODEL_NAME = "openai-community/gpt2"
216+
MODEL_NAME_SHORT = "gpt2"
216217
CHAT_TEMPLATE = "Dummy chat template for testing {}"
217-
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
218+
BASE_MODEL_PATHS = [
219+
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
220+
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT)
221+
]
218222

219223

220224
@dataclass
@@ -270,6 +274,42 @@ def test_async_serving_chat_init():
270274
assert serving_completion.chat_template == CHAT_TEMPLATE
271275

272276

277+
@pytest.mark.asyncio
278+
async def test_serving_chat_returns_correct_model_name():
279+
mock_engine = MagicMock(spec=MQLLMEngineClient)
280+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
281+
mock_engine.errored = False
282+
283+
models = OpenAIServingModels(engine_client=mock_engine,
284+
base_model_paths=BASE_MODEL_PATHS,
285+
model_config=MockModelConfig())
286+
serving_chat = OpenAIServingChat(mock_engine,
287+
MockModelConfig(),
288+
models,
289+
response_role="assistant",
290+
chat_template=CHAT_TEMPLATE,
291+
chat_template_content_format="auto",
292+
request_logger=None)
293+
messages = [{"role": "user", "content": "what is 1+1?"}]
294+
295+
async def return_model_name(*args):
296+
return args[3]
297+
298+
serving_chat.chat_completion_full_generator = return_model_name
299+
300+
# Test that full name is returned when short name is requested
301+
req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages)
302+
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
303+
304+
# Test that full name is returned when empty string is specified
305+
req = ChatCompletionRequest(model="", messages=messages)
306+
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
307+
308+
# Test that full name is returned when no model is specified
309+
req = ChatCompletionRequest(messages=messages)
310+
assert await serving_chat.create_chat_completion(req) == MODEL_NAME
311+
312+
273313
@pytest.mark.asyncio
274314
async def test_serving_chat_should_set_correct_max_tokens():
275315
mock_engine = MagicMock(spec=MQLLMEngineClient)

vllm/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def create_chat_completion(
186186
lora_request = self._maybe_get_adapters(
187187
request, supports_default_mm_loras=True)
188188

189-
model_name = self._get_model_name(request.model, lora_request)
189+
model_name = self.models.model_name(lora_request)
190190

191191
tokenizer = await self.engine_client.get_tokenizer(lora_request)
192192

vllm/entrypoints/openai/serving_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async def create_classify(
146146
request: ClassificationRequest,
147147
raw_request: Request,
148148
) -> Union[ClassificationResponse, ErrorResponse]:
149-
model_name = self._get_model_name(request.model)
149+
model_name = self.models.model_name()
150150
request_id = (f"{self.request_id_prefix}-"
151151
f"{self._base_request_id(raw_request)}")
152152

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def create_completion(
232232

233233
result_generator = merge_async_iterators(*generators)
234234

235-
model_name = self._get_model_name(request.model, lora_request)
235+
model_name = self.models.model_name(lora_request)
236236
num_prompts = len(engine_prompts)
237237

238238
# Similar to the OpenAI API, when n != best_of, we do not stream the

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ async def create_embedding(
599599
See https://platform.openai.com/docs/api-reference/embeddings/create
600600
for the API specification. This API mimics the OpenAI Embedding API.
601601
"""
602-
model_name = self._get_model_name(request.model)
602+
model_name = self.models.model_name()
603603
request_id = (
604604
f"{self.request_id_prefix}-"
605605
f"{self._base_request_id(raw_request, request.request_id)}")

vllm/entrypoints/openai/serving_engine.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -980,17 +980,6 @@ def _is_model_supported(self, model_name: Optional[str]) -> bool:
980980
return True
981981
return self.models.is_base_model(model_name)
982982

983-
def _get_model_name(
984-
self,
985-
model_name: Optional[str] = None,
986-
lora_request: Optional[LoRARequest] = None,
987-
) -> str:
988-
if lora_request:
989-
return lora_request.lora_name
990-
if not model_name:
991-
return self.models.base_model_paths[0].name
992-
return model_name
993-
994983

995984
def clamp_prompt_logprobs(
996985
prompt_logprobs: Union[PromptLogprobs,

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def create_pooling(
9191
if error_check_ret is not None:
9292
return error_check_ret
9393

94-
model_name = self._get_model_name(request.model)
94+
model_name = self.models.model_name()
9595

9696
request_id = f"pool-{self._base_request_id(raw_request)}"
9797
created_time = int(time.time())

vllm/entrypoints/openai/serving_responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ async def create_responses(
237237

238238
try:
239239
lora_request = self._maybe_get_adapters(request)
240-
model_name = self._get_model_name(request.model, lora_request)
240+
model_name = self.models.model_name(lora_request)
241241
tokenizer = await self.engine_client.get_tokenizer(lora_request)
242242

243243
if self.use_harmony:

vllm/entrypoints/openai/serving_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ async def create_score(
353353
final_res_batch,
354354
request_id,
355355
created_time,
356-
self._get_model_name(request.model),
356+
self.models.model_name(),
357357
)
358358
except asyncio.CancelledError:
359359
return self.create_error_response("Client disconnected")
@@ -399,7 +399,7 @@ async def do_rerank(
399399
return self.request_output_to_rerank_response(
400400
final_res_batch,
401401
request_id,
402-
self._get_model_name(request.model),
402+
self.models.model_name(),
403403
documents,
404404
top_n,
405405
)

0 commit comments

Comments
 (0)