Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit ed71c6b

Browse files
br3norshaw@neuralmagic.com
authored andcommitted
[BUGFIX] [FRONTEND] Correct chat logprobs (vllm-project#5029)
Co-authored-by: Breno Faria <[email protected]>
1 parent f3fdfff commit ed71c6b

File tree

6 files changed

+361
-98
lines changed

6 files changed

+361
-98
lines changed

tests/async_engine/test_openapi_server_ray.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
9494
chat_completion.choices) == 1
9595
assert chat_completion.choices[0].message is not None
9696
assert chat_completion.choices[0].logprobs is not None
97-
assert chat_completion.choices[0].logprobs.top_logprobs is not None
98-
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
97+
assert chat_completion.choices[0].logprobs.content[
98+
0].top_logprobs is not None
99+
assert len(
100+
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
99101
message = chat_completion.choices[0].message
100102
assert message.content is not None and len(message.content) >= 10
101103
assert message.role == "assistant"

tests/entrypoints/test_openai_server.py

Lines changed: 181 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
183183
completion.choices[0].text) >= 5
184184

185185

186+
@pytest.mark.asyncio
187+
@pytest.mark.parametrize(
188+
# first test base model, then test loras
189+
"model_name",
190+
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
191+
)
192+
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
193+
model_name: str):
194+
# test using token IDs
195+
completion = await client.completions.create(
196+
model=MODEL_NAME,
197+
prompt=[0, 0, 0, 0, 0],
198+
max_tokens=5,
199+
temperature=0.0,
200+
logprobs=None,
201+
)
202+
choice = completion.choices[0]
203+
assert choice.logprobs is None
204+
205+
186206
@pytest.mark.asyncio
187207
@pytest.mark.parametrize(
188208
# first test base model, then test loras
@@ -202,7 +222,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
202222
choice = completion.choices[0]
203223
assert choice.logprobs is not None
204224
assert choice.logprobs.token_logprobs is not None
205-
assert choice.logprobs.top_logprobs is None
225+
assert choice.logprobs.top_logprobs is not None
226+
assert len(choice.logprobs.top_logprobs[0]) <= 1
227+
228+
229+
@pytest.mark.asyncio
230+
@pytest.mark.parametrize(
231+
"model_name",
232+
[MODEL_NAME, "zephyr-lora"],
233+
)
234+
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
235+
model_name: str):
236+
# test using token IDs
237+
completion = await client.completions.create(
238+
model=MODEL_NAME,
239+
prompt=[0, 0, 0, 0, 0],
240+
max_tokens=5,
241+
temperature=0.0,
242+
logprobs=5,
243+
)
244+
choice = completion.choices[0]
245+
assert choice.logprobs is not None
246+
assert choice.logprobs.token_logprobs is not None
247+
assert choice.logprobs.top_logprobs is not None
248+
assert len(choice.logprobs.top_logprobs[0]) <= 6
249+
250+
251+
@pytest.mark.asyncio
252+
@pytest.mark.parametrize(
253+
"model_name",
254+
[MODEL_NAME, "zephyr-lora"],
255+
)
256+
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
257+
model_name: str):
258+
259+
with pytest.raises(
260+
(openai.BadRequestError, openai.APIError)): # test using token IDs
261+
await client.completions.create(
262+
model=MODEL_NAME,
263+
prompt=[0, 0, 0, 0, 0],
264+
max_tokens=5,
265+
temperature=0.0,
266+
logprobs=6,
267+
)
268+
...
269+
with pytest.raises(
270+
(openai.BadRequestError, openai.APIError)): # test using token IDs
271+
stream = await client.completions.create(
272+
model=MODEL_NAME,
273+
prompt=[0, 0, 0, 0, 0],
274+
max_tokens=5,
275+
temperature=0.0,
276+
logprobs=6,
277+
stream=True,
278+
)
279+
async for chunk in stream:
280+
...
281+
282+
# the server should still work afterwards
283+
completion = await client.completions.create(
284+
model=model_name,
285+
prompt=[0, 0, 0, 0, 0],
286+
max_tokens=5,
287+
temperature=0.0,
288+
)
289+
completion = completion.choices[0].text
290+
assert completion is not None and len(completion) >= 0
206291

207292

208293
@pytest.mark.asyncio
@@ -232,8 +317,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
232317
chat_completion.choices) == 1
233318
assert chat_completion.choices[0].message is not None
234319
assert chat_completion.choices[0].logprobs is not None
235-
assert chat_completion.choices[0].logprobs.top_logprobs is not None
236-
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
320+
assert chat_completion.choices[0].logprobs.content[
321+
0].top_logprobs is not None
322+
assert len(
323+
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
237324
message = chat_completion.choices[0].message
238325
assert message.content is not None and len(message.content) >= 10
239326
assert message.role == "assistant"
@@ -250,10 +337,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
250337
assert message.content is not None and len(message.content) >= 0
251338

252339

340+
@pytest.mark.asyncio
341+
@pytest.mark.parametrize(
342+
# first test base model, then test loras
343+
"model_name",
344+
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
345+
)
346+
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
347+
model_name: str):
348+
messages = [{
349+
"role": "system",
350+
"content": "you are a helpful assistant"
351+
}, {
352+
"role": "user",
353+
"content": "what is 1+1?"
354+
}]
355+
356+
chat_completion = await client.chat.completions.create(model=model_name,
357+
messages=messages,
358+
max_tokens=5,
359+
temperature=0.0,
360+
logprobs=False)
361+
362+
choice = chat_completion.choices[0]
363+
assert choice.logprobs is None
364+
365+
366+
@pytest.mark.asyncio
367+
@pytest.mark.parametrize(
368+
# just test 1 lora hereafter
369+
"model_name",
370+
[MODEL_NAME, "zephyr-lora"],
371+
)
372+
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
373+
model_name: str):
374+
messages = [{
375+
"role": "system",
376+
"content": "you are a helpful assistant"
377+
}, {
378+
"role": "user",
379+
"content": "what is 1+1?"
380+
}]
381+
382+
chat_completion = await client.chat.completions.create(model=model_name,
383+
messages=messages,
384+
max_tokens=5,
385+
temperature=0.0,
386+
logprobs=True,
387+
top_logprobs=0)
388+
389+
choice = chat_completion.choices[0]
390+
assert choice.logprobs is not None
391+
assert choice.logprobs.content is not None
392+
assert len(choice.logprobs.content[0].top_logprobs) <= 1
393+
394+
395+
@pytest.mark.asyncio
396+
@pytest.mark.parametrize(
397+
"model_name",
398+
[MODEL_NAME, "zephyr-lora"],
399+
)
400+
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
401+
model_name: str):
402+
messages = [{
403+
"role": "system",
404+
"content": "you are a helpful assistant"
405+
}, {
406+
"role": "user",
407+
"content": "what is 1+1?"
408+
}]
409+
410+
chat_completion = await client.chat.completions.create(model=model_name,
411+
messages=messages,
412+
max_tokens=5,
413+
temperature=0.0,
414+
logprobs=True,
415+
top_logprobs=5)
416+
417+
choice = chat_completion.choices[0]
418+
assert choice.logprobs is not None
419+
assert choice.logprobs.content is not None
420+
assert len(choice.logprobs.content[0].top_logprobs) <= 6
421+
422+
253423
@pytest.mark.asyncio
254424
@pytest.mark.parametrize("model_name", [MODEL_NAME])
255-
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
256-
model_name: str):
425+
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
426+
model_name: str):
257427
messages = [{
258428
"role": "system",
259429
"content": "you are a helpful assistant"
@@ -262,13 +432,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
262432
"content": "what is 1+1?"
263433
}]
264434

265-
# Default max_logprobs is 5, so this should raise an error
435+
# Default max_logprobs is 20, so this should raise an error
266436
with pytest.raises((openai.BadRequestError, openai.APIError)):
267437
stream = await client.chat.completions.create(model=model_name,
268438
messages=messages,
269439
max_tokens=10,
270440
logprobs=True,
271-
top_logprobs=10,
441+
top_logprobs=21,
272442
stream=True)
273443
async for chunk in stream:
274444
...
@@ -278,25 +448,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
278448
messages=messages,
279449
max_tokens=10,
280450
logprobs=True,
281-
top_logprobs=10,
451+
top_logprobs=30,
282452
stream=False)
283453

284-
with pytest.raises((openai.BadRequestError, openai.APIError)):
285-
stream = await client.completions.create(model=model_name,
286-
prompt="Test",
287-
max_tokens=10,
288-
logprobs=10,
289-
stream=True)
290-
async for chunk in stream:
291-
...
292-
293-
with pytest.raises(openai.BadRequestError):
294-
await client.completions.create(model=model_name,
295-
prompt="Test",
296-
max_tokens=10,
297-
logprobs=10,
298-
stream=False)
299-
300454
# the server should still work afterwards
301455
chat_completion = await client.chat.completions.create(model=model_name,
302456
messages=messages,
@@ -743,13 +897,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
743897
top_logprobs=5,
744898
extra_body=dict(guided_choice=TEST_CHOICE,
745899
guided_decoding_backend=guided_decoding_backend))
746-
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
900+
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
747901

748902
# -9999.0 is the minimum logprob returned by OpenAI
749903
assert all(
750-
isinstance(logprob, float) and logprob >= -9999.0
751-
for token_dict in top_logprobs
752-
for token, logprob in token_dict.items())
904+
isinstance(token.logprob, float) and token.logprob >= -9999.0
905+
for token in top_logprobs)
753906

754907

755908
@pytest.mark.asyncio

vllm/entrypoints/openai/protocol.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,19 @@ def check_guided_decoding_count(cls, data):
250250
"('guided_json', 'guided_regex' or 'guided_choice').")
251251
return data
252252

253+
@model_validator(mode="before")
254+
@classmethod
255+
def check_logprobs(cls, data):
256+
if "top_logprobs" in data and data["top_logprobs"] is not None:
257+
if "logprobs" not in data or data["logprobs"] is False:
258+
raise ValueError(
259+
"when using `top_logprobs`, `logprobs` must be set to true."
260+
)
261+
elif not 0 <= data["top_logprobs"] <= 20:
262+
raise ValueError(
263+
"`top_logprobs` must be a value in the interval [0, 20].")
264+
return data
265+
253266

254267
class CompletionRequest(OpenAIBaseModel):
255268
# Ordered by official OpenAI API documentation
@@ -396,6 +409,15 @@ def check_guided_decoding_count(cls, data):
396409
"('guided_json', 'guided_regex' or 'guided_choice').")
397410
return data
398411

412+
@model_validator(mode="before")
413+
@classmethod
414+
def check_logprobs(cls, data):
415+
if "logprobs" in data and data[
416+
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
417+
raise ValueError(("if passed, `logprobs` must be a value",
418+
" in the interval [0, 5]."))
419+
return data
420+
399421

400422
class EmbeddingRequest(BaseModel):
401423
# Ordered by official OpenAI API documentation
@@ -415,7 +437,7 @@ def to_pooling_params(self):
415437
return PoolingParams(additional_data=self.additional_data)
416438

417439

418-
class LogProbs(OpenAIBaseModel):
440+
class CompletionLogProbs(OpenAIBaseModel):
419441
text_offset: List[int] = Field(default_factory=list)
420442
token_logprobs: List[Optional[float]] = Field(default_factory=list)
421443
tokens: List[str] = Field(default_factory=list)
@@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
425447
class CompletionResponseChoice(OpenAIBaseModel):
426448
index: int
427449
text: str
428-
logprobs: Optional[LogProbs] = None
450+
logprobs: Optional[CompletionLogProbs] = None
429451
finish_reason: Optional[str] = None
430452
stop_reason: Optional[Union[int, str]] = Field(
431453
default=None,
@@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
448470
class CompletionResponseStreamChoice(OpenAIBaseModel):
449471
index: int
450472
text: str
451-
logprobs: Optional[LogProbs] = None
473+
logprobs: Optional[CompletionLogProbs] = None
452474
finish_reason: Optional[str] = None
453475
stop_reason: Optional[Union[int, str]] = Field(
454476
default=None,
@@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
488510
content: str
489511

490512

513+
class ChatCompletionLogProb(OpenAIBaseModel):
514+
token: str
515+
logprob: float = -9999.0
516+
bytes: Optional[List[int]] = None
517+
518+
519+
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
520+
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
521+
522+
523+
class ChatCompletionLogProbs(OpenAIBaseModel):
524+
content: Optional[List[ChatCompletionLogProbsContent]] = None
525+
526+
491527
class ChatCompletionResponseChoice(OpenAIBaseModel):
492528
index: int
493529
message: ChatMessage
494-
logprobs: Optional[LogProbs] = None
495-
finish_reason: Optional[str] = None
530+
logprobs: Optional[ChatCompletionLogProbs] = None
531+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
496532
stop_reason: Optional[Union[int, str]] = None
497533

498534

@@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
513549
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
514550
index: int
515551
delta: DeltaMessage
516-
logprobs: Optional[LogProbs] = None
517-
finish_reason: Optional[str] = None
552+
logprobs: Optional[ChatCompletionLogProbs] = None
553+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
518554
stop_reason: Optional[Union[int, str]] = None
519555

520556

0 commit comments

Comments
 (0)