Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
assert chat_completion.choices[0].logprobs.content[
0].top_logprobs is not None
assert len(
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand Down
209 changes: 181 additions & 28 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
Expand All @@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):

with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
stream=True,
)
async for chunk in stream:
...

# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0


@pytest.mark.asyncio
Expand Down Expand Up @@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
assert chat_completion.choices[0].logprobs.content[
0].top_logprobs is not None
assert len(
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand All @@ -251,10 +338,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=False)

choice = chat_completion.choices[0]
assert choice.logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=True,
top_logprobs=0)

choice = chat_completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.content is not None
assert len(choice.logprobs.content[0].top_logprobs) <= 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=True,
top_logprobs=5)

choice = chat_completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.content is not None
assert len(choice.logprobs.content[0].top_logprobs) <= 6


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -263,13 +433,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
"content": "what is 1+1?"
}]

# Default max_logprobs is 5, so this should raise an error
# Default max_logprobs is 20, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
top_logprobs=21,
stream=True)
async for chunk in stream:
...
Expand All @@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
top_logprobs=30,
stream=False)

with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...

with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)

# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
Expand Down Expand Up @@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs

# -9999.0 is the minimum logprob returned by OpenAI
assert all(
isinstance(logprob, float) and logprob >= -9999.0
for token_dict in top_logprobs
for token, logprob in token_dict.items())
isinstance(token.logprob, float) and token.logprob >= -9999.0
for token in top_logprobs)


@pytest.mark.asyncio
Expand Down
52 changes: 43 additions & 9 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params

def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")

logits_processors = None
if self.logit_bias:
Expand Down Expand Up @@ -251,6 +249,19 @@ def check_guided_decoding_count(cls, data):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data

@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif not 0 <= data["top_logprobs"] <= 20:
raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].")
return data


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
Expand Down Expand Up @@ -397,6 +408,15 @@ def check_guided_decoding_count(cls, data):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data

@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
raise ValueError(("if passed, `logprobs` must be a value",
" in the interval [0, 5]."))
return data


class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
Expand All @@ -416,7 +436,7 @@ def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)


class LogProbs(OpenAIBaseModel):
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
Expand All @@ -426,7 +446,7 @@ class LogProbs(OpenAIBaseModel):
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
Expand All @@ -449,7 +469,7 @@ class CompletionResponse(OpenAIBaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
Expand Down Expand Up @@ -489,11 +509,25 @@ class ChatMessage(OpenAIBaseModel):
content: str


class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)


class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None


class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None


Expand All @@ -514,8 +548,8 @@ class DeltaMessage(OpenAIBaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None


Expand Down
Loading