Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/entrypoints/openai/test_chat_logit_bias_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0

import openai
import pytest
import pytest_asyncio

from vllm.config import ModelConfig

from ...utils import RemoteOpenAIServer

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"


def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
task="auto",
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="bfloat16",
)
return config.get_vocab_size()


@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"1024",
"--enforce-eager",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
async def test_chat_logit_bias_valid(client):
"""Test that valid logit_bias values are accepted in chat completions."""
vocab_size = get_vocab_size(MODEL_NAME)
valid_token_id = vocab_size - 1

completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "Testing valid logit bias"
}],
max_tokens=5,
logit_bias={str(valid_token_id): 1.0},
)

assert completion.choices[0].message.content is not None


@pytest.mark.asyncio
async def test_chat_logit_bias_invalid(client):
"""Test that invalid logit_bias values are rejected in chat completions."""
vocab_size = get_vocab_size(MODEL_NAME)
invalid_token_id = vocab_size + 1

with pytest.raises(openai.BadRequestError) as excinfo:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "Testing invalid logit bias"
}],
max_tokens=5,
logit_bias={str(invalid_token_id): 1.0},
)

error = excinfo.value
error_message = str(error)

assert error.status_code == 400
assert str(invalid_token_id) in error_message
assert str(vocab_size) in error_message
21 changes: 21 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _validate_sampling_params(
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)

if params.allowed_token_ids is None:
return
Expand All @@ -87,6 +88,26 @@ def _validate_sampling_params(
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")

def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return

vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []

for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)

if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")

def _validate_supported_sampling_params(
self,
params: SamplingParams,
Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,19 @@ def apply_logits_bias(
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.

# Get vocabulary size from logits
vocab_size = logits.shape[-1]

for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
# Check token_id bounds to ensure within vocabulary
if token_id < 0 or token_id >= vocab_size:
raise ValueError(
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
Comment on lines +233 to +245
Copy link
Contributor

@afeldman-nm afeldman-nm May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @rymc question, why do we need out-of-vocab token id validation inside of the logits bias logits processor? It appears that this PR already added validation within the frontend (in processor.py); it appears that the _validate_logit_bias() method will be called for every new request that is submitted to both the sync and async engines. Would it be acceptable to remove this logit bias validation from the internals of the logit bias logits processor?

(I know this PR is merged already, I'm asking because #16728 emits this out-of-vocab token check from the logit bias logits processor but keeps the frontend validate_logit_bias() check`)

CC @njhill

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when I made the above comment I may have missed the fact that this PR adds the validation in both places. Given that, all that's needed is to remove the validation within the sampler / omit it from the vectorized impl that we'll move to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. The reason was because I wasn't confident there wasn't a path that ends up here without going through the frontend processor and I didn't want a crash. Given that it does, I'm happy to submit a new PR with the sampler validation removed. Let me know :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rymc thanks for the response :) most likely the omission of the redundant checks will be accomplished by #16728

logits[i, token_id] += bias
return logits

Expand Down