Skip to content

Commit be7370d

Browse files
alecsolderAlec Solder
andauthored
[Frontend] Enable generic structured_outputs for responses API (#33709)
Signed-off-by: Alec Solder <alecs@fb.com> Co-authored-by: Alec Solder <alecs@fb.com>
1 parent 9ea1f59 commit be7370d

2 files changed

Lines changed: 58 additions & 7 deletions

File tree

tests/entrypoints/openai/responses/test_sampling_params.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@
44
"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping."""
55

66
import pytest
7+
import torch
8+
from openai.types.responses.response_format_text_json_schema_config import (
9+
ResponseFormatTextJSONSchemaConfig,
10+
)
11+
from pydantic import ValidationError
712

8-
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
13+
from vllm.entrypoints.openai.responses.protocol import (
14+
ResponsesRequest,
15+
ResponseTextConfig,
16+
)
17+
from vllm.sampling_params import StructuredOutputsParams
918

1019

1120
class TestResponsesRequestSamplingParams:
@@ -76,9 +85,6 @@ def test_default_values(self):
7685

7786
def test_seed_bounds_validation(self):
7887
"""Test that seed values outside torch.long bounds are rejected."""
79-
import torch
80-
from pydantic import ValidationError
81-
8288
# Test seed below minimum
8389
with pytest.raises(ValidationError) as exc_info:
8490
ResponsesRequest(
@@ -111,3 +117,40 @@ def test_seed_bounds_validation(self):
111117
seed=torch.iinfo(torch.long).max,
112118
)
113119
assert request_max.seed == torch.iinfo(torch.long).max
120+
121+
def test_structured_outputs_passed_through(self):
122+
"""Test that structured_outputs field is passed to SamplingParams."""
123+
structured_outputs = StructuredOutputsParams(grammar="root ::= 'hello'")
124+
request = ResponsesRequest(
125+
model="test-model",
126+
input="test input",
127+
structured_outputs=structured_outputs,
128+
)
129+
130+
sampling_params = request.to_sampling_params(default_max_tokens=1000)
131+
132+
assert sampling_params.structured_outputs is not None
133+
assert sampling_params.structured_outputs.grammar == "root ::= 'hello'"
134+
135+
def test_structured_outputs_and_json_schema_conflict(self):
136+
"""Test that specifying both structured_outputs and json_schema raises."""
137+
structured_outputs = StructuredOutputsParams(grammar="root ::= 'hello'")
138+
text_config = ResponseTextConfig()
139+
text_config.format = ResponseFormatTextJSONSchemaConfig(
140+
type="json_schema",
141+
name="test",
142+
schema={"type": "object"},
143+
)
144+
request = ResponsesRequest(
145+
model="test-model",
146+
input="test input",
147+
structured_outputs=structured_outputs,
148+
text=text_config,
149+
)
150+
151+
with pytest.raises(ValueError) as exc_info:
152+
request.to_sampling_params(default_max_tokens=1000)
153+
154+
assert "Cannot specify both structured_outputs and text.format" in str(
155+
exc_info.value
156+
)

vllm/entrypoints/openai/responses/protocol.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class ResponsesRequest(OpenAIBaseModel):
233233
# this cannot be used in conjunction with previous_response_id
234234
# TODO: consider supporting non harmony messages as well
235235
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
236+
structured_outputs: StructuredOutputsParams | None = Field(
237+
default=None,
238+
description="Additional kwargs for structured outputs",
239+
)
236240

237241
repetition_penalty: float | None = None
238242
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
@@ -319,8 +323,14 @@ def to_sampling_params(
319323
stop_token_ids = default_sampling_params.get("stop_token_ids")
320324

321325
# Structured output
322-
structured_outputs = None
326+
structured_outputs = self.structured_outputs
327+
328+
# Also check text.format for OpenAI-style json_schema
323329
if self.text is not None and self.text.format is not None:
330+
if structured_outputs is not None:
331+
raise ValueError(
332+
"Cannot specify both structured_outputs and text.format"
333+
)
324334
response_format = self.text.format
325335
if (
326336
response_format.type == "json_schema"
@@ -329,8 +339,6 @@ def to_sampling_params(
329339
structured_outputs = StructuredOutputsParams(
330340
json=response_format.schema_
331341
)
332-
elif response_format.type == "json_object":
333-
raise NotImplementedError("json_object is not supported")
334342

335343
stop = self.stop if self.stop else []
336344
if isinstance(stop, str):

0 commit comments

Comments
 (0)