Skip to content

Commit 120157f

Browse files
authored
Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)
1 parent 8e67598 commit 120157f

File tree

4 files changed

+176
-49
lines changed

4 files changed

+176
-49
lines changed

tests/entrypoints/test_openai_server.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
660660
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
661661

662662

663+
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
664+
resp = await client.chat.completions.create(
665+
model=MODEL_NAME,
666+
messages=[{
667+
"role":
668+
"user",
669+
"content": ('what is 1+1? please respond with a JSON object, '
670+
'the format is {"result": 2}')
671+
}],
672+
response_format={"type": "json_object"})
673+
674+
content = resp.choices[0].message.content
675+
loaded = json.loads(content)
676+
assert loaded == {"result": 2}, loaded
677+
678+
679+
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
680+
simple_sql_grammar = """
681+
start: select_statement
682+
683+
select_statement: "SELECT" column "from" table "where" condition
684+
685+
column: "col_1" | "col_2"
686+
table: "table_1" | "table_2"
687+
condition: column "=" number
688+
689+
number: "1" | "2"
690+
"""
691+
692+
completion = await client.completions.create(
693+
model=MODEL_NAME,
694+
prompt=("Generate a sql state that select col_1 from "
695+
"table_1 where it is equals to 1"),
696+
temperature=1.0,
697+
max_tokens=500,
698+
extra_body=dict(guided_grammar=simple_sql_grammar))
699+
700+
content = completion.choices[0].text
701+
702+
# use Lark to parse the output, and make sure it's a valid parse tree
703+
from lark import Lark
704+
parser = Lark(simple_sql_grammar)
705+
parser.parse(content)
706+
707+
# remove spaces for comparison b/c we removed them in the grammar
708+
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
709+
710+
assert content.strip() == ground_truth
711+
712+
663713
if __name__ == "__main__":
664714
pytest.main([__file__])

vllm/entrypoints/openai/protocol.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
5555
completion_tokens: Optional[int] = 0
5656

5757

58+
class ResponseFormat(BaseModel):
59+
# type must be "json_object" or "text"
60+
type: str = Literal["text", "json_object"]
61+
62+
5863
class ChatCompletionRequest(BaseModel):
5964
model: str
6065
messages: List[Dict[str, str]]
@@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
8994
guided_json: Optional[Union[str, dict, BaseModel]] = None
9095
guided_regex: Optional[str] = None
9196
guided_choice: Optional[List[str]] = None
97+
guided_grammar: Optional[str] = None
98+
response_format: Optional[ResponseFormat] = None
9299

93100
def to_sampling_params(self) -> SamplingParams:
94101
if self.logprobs and not self.top_logprobs:
@@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
183190
guided_json: Optional[Union[str, dict, BaseModel]] = None
184191
guided_regex: Optional[str] = None
185192
guided_choice: Optional[List[str]] = None
193+
guided_grammar: Optional[str] = None
194+
response_format: Optional[ResponseFormat] = None
186195

187196
def to_sampling_params(self):
188197
echo_without_generation = self.echo and self.max_tokens == 0

vllm/model_executor/guided_decoding.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,50 @@
66
from json import dumps as json_dumps
77
from re import escape as regex_escape
88
from typing import Union, Tuple
9+
910
from pydantic import BaseModel
11+
from transformers import PreTrainedTokenizerBase
1012

1113
from vllm.entrypoints.openai.protocol import (CompletionRequest,
1214
ChatCompletionRequest)
1315
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
14-
RegexLogitsProcessor)
16+
RegexLogitsProcessor,
17+
CFGLogitsProcessor)
1518

1619

1720
class GuidedDecodingMode(Enum):
1821
JSON = "json"
1922
REGEX = "regex"
2023
CHOICE = "choice"
24+
GRAMMAR = "grammar"
25+
26+
27+
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
28+
# the main difference is that we changed the start: value to
29+
# start: object | array, so we are denying scalar values as the root of the
30+
# JSON. Starting with scalars as the root seems to cause llama to generate
31+
# without stop.
32+
JSON_GRAMMAR = r"""
33+
?start: object | array
34+
35+
?value: object
36+
| array
37+
| UNESCAPED_STRING
38+
| SIGNED_NUMBER -> number
39+
| "true" -> true
40+
| "false" -> false
41+
| "null" -> null
42+
43+
array : "[" [value ("," value)*] "]"
44+
object : "{" [pair ("," pair)*] "}"
45+
pair : UNESCAPED_STRING ":" value
46+
47+
%import common.UNESCAPED_STRING
48+
%import common.SIGNED_NUMBER
49+
%import common.WS
2150
51+
%ignore WS
52+
"""
2253

2354
global_thread_pool = None # used for generating logits processor fsm
2455

@@ -57,9 +88,6 @@ def _get_guide_and_mode(
5788
) -> Tuple[str, GuidedDecodingMode]:
5889

5990
if request.guided_json:
60-
if not isinstance(request.guided_json, (str, dict, BaseModel)):
61-
raise TypeError("JSON schema must be str, dict, or BaseModel")
62-
6391
json = request.guided_json
6492
if isinstance(json, dict):
6593
# turn dict into hashable string
@@ -69,33 +97,33 @@ def _get_guide_and_mode(
6997
# with the same fields will get hashed the same
7098
json = str(json.__signature__)
7199
return json, GuidedDecodingMode.JSON
72-
73100
elif request.guided_regex:
74-
if not isinstance(request.guided_regex, str):
75-
raise TypeError("Regex must be string")
76101
return request.guided_regex, GuidedDecodingMode.REGEX
77-
78102
elif request.guided_choice:
79-
if not isinstance(request.guided_choice, list):
80-
raise TypeError("Choices must be a list")
81-
82103
# choice just uses regex
83104
choices = [
84105
regex_escape(str(choice)) for choice in request.guided_choice
85106
]
86107
choices_regex = "(" + "|".join(choices) + ")"
87108
return choices_regex, GuidedDecodingMode.CHOICE
88-
109+
elif request.guided_grammar:
110+
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
111+
elif (request.response_format is not None
112+
and request.response_format.type == "json_object"):
113+
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
89114
else:
90115
return None, None
91116

92117

93118
@lru_cache(maxsize=32)
94-
def _get_cached_logits_processor(guide: str, tokenizer,
119+
def _get_cached_logits_processor(guide: str,
120+
tokenizer: PreTrainedTokenizerBase,
95121
mode: GuidedDecodingMode):
96122
if mode == GuidedDecodingMode.JSON:
97123
return JSONLogitsProcessor(guide, tokenizer)
98124
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
99125
return RegexLogitsProcessor(guide, tokenizer)
126+
elif mode == GuidedDecodingMode.GRAMMAR:
127+
return CFGLogitsProcessor(guide, tokenizer)
100128
else:
101129
raise ValueError(f"Unknown guided decoding mode {mode}")

vllm/model_executor/guided_logits_processors.py

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,60 @@
1616
import json
1717
import math
1818
from collections import defaultdict
19-
from typing import Union, DefaultDict, Dict, List, Optional
19+
from typing import Union, DefaultDict, Dict, List, Optional, Callable
2020

2121
import torch
2222
from pydantic import BaseModel
23-
from outlines.fsm.fsm import RegexFSM
23+
from transformers import PreTrainedTokenizerBase
24+
from outlines.fsm.fsm import RegexFSM, CFGFSM
2425
from outlines.fsm.json_schema import build_regex_from_schema
2526

2627

27-
class RegexLogitsProcessor:
28+
class BaseLogitsProcessor:
2829

29-
def __init__(self, regex_string: str, tokenizer):
30-
"""Compile the FSM that drives the regex-structured generation.
30+
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
31+
"""Adapt vLLM's tokenizer to use to compile the FSM.
3132
32-
Parameters
33-
----------
34-
regex_string
35-
A string that represents a regular expression
36-
tokenizer
37-
The model's tokenizer
33+
The API of Outlines tokenizers is slightly different to that of
34+
`transformers`. The decoder of outlines, returns a list whereas
35+
the decode of vLLM returns an str. To sync the vLLM decoder with
36+
outlines internal api, the decoder should be adapted. In addition
37+
we need to handle the missing spaces to Llama's tokenizer to be
38+
able to compile FSMs for this model.
3839
3940
"""
40-
tokenizer = self.adapt_tokenizer(tokenizer)
41-
fsm = RegexFSM(regex_string, tokenizer)
42-
self.fsm = fsm
41+
if getattr(tokenizer, "_outlines_adapted", False):
42+
return tokenizer
43+
44+
tokenizer.vocabulary = tokenizer.get_vocab()
45+
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
46+
47+
def convert_token_to_string(token: str) -> str:
48+
from transformers.file_utils import SPIECE_UNDERLINE
49+
50+
string = tokenizer.convert_tokens_to_string([token])
51+
52+
# A hack to handle missing spaces to HF's Llama tokenizers
53+
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
54+
return " " + string
55+
56+
return string
57+
58+
def change_decoder(
59+
decoder: Callable[[List[int]], str]
60+
) -> Callable[[List[int]], List[str]]:
61+
"""Sync vLLM's decoder with the outlines by returning list."""
62+
63+
def new_decoder(inp_tokens: List[int]) -> List[str]:
64+
return [decoder(inp_tokens)]
65+
66+
return new_decoder
67+
68+
tokenizer.convert_token_to_string = convert_token_to_string
69+
tokenizer.decode = change_decoder(tokenizer.decode)
70+
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
71+
72+
return tokenizer
4373

4474
def init_state(self):
4575
"""Initialize the FSM states."""
@@ -69,38 +99,30 @@ def __call__(self, input_ids: List[int],
6999

70100
return scores
71101

72-
def adapt_tokenizer(self, tokenizer):
73-
"""Adapt vLLM's tokenizer to use to compile the FSM.
74-
75-
The API of Outlines tokenizers is slightly different to that of
76-
`transformers`. In addition we need to handle the missing spaces to
77-
Llama's tokenizer to be able to compile FSMs for this model.
78-
79-
"""
80-
tokenizer.vocabulary = tokenizer.get_vocab()
81-
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
82-
83-
def convert_token_to_string(token: str) -> str:
84-
from transformers.file_utils import SPIECE_UNDERLINE
85102

86-
string = tokenizer.convert_tokens_to_string([token])
103+
class RegexLogitsProcessor(BaseLogitsProcessor):
87104

88-
# A hack to handle missing spaces to HF's Llama tokenizers
89-
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
90-
return " " + string
91-
92-
return string
105+
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
106+
"""Compile the FSM that drives the regex-structured generation.
93107
94-
tokenizer.convert_token_to_string = convert_token_to_string
108+
Parameters
109+
----------
110+
regex_string
111+
A string that represents a regular expression
112+
tokenizer
113+
The model's tokenizer
95114
96-
return tokenizer
115+
"""
116+
tokenizer = self.adapt_tokenizer(tokenizer)
117+
fsm = RegexFSM(regex_string, tokenizer)
118+
self.fsm = fsm
97119

98120

99121
class JSONLogitsProcessor(RegexLogitsProcessor):
100122

101123
def __init__(self,
102124
schema: Union[str, Dict, BaseModel],
103-
tokenizer,
125+
tokenizer: PreTrainedTokenizerBase,
104126
whitespace_pattern: Optional[str] = None):
105127
"""Compile the FSM that drives the JSON-guided generation.
106128
@@ -130,3 +152,21 @@ def __init__(self,
130152
f"the JSON Schema specification")
131153
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
132154
super().__init__(regex_string, tokenizer)
155+
156+
157+
class CFGLogitsProcessor(BaseLogitsProcessor):
158+
159+
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
160+
"""Compile the FSM that drives the context free grammar generation.
161+
162+
Parameters
163+
----------
164+
cfg
165+
A string that represents a context-free grammar
166+
tokenizer
167+
The model's tokenizer
168+
169+
"""
170+
tokenizer = self.adapt_tokenizer(tokenizer)
171+
fsm = CFGFSM(cfg, tokenizer)
172+
self.fsm = fsm

0 commit comments

Comments
 (0)