|
16 | 16 | import json |
17 | 17 | import math |
18 | 18 | from collections import defaultdict |
19 | | -from typing import Union, DefaultDict, Dict, List, Optional |
| 19 | +from typing import Union, DefaultDict, Dict, List, Optional, Callable |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | from pydantic import BaseModel |
23 | | -from outlines.fsm.fsm import RegexFSM |
| 23 | +from transformers import PreTrainedTokenizerBase |
| 24 | +from outlines.fsm.fsm import RegexFSM, CFGFSM |
24 | 25 | from outlines.fsm.json_schema import build_regex_from_schema |
25 | 26 |
|
26 | 27 |
|
27 | | -class RegexLogitsProcessor: |
| 28 | +class BaseLogitsProcessor: |
28 | 29 |
|
29 | | - def __init__(self, regex_string: str, tokenizer): |
30 | | - """Compile the FSM that drives the regex-structured generation. |
| 30 | + def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): |
| 31 | + """Adapt vLLM's tokenizer to use to compile the FSM. |
31 | 32 |
|
32 | | - Parameters |
33 | | - ---------- |
34 | | - regex_string |
35 | | - A string that represents a regular expression |
36 | | - tokenizer |
37 | | - The model's tokenizer |
| 33 | + The API of Outlines tokenizers is slightly different to that of |
| 34 | + `transformers`. The decoder of outlines, returns a list whereas |
| 35 | + the decode of vLLM returns an str. To sync the vLLM decoder with |
| 36 | + outlines internal api, the decoder should be adapted. In addition |
| 37 | + we need to handle the missing spaces to Llama's tokenizer to be |
| 38 | + able to compile FSMs for this model. |
38 | 39 |
|
39 | 40 | """ |
40 | | - tokenizer = self.adapt_tokenizer(tokenizer) |
41 | | - fsm = RegexFSM(regex_string, tokenizer) |
42 | | - self.fsm = fsm |
| 41 | + if getattr(tokenizer, "_outlines_adapted", False): |
| 42 | + return tokenizer |
| 43 | + |
| 44 | + tokenizer.vocabulary = tokenizer.get_vocab() |
| 45 | + tokenizer.special_tokens = set(tokenizer.all_special_tokens) |
| 46 | + |
| 47 | + def convert_token_to_string(token: str) -> str: |
| 48 | + from transformers.file_utils import SPIECE_UNDERLINE |
| 49 | + |
| 50 | + string = tokenizer.convert_tokens_to_string([token]) |
| 51 | + |
| 52 | + # A hack to handle missing spaces to HF's Llama tokenizers |
| 53 | + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": |
| 54 | + return " " + string |
| 55 | + |
| 56 | + return string |
| 57 | + |
| 58 | + def change_decoder( |
| 59 | + decoder: Callable[[List[int]], str] |
| 60 | + ) -> Callable[[List[int]], List[str]]: |
| 61 | + """Sync vLLM's decoder with the outlines by returning list.""" |
| 62 | + |
| 63 | + def new_decoder(inp_tokens: List[int]) -> List[str]: |
| 64 | + return [decoder(inp_tokens)] |
| 65 | + |
| 66 | + return new_decoder |
| 67 | + |
| 68 | + tokenizer.convert_token_to_string = convert_token_to_string |
| 69 | + tokenizer.decode = change_decoder(tokenizer.decode) |
| 70 | + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 |
| 71 | + |
| 72 | + return tokenizer |
43 | 73 |
|
44 | 74 | def init_state(self): |
45 | 75 | """Initialize the FSM states.""" |
@@ -69,38 +99,30 @@ def __call__(self, input_ids: List[int], |
69 | 99 |
|
70 | 100 | return scores |
71 | 101 |
|
72 | | - def adapt_tokenizer(self, tokenizer): |
73 | | - """Adapt vLLM's tokenizer to use to compile the FSM. |
74 | | -
|
75 | | - The API of Outlines tokenizers is slightly different to that of |
76 | | - `transformers`. In addition we need to handle the missing spaces to |
77 | | - Llama's tokenizer to be able to compile FSMs for this model. |
78 | | -
|
79 | | - """ |
80 | | - tokenizer.vocabulary = tokenizer.get_vocab() |
81 | | - tokenizer.special_tokens = set(tokenizer.all_special_tokens) |
82 | | - |
83 | | - def convert_token_to_string(token: str) -> str: |
84 | | - from transformers.file_utils import SPIECE_UNDERLINE |
85 | 102 |
|
86 | | - string = tokenizer.convert_tokens_to_string([token]) |
| 103 | +class RegexLogitsProcessor(BaseLogitsProcessor): |
87 | 104 |
|
88 | | - # A hack to handle missing spaces to HF's Llama tokenizers |
89 | | - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": |
90 | | - return " " + string |
91 | | - |
92 | | - return string |
| 105 | + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): |
| 106 | + """Compile the FSM that drives the regex-structured generation. |
93 | 107 |
|
94 | | - tokenizer.convert_token_to_string = convert_token_to_string |
| 108 | + Parameters |
| 109 | + ---------- |
| 110 | + regex_string |
| 111 | + A string that represents a regular expression |
| 112 | + tokenizer |
| 113 | + The model's tokenizer |
95 | 114 |
|
96 | | - return tokenizer |
| 115 | + """ |
| 116 | + tokenizer = self.adapt_tokenizer(tokenizer) |
| 117 | + fsm = RegexFSM(regex_string, tokenizer) |
| 118 | + self.fsm = fsm |
97 | 119 |
|
98 | 120 |
|
99 | 121 | class JSONLogitsProcessor(RegexLogitsProcessor): |
100 | 122 |
|
101 | 123 | def __init__(self, |
102 | 124 | schema: Union[str, Dict, BaseModel], |
103 | | - tokenizer, |
| 125 | + tokenizer: PreTrainedTokenizerBase, |
104 | 126 | whitespace_pattern: Optional[str] = None): |
105 | 127 | """Compile the FSM that drives the JSON-guided generation. |
106 | 128 |
|
@@ -130,3 +152,21 @@ def __init__(self, |
130 | 152 | f"the JSON Schema specification") |
131 | 153 | regex_string = build_regex_from_schema(schema_str, whitespace_pattern) |
132 | 154 | super().__init__(regex_string, tokenizer) |
| 155 | + |
| 156 | + |
| 157 | +class CFGLogitsProcessor(BaseLogitsProcessor): |
| 158 | + |
| 159 | + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): |
| 160 | + """Compile the FSM that drives the context free grammar generation. |
| 161 | +
|
| 162 | + Parameters |
| 163 | + ---------- |
| 164 | + cfg |
| 165 | + A string that represents a context-free grammar |
| 166 | + tokenizer |
| 167 | + The model's tokenizer |
| 168 | +
|
| 169 | + """ |
| 170 | + tokenizer = self.adapt_tokenizer(tokenizer) |
| 171 | + fsm = CFGFSM(cfg, tokenizer) |
| 172 | + self.fsm = fsm |
0 commit comments