77
88from vllm .config import VllmConfig
99from vllm .logger import init_logger
10- from vllm .transformers_utils .tokenizer_group import init_tokenizer_from_configs
11- from vllm .transformers_utils .tokenizers .mistral import MistralTokenizer
12- from vllm .utils import LazyLoader
13- from vllm .v1 .structured_output .grammar import Grammar , StructuredOutputOptions
10+ from vllm .v1 .structured_output .backend_types import (StructuredOutputBackend ,
11+ StructuredOutputGrammar )
12+ from vllm .v1 .structured_output .backend_xgrammar import XgrammarBackend
1413
1514if TYPE_CHECKING :
1615 import numpy as np
1716 import numpy .typing as npt
18- import xgrammar as xgr
17+ import torch
1918
2019 from vllm .v1 .request import Request
21- else :
22- xgr = LazyLoader ("xgr" , globals (), "xgrammar" )
2320
2421logger = init_logger (__name__ )
2522
2623
2724class StructuredOutputManager :
25+ """Engine-level manager for structured output requests."""
2826
2927 def __init__ (self , vllm_config : VllmConfig ):
28+ self .backend : Optional [StructuredOutputBackend ] = None
3029 self .vllm_config = vllm_config
31- self .init_complete = False
32-
33- def _delayed_init (self ):
34- """Initialization delayed until we know it is needed."""
35- tokenizer_group = init_tokenizer_from_configs (
36- model_config = self .vllm_config .model_config ,
37- scheduler_config = self .vllm_config .scheduler_config ,
38- parallel_config = self .vllm_config .parallel_config ,
39- lora_config = self .vllm_config .lora_config ) # type: ignore[arg-type]
40- tokenizer_group .ping ()
41-
42- tokenizer = tokenizer_group .get_lora_tokenizer (None )
43- self .vocab_size = self .vllm_config .model_config .get_vocab_size ()
44- if isinstance (tokenizer , MistralTokenizer ):
45- # NOTE: ideally, xgrammar should handle this accordingly.
46- # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
47- try :
48- encoded_vocab = [
49- token for token , _ in sorted (
50- tokenizer .get_vocab ().items (),
51- key = lambda x : x [1 ],
52- )
53- ]
54- stop_token_ids = None
55- if hasattr (
56- tokenizer ,
57- "eos_token_id" ,
58- ) and tokenizer .eos_token_id is not None :
59- stop_token_ids = [tokenizer .eos_token_id ]
60- except AttributeError as e :
61- raise ValueError (
62- f"Cannot get the vocabulary of the tokenizer "
63- f"{ type (tokenizer )} . The tokenizer should have a "
64- "get_vocab method." ) from e
65- tokenizer_info = xgr .TokenizerInfo (
66- encoded_vocab = encoded_vocab ,
67- # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
68- vocab_type = xgr .VocabType .BYTE_FALLBACK ,
69- vocab_size = self .vocab_size ,
70- stop_token_ids = stop_token_ids ,
71- add_prefix_space = True ,
72- )
73- else :
74- tokenizer_info = xgr .TokenizerInfo .from_huggingface (
75- tokenizer ,
76- vocab_size = self .vocab_size ,
77- )
78- self .compiler = xgr .GrammarCompiler (tokenizer_info , max_threads = 8 )
30+ self ._grammar_bitmask : Optional [torch .Tensor ] = None
7931
8032 # The default max_workers if not specified is the number of CPUs * 5,
8133 # which is way too high since these tasks are CPU-bound, not I/O bound.
8234 # We also know we would never dominate CPU usage with just grammar
8335 # compilation, so we set it to half the number of CPUs.
8436 max_workers = max (1 , (multiprocessing .cpu_count () + 1 ) // 2 )
8537 self .executor = ThreadPoolExecutor (max_workers = max_workers )
86- self ._grammar_bitmask = xgr .allocate_token_bitmask (
87- self .vllm_config .scheduler_config .max_num_seqs ,
88- self .vocab_size ,
89- )
90-
91- self .init_complete = True
9238
9339 def grammar_init (self , request : Request ) -> None :
9440 if request .structured_output_request is None :
9541 return
9642
97- # The first time this is called, we need to finish initialization
98- # of xgrammar. We defer it to avoid the import of xgrammar and
99- # initialization cost if it is not going to be used.
100- if not self .init_complete :
101- self ._delayed_init ()
43+ # Initialize the backend the first time it is needed.
44+ #
45+ # NOTE: We only support a single backend. We do NOT support different
46+ # backends on a per-request basis in V1 (for now, anyway...).
47+ if self .backend is None :
48+ backend_name = request .sampling_params .guided_decoding .backend_name
49+ if backend_name == "xgrammar" :
50+ self .backend = XgrammarBackend (self .vllm_config )
51+ else :
52+ raise ValueError (
53+ f"Unsupported structured output backend: { backend_name } " )
10254
103- grammar : Future [Grammar ] = self .executor .submit (
104- self ._async_create_grammar , request )
55+ grammar : Future [StructuredOutputGrammar ] = self .executor .submit (
56+ self ._async_create_grammar , request , self . backend )
10557 request .structured_output_request .grammar = grammar # type: ignore[assignment]
10658
107- def _async_create_grammar (self , request : Request ) -> Grammar :
59+ def _async_create_grammar (
60+ self , request : Request ,
61+ backend : StructuredOutputBackend ) -> StructuredOutputGrammar :
10862 key = request .structured_output_request .structured_output_key # type: ignore[union-attr]
10963
11064 # Note that the request was validated in the engine core client,
@@ -114,28 +68,8 @@ def _async_create_grammar(self, request: Request) -> Grammar:
11468 # though it should be unlikely as we test that up front as well.
11569 request_type , grammar_spec = key
11670
117- if request_type == StructuredOutputOptions .JSON :
118- # TODO -- allow any_whitespace to be configurable
119- # pending merge of https://github.com/vllm-project/vllm/pull/12744
120- ctx = self .compiler .compile_json_schema (grammar_spec ,
121- any_whitespace = False )
122- elif request_type == StructuredOutputOptions .JSON_OBJECT :
123- ctx = self .compiler .compile_builtin_json_grammar ()
124- elif request_type == StructuredOutputOptions .GRAMMAR :
125- ctx = self .compiler .compile_grammar (grammar_spec )
126- elif request_type == StructuredOutputOptions .REGEX :
127- ctx = self .compiler .compile_regex (grammar_spec )
128- else :
129- logger .error ("Validation should have already occurred. "
130- "Please file an issue." )
131- raise ValueError (
132- f"grammar is not of valid supported types. ({ request_type !s} )" )
133-
134- return Grammar (
135- matcher = xgr .GrammarMatcher (ctx ),
136- vocab_size = self .vocab_size ,
137- ctx = ctx ,
138- )
71+ assert self .backend is not None
72+ return self .backend .compile_grammar (request_type , grammar_spec )
13973
14074 def grammar_bitmask (
14175 self ,
@@ -147,14 +81,19 @@ def grammar_bitmask(
14781 if not structured_output_request_ids :
14882 return None
14983
84+ if self ._grammar_bitmask is None :
85+ assert self .backend is not None
86+ self ._grammar_bitmask = self .backend .allocate_token_bitmask (
87+ self .vllm_config .scheduler_config .max_num_seqs )
88+
15089 # Fill the bitmask using the index of each request equal to its
15190 # position in the batch. Resize the bitmask down to the size of
15291 # the batch.
15392 bitmask_tensor = self ._grammar_bitmask
15493 for req_id , batch_index in structured_output_request_ids .items ():
15594 request = requests [req_id ].structured_output_request
15695 assert request is not None and request .grammar is not None
157- if not request .grammar .matcher . is_terminated ():
96+ if not request .grammar .is_terminated ():
15897 request .grammar .fill_bitmask (bitmask_tensor , batch_index )
15998 if batch_len < self ._grammar_bitmask .shape [0 ]:
16099 bitmask_tensor = self ._grammar_bitmask [:batch_len ]
0 commit comments