Skip to content

Commit dd05b0f

Browse files
committed
[V1] Delay all xgrammar usage until needed
PR vllm-project#14575 delayed initialization of the grammar bitmask until it was needed to try to fix a problem encountered on TPU systems. Unfortunately, that change was not sufficient. We need to delay usage of ALL xgrammar APIs, not just the grammar initialization. This change implements that. More initialization is now deferred until the first time a structured output request is received. Signed-off-by: Russell Bryant <[email protected]>
1 parent 5305673 commit dd05b0f

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

vllm/v1/structured_output/__init__.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
if TYPE_CHECKING:
1515
import numpy as np
1616
import numpy.typing as npt
17-
import torch
1817
import xgrammar as xgr
1918

2019
from vllm.v1.request import Request
@@ -27,14 +26,18 @@
2726
class StructuredOutputManager:
2827

2928
def __init__(self, vllm_config: VllmConfig):
30-
tokenizer_group = init_tokenizer_from_configs(
31-
model_config=vllm_config.model_config,
32-
scheduler_config=vllm_config.scheduler_config,
33-
parallel_config=vllm_config.parallel_config,
34-
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
35-
tokenizer_group.ping()
3629
self.vocab_size = vllm_config.model_config.get_vocab_size()
3730
self.vllm_config = vllm_config
31+
self.init_complete = False
32+
33+
def _delayed_init(self):
34+
"""Initialization delayed until we know it is needed."""
35+
tokenizer_group = init_tokenizer_from_configs(
36+
model_config=self.vllm_config.model_config,
37+
scheduler_config=self.vllm_config.scheduler_config,
38+
parallel_config=self.vllm_config.parallel_config,
39+
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
40+
tokenizer_group.ping()
3841

3942
tokenizer = tokenizer_group.get_lora_tokenizer(None)
4043
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
@@ -47,12 +50,21 @@ def __init__(self, vllm_config: VllmConfig):
4750
# compilation, so we set it to half the number of CPUs.
4851
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
4952
self.executor = ThreadPoolExecutor(max_workers=max_workers)
50-
self._grammar_bitmask: Optional[torch.Tensor] = None
53+
self._grammar_bitmask = xgr.allocate_token_bitmask(
54+
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
55+
56+
self.init_complete = True
5157

5258
def grammar_init(self, request: Request) -> None:
5359
if request.structured_output_request is None:
5460
return
5561

62+
# The first time this is called, we need to finish initialization
63+
# of xgrammar. We defer it to avoid the import of xgrammar and
64+
# initialization cost if it is not going to be used.
65+
if not self.init_complete:
66+
self._delayed_init()
67+
5668
grammar: Future[Grammar] = self.executor.submit(
5769
self._async_create_grammar, request)
5870
request.structured_output_request.grammar = grammar # type: ignore[assignment]
@@ -100,11 +112,6 @@ def grammar_bitmask(
100112
if not structured_output_request_ids:
101113
return None
102114

103-
if self._grammar_bitmask is None:
104-
self._grammar_bitmask = xgr.allocate_token_bitmask(
105-
self.vllm_config.scheduler_config.max_num_seqs,
106-
self.vocab_size)
107-
108115
# Fill the bitmask using the index of each request equal to its
109116
# position in the batch. Resize the bitmask down to the size of
110117
# the batch.

0 commit comments

Comments
 (0)