Skip to content

Commit 65ee972

Browse files
alecsolderAlec Soldergemini-code-assist[bot]22quinn
authored
[BugFix] Adding env variable to disable async grammar compilation (#29996)
Signed-off-by: Alec Solder <[email protected]> Signed-off-by: Alec S <[email protected]> Co-authored-by: Alec Solder <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: 22quinn <[email protected]>
1 parent 62b3333 commit 65ee972

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

tests/v1/structured_output/test_backend_guidance.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import time
4+
from concurrent.futures import Future
5+
6+
import pytest
37
from transformers import AutoTokenizer
48

59
from vllm.config import StructuredOutputsConfig, VllmConfig
610
from vllm.config.model import ModelConfig
11+
from vllm.config.parallel import ParallelConfig
712
from vllm.config.speculative import SpeculativeConfig
813
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
914
from vllm.v1.request import Request
@@ -116,3 +121,72 @@ def grammar_bitmask(req: Request, tokens: list[int]) -> None:
116121
) # EOS not the final token
117122
grammar_bitmask(request, prompt[i:]) # EOS not present
118123
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
124+
125+
126+
@pytest.mark.parametrize("async_grammar", [True, False])
127+
def test_grammar_init_async_and_sync(async_grammar):
128+
"""Test grammar initialization works correctly in both async and sync modes.
129+
130+
This test validates that the distributed_executor_backend config option
131+
correctly controls whether grammar compilation happens asynchronously
132+
(via executor.submit) or synchronously. When set to "external_launcher",
133+
grammar compilation is synchronous to avoid deadlocks.
134+
"""
135+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
136+
prompt = tokenizer.encode('{"a": "b"}')
137+
138+
# Use "external_launcher" for sync mode, None for async mode
139+
executor_backend = None if async_grammar else "external_launcher"
140+
vllm_config = VllmConfig(
141+
model_config=ModelConfig(tokenizer=TOKENIZER),
142+
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
143+
parallel_config=ParallelConfig(distributed_executor_backend=executor_backend),
144+
)
145+
structured_output_manager = StructuredOutputManager(vllm_config)
146+
147+
sampling_params = SamplingParams(
148+
structured_outputs=StructuredOutputsParams(
149+
json='{"type": "object"}',
150+
),
151+
)
152+
sampling_params.structured_outputs._backend = "guidance"
153+
154+
request = Request(
155+
"test_request",
156+
prompt_token_ids=prompt,
157+
sampling_params=sampling_params,
158+
pooling_params=None,
159+
eos_token_id=tokenizer.eos_token_id,
160+
)
161+
162+
structured_output_manager.grammar_init(request)
163+
164+
# Check the internal _grammar type immediately after init
165+
# Before _check_grammar_completion is called, async mode should have a Future
166+
raw_grammar = request.structured_output_request._grammar
167+
if async_grammar:
168+
assert isinstance(raw_grammar, Future), (
169+
"Async mode should store a Future before completion"
170+
)
171+
else:
172+
assert not isinstance(raw_grammar, Future), (
173+
"Sync mode should store the grammar directly, not a Future"
174+
)
175+
176+
# Wait for grammar to be ready (handles both async and sync cases)
177+
start_time = time.time()
178+
while not request.structured_output_request._check_grammar_completion():
179+
if time.time() - start_time > 5: # 5-second timeout
180+
pytest.fail("Grammar compilation timed out")
181+
time.sleep(0.01)
182+
183+
# After completion, _grammar should no longer be a Future
184+
assert not isinstance(request.structured_output_request._grammar, Future)
185+
186+
# Verify grammar is properly initialized and functional
187+
grammar = request.structured_output_request.grammar
188+
assert grammar is not None
189+
assert not grammar.is_terminated()
190+
191+
# Verify the grammar can accept valid tokens
192+
assert grammar.accept_tokens(request.request_id, prompt)

vllm/v1/structured_output/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,16 @@ def __init__(self, vllm_config: VllmConfig):
4040
self.reasoner: ReasoningParser | None = None
4141
self.vllm_config = vllm_config
4242

43+
# When in external_launcher mode, async grammar compilation causes deadlocks
44+
# due to external_launcher mode having a scheduler for each TP rank.
45+
# Async grammar compilation causes the WAITING_FOR_FSM → WAITING transition to
46+
# happen at different times on different TP ranks,
47+
# breaking the determinism assumption that external_launcher relies on.
48+
self._use_async_grammar_compilation = (
49+
vllm_config.parallel_config.distributed_executor_backend
50+
!= "external_launcher"
51+
)
52+
4353
self._grammar_bitmask: torch.Tensor | None = None
4454
self._full_mask = torch.tensor(-1, dtype=torch.int32)
4555

@@ -138,10 +148,13 @@ def grammar_init(self, request: Request) -> None:
138148
else:
139149
raise ValueError(f"Unsupported structured output backend: {backend}")
140150

141-
grammar = self.executor.submit(self._async_create_grammar, request)
151+
if self._use_async_grammar_compilation:
152+
grammar = self.executor.submit(self._create_grammar, request)
153+
else:
154+
grammar = self._create_grammar(request) # type: ignore[assignment]
142155
request.structured_output_request.grammar = grammar # type: ignore[assignment]
143156

144-
def _async_create_grammar(
157+
def _create_grammar(
145158
self,
146159
request: Request,
147160
) -> StructuredOutputGrammar:

0 commit comments

Comments
 (0)