|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import time |
| 4 | +from concurrent.futures import Future |
| 5 | + |
| 6 | +import pytest |
3 | 7 | from transformers import AutoTokenizer |
4 | 8 |
|
5 | 9 | from vllm.config import StructuredOutputsConfig, VllmConfig |
6 | 10 | from vllm.config.model import ModelConfig |
| 11 | +from vllm.config.parallel import ParallelConfig |
7 | 12 | from vllm.config.speculative import SpeculativeConfig |
8 | 13 | from vllm.sampling_params import SamplingParams, StructuredOutputsParams |
9 | 14 | from vllm.v1.request import Request |
@@ -116,3 +121,72 @@ def grammar_bitmask(req: Request, tokens: list[int]) -> None: |
116 | 121 | ) # EOS not the final token |
117 | 122 | grammar_bitmask(request, prompt[i:]) # EOS not present |
118 | 123 | grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) |
| 124 | + |
| 125 | + |
| 126 | +@pytest.mark.parametrize("async_grammar", [True, False]) |
| 127 | +def test_grammar_init_async_and_sync(async_grammar): |
| 128 | + """Test grammar initialization works correctly in both async and sync modes. |
| 129 | +
|
| 130 | + This test validates that the distributed_executor_backend config option |
| 131 | + correctly controls whether grammar compilation happens asynchronously |
| 132 | + (via executor.submit) or synchronously. When set to "external_launcher", |
| 133 | + grammar compilation is synchronous to avoid deadlocks. |
| 134 | + """ |
| 135 | + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
| 136 | + prompt = tokenizer.encode('{"a": "b"}') |
| 137 | + |
| 138 | + # Use "external_launcher" for sync mode, None for async mode |
| 139 | + executor_backend = None if async_grammar else "external_launcher" |
| 140 | + vllm_config = VllmConfig( |
| 141 | + model_config=ModelConfig(tokenizer=TOKENIZER), |
| 142 | + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), |
| 143 | + parallel_config=ParallelConfig(distributed_executor_backend=executor_backend), |
| 144 | + ) |
| 145 | + structured_output_manager = StructuredOutputManager(vllm_config) |
| 146 | + |
| 147 | + sampling_params = SamplingParams( |
| 148 | + structured_outputs=StructuredOutputsParams( |
| 149 | + json='{"type": "object"}', |
| 150 | + ), |
| 151 | + ) |
| 152 | + sampling_params.structured_outputs._backend = "guidance" |
| 153 | + |
| 154 | + request = Request( |
| 155 | + "test_request", |
| 156 | + prompt_token_ids=prompt, |
| 157 | + sampling_params=sampling_params, |
| 158 | + pooling_params=None, |
| 159 | + eos_token_id=tokenizer.eos_token_id, |
| 160 | + ) |
| 161 | + |
| 162 | + structured_output_manager.grammar_init(request) |
| 163 | + |
| 164 | + # Check the internal _grammar type immediately after init |
| 165 | + # Before _check_grammar_completion is called, async mode should have a Future |
| 166 | + raw_grammar = request.structured_output_request._grammar |
| 167 | + if async_grammar: |
| 168 | + assert isinstance(raw_grammar, Future), ( |
| 169 | + "Async mode should store a Future before completion" |
| 170 | + ) |
| 171 | + else: |
| 172 | + assert not isinstance(raw_grammar, Future), ( |
| 173 | + "Sync mode should store the grammar directly, not a Future" |
| 174 | + ) |
| 175 | + |
| 176 | + # Wait for grammar to be ready (handles both async and sync cases) |
| 177 | + start_time = time.time() |
| 178 | + while not request.structured_output_request._check_grammar_completion(): |
| 179 | + if time.time() - start_time > 5: # 5-second timeout |
| 180 | + pytest.fail("Grammar compilation timed out") |
| 181 | + time.sleep(0.01) |
| 182 | + |
| 183 | + # After completion, _grammar should no longer be a Future |
| 184 | + assert not isinstance(request.structured_output_request._grammar, Future) |
| 185 | + |
| 186 | + # Verify grammar is properly initialized and functional |
| 187 | + grammar = request.structured_output_request.grammar |
| 188 | + assert grammar is not None |
| 189 | + assert not grammar.is_terminated() |
| 190 | + |
| 191 | + # Verify the grammar can accept valid tokens |
| 192 | + assert grammar.accept_tokens(request.request_id, prompt) |
0 commit comments