diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 4aae755eb4e4..ba773223f353 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -250,7 +250,7 @@ def _filter_func(item): ) input_len = len(tokenizer(prompt).input_ids) completion = dataset["completion"][idx] - + requests.append( SampleRequest( prompt=prompt, diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b5cd6c5c8af5..dbd70739f12b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -151,7 +151,7 @@ class SchedulerOutput: # for filling the next token bitmask structured_output_request_ids: dict[str, int] # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] + eos_token_ids: dict[str,int] # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8322fa7335b6..4833b332532b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -193,6 +193,9 @@ def schedule(self) -> SchedulerOutput: # For logging. scheduled_timestamp = time.monotonic() + #for structured output + structured_output_request_ids: dict[str,int] = {} + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -343,13 +346,7 @@ def schedule(self) -> SchedulerOutput: # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: - structured_output_req = request.structured_output_request - if structured_output_req and structured_output_req.grammar: - request.status = RequestStatus.WAITING - else: - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue + request.status = RequestStatus.WAITING # Check that adding the request still respects the max_loras # constraint. @@ -559,9 +556,10 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(self.running, - scheduled_spec_decode_tokens)) + + eos_token_ids = {req.request_id: req.eos_token_id for req in self.running if req.eos_token_id is not None} + structured_output_request_ids = self.get_structured_output_request_ids(self.running, + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -578,7 +576,7 @@ def schedule(self) -> SchedulerOutput: free_encoder_mm_hashes=self.encoder_cache_manager. get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, + eos_token_ids=eos_token_ids, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -807,17 +805,13 @@ def _try_schedule_encoder_inputs( encoder_compute_budget, ) - def get_grammar_bitmask( + def get_structured_output_request_ids( self, requests: list[Request], - scheduled_spec_decode_tokens: dict[str, list[int]], ): # NOTE: structured_output_request_ids maps # a request's (request that uses structured output) # request_id to its index in the batch. - # This will help us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. structured_output_request_ids: dict[str, int] = {} for i, req in enumerate(requests): if req.use_structured_output: @@ -826,16 +820,7 @@ def get_grammar_bitmask( # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[req.request_id] = i - - if not structured_output_request_ids: - bitmask = None - else: - bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) - return structured_output_request_ids, bitmask + return structured_output_request_ids def update_from_output( self, @@ -919,14 +904,6 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # checked above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) - if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] @@ -1052,10 +1029,8 @@ def update_draft_token_ids( if not spec_token_ids: # NOTE(woosuk): request.spec_token_ids should be updated. request.spec_token_ids.clear() - elif self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + elif draft_token_ids.should_advance[req_id]: + request.spec_token_ids = draft_token_ids.spec_token_ids[req_id] else: request.spec_token_ids = spec_token_ids diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d7e9cfa3660b..e3780d66548d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -447,13 +447,6 @@ def preprocess_add_request( req = Request.from_engine_core_request(request, self.request_block_hasher) - if req.use_structured_output: - # Note on thread safety: no race condition. - # `grammar_init` is only invoked in input processing thread. For - # `structured_output_manager`, each request is independent and - # grammar compilation is async. Scheduler always checks grammar - # compilation status before scheduling request. - self.structured_output_manager.grammar_init(req) return req, request.current_wave diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3..e16e44ea4b68 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass +from dataclasses import dataclass,field from typing import NamedTuple, Optional import torch @@ -122,6 +122,10 @@ class DraftTokenIds: # num_reqs x num_draft_tokens draft_token_ids: list[list[int]] + should_advance: dict[str, bool] = field(default_factory=dict) + + spec_token_ids: dict[str, list[int]] = field(default_factory=dict) + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 57854cc11204..e1f1c6bc6ebd 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -15,6 +15,9 @@ from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.v1.utils import ConstantList +import time if TYPE_CHECKING: import numpy as np @@ -71,7 +74,7 @@ def __init__(self, vllm_config: VllmConfig): reasoning_backend) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) - def grammar_init(self, request: Request) -> None: + def grammar_init(self, request: CachedRequestState) -> None: if request.structured_output_request is None: return @@ -125,7 +128,7 @@ def grammar_init(self, request: Request) -> None: def _async_create_grammar( self, - request: Request, + request: CachedRequestState, ) -> StructuredOutputGrammar: key = request.structured_output_request.structured_output_key # type: ignore[union-attr] @@ -161,7 +164,7 @@ def _async_submit_fill_bitmask( def grammar_bitmask( self, - requests: dict[str, Request], + requests: dict[str, CachedRequestState], structured_output_request_ids: dict[str, int], scheduled_spec_decode_tokens: dict[str, list[int]], ) -> Optional[npt.NDArray[np.int32]]: @@ -230,6 +233,8 @@ def grammar_bitmask( assert structured_output_request is not None assert structured_output_request.grammar is not None apply_bitmask = self.should_fill_bitmask(request) + while structured_output_request.grammar is None: + time.sleep(0.0000001) state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) @@ -256,7 +261,7 @@ def grammar_bitmask( # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() - def should_fill_bitmask(self, request: Request) -> bool: + def should_fill_bitmask(self, request: CachedRequestState) -> bool: if self.reasoner is not None: assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: @@ -265,9 +270,10 @@ def should_fill_bitmask(self, request: Request) -> bool: return request.structured_output_request.reasoning_ended return True - def should_advance(self, request: Request) -> bool: - if not request.use_structured_output: - return False + def should_advance(self, request: CachedRequestState) -> bool: + if request.sampling_params is not None : + if request.sampling_params.guided_decoding is None: + return False # To determine whether we can advance the FSM. # Supports thinking usage where we skip the reasoning components. @@ -283,7 +289,8 @@ def should_advance(self, request: Request) -> bool: return True # Check if reasoning ends in *this* step - if self.reasoner.is_reasoning_end(request.all_token_ids): + all_token_ids = ConstantList(request.prompt_token_ids + request.output_token_ids) + if self.reasoner.is_reasoning_end(all_token_ids): # Reasoning just ended, so we shouldn't advance til # next pass structured_req.reasoning_ended = True diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ad70d9efaaaa..047af3a910d4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -24,6 +24,7 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable +from vllm.v1.structured_output.request import StructuredOutputRequest @dataclass @@ -47,6 +48,12 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None + structured_output_request: Optional[StructuredOutputRequest] = None + + requests_stop_id: Optional[int] = None + + stop : bool = False + def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42baf020e9dc..0ed09543e9fb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np +import numpy.typing as npt import torch import torch.distributed import torch.nn as nn @@ -83,6 +84,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.request import StructuredOutputRequest from .utils import (AttentionGroup, MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, @@ -320,6 +323,7 @@ def __init__( dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) + self.structured_output_manager = StructuredOutputManager(self.vllm_config) def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: return CpuGpuBuffer(*args, @@ -470,8 +474,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, + structured_output_request = StructuredOutputRequest( + sampling_params=sampling_params,), + requests_stop_id = scheduler_output.eos_token_ids.get(req_id), + ) self.requests[req_id] = req_state + if self.requests[req_id].sampling_params is not None : + if self.requests[req_id].sampling_params.guided_decoding is not None: + self.structured_output_manager.grammar_init(self.requests[req_id]) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1275,8 +1286,9 @@ def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", logits: torch.Tensor, + bitmask: Optional[npt.NDArray[np.int32]], ): - grammar_bitmask = scheduler_output.grammar_bitmask + grammar_bitmask = bitmask if grammar_bitmask is None: return @@ -1575,6 +1587,11 @@ def execute_model( inputs_embeds=inputs_embeds, **model_kwargs, ) + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + scheduler_output.structured_output_request_ids, + scheduler_output.scheduled_spec_decode_tokens, + ) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1615,8 +1632,8 @@ def execute_model( logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + if bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits, bitmask) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -1740,6 +1757,26 @@ def execute_model( ) self.eplb_step() + for req_id,num_tokens_scheduled in scheduler_output.num_scheduled_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + generated_token_ids = valid_sampled_token_ids[req_idx] if valid_sampled_token_ids else [] + count = 0 + for num_new, output_token_id in enumerate(generated_token_ids, 1): + last_token_id = output_token_id + if self.requests[req_id].sampling_params is not None : + if not self.requests[req_id].sampling_params.ignore_eos \ + and last_token_id == scheduler_output.eos_token_ids[req_id] : + count = num_new + self.requests[req_id].stop = True + break + if generated_token_ids and self.structured_output_manager.should_advance(self.requests[req_id]): + if self.requests[req_id].structured_output_request is not None : + if self.requests[req_id].structured_output_request.grammar is not None: + + if count > 0: + self.requests[req_id].structured_output_request.grammar.accept_tokens(req_id, generated_token_ids[:count]) + else: + self.requests[req_id].structured_output_request.grammar.accept_tokens(req_id, generated_token_ids) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1761,7 +1798,28 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: else: draft_token_ids = self._draft_token_ids self._draft_token_ids = None - return DraftTokenIds(req_ids, draft_token_ids) + + should_advance: dict[str, bool] = {} + spec_token_ids: dict[str, list[int]] = {} + + for i,req_id in enumerate(req_ids): + if self.structured_output_manager.should_advance(self.requests[req_id]) and not self.requests[req_id].stop: + should_advance[req_id] = True + count = 0 + for num_new, output_token_id in enumerate(draft_token_ids[i], 1): + last_token_id = output_token_id + if last_token_id == self.requests[req_id].requests_stop_id: + count = num_new + break + if self.requests[req_id].structured_output_request is not None : + if self.requests[req_id].structured_output_request.grammar is not None: + if count == 0 or count == len(draft_token_ids[i]): + spec_token_ids[req_id] = self.requests[req_id].structured_output_request.grammar.validate_tokens(draft_token_ids[i]) + else: + spec_token_ids[req_id] = self.requests[req_id].structured_output_request.grammar.validate_tokens(draft_token_ids[i][:count]) + else: + should_advance[req_id] = False + return DraftTokenIds(req_ids, draft_token_ids, should_advance, spec_token_ids) def propose_draft_token_ids( self, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5947b54d33ce..1d4a6f4733cb 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -55,6 +55,8 @@ KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.structured_output.request import StructuredOutputRequest from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, sanity_check_mm_encoder_outputs) @@ -303,6 +305,8 @@ def __init__( else: self.sample_from_logits_func = self.sample_from_logits + self.structured_output_manager = StructuredOutputManager(self.vllm_config) + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -399,6 +403,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: lora_request=new_req_data.lora_request, ) + if self.requests[req_id].sampling_params is not None: + if self.requests[req_id].sampling_params.guided_decoding is not None: + self.structured_output_manager.grammar_init(self.requests[req_id]) + req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. @@ -983,15 +991,21 @@ def execute_model( positions=self.position_ids, inputs_embeds=inputs_embeds, ) + + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + scheduler_output.structured_output_request_ids, + scheduler_output.scheduled_spec_decode_tokens, + ) hidden_states = self.select_hidden_states(hidden_states, logits_indices) logits = self.compute_logits(hidden_states) tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, padded_num_reqs, self.device) - if scheduler_output.grammar_bitmask is not None: + if bitmask is not None: require_struct_decoding, grammar_bitmask_padded, arange = \ self.prepare_structured_decoding_input(logits, - scheduler_output) + scheduler_output, bitmask) logits = self.structured_decode(require_struct_decoding, grammar_bitmask_padded, logits, arange) @@ -1113,6 +1127,26 @@ def concat_lists(input_lists): finished_sending=finished_sending, finished_recving=finished_recving, ) + + for req_id,num_tokens_scheduled in scheduler_output.num_scheduled_tokens.items(): + req_idx = self.input_batch.req_id_to_index[req_id] + generated_token_ids = valid_sampled_token_ids[req_idx] if valid_sampled_token_ids else [] + count = 0 + for num_new, output_token_id in enumerate(generated_token_ids, 1): + last_token_id = output_token_id + if self.requests[req_id].sampling_params is not None : + if not self.requests[req_id].sampling_params.ignore_eos \ + and last_token_id == scheduler_output.eos_token_ids[req_id]: + count = num_new + self.requests[req_id].stop = True + break + if generated_token_ids and self.structured_output_manager.should_advance(self.requests[req_id]): + if self.requests[req_id].structured_output_request is not None : + if self.requests[req_id].structured_output_request.grammar is not None: + if count > 0: + self.requests[req_id].structured_output_request.grammar.accept_tokens(req_id, generated_token_ids[:count]) + else: + self.requests[req_id].structured_output_request.grammar.accept_tokens(req_id, generated_token_ids) model_runner_output = ModelRunnerOutput( req_ids=req_ids, @@ -1759,9 +1793,9 @@ def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) def prepare_structured_decoding_input( - self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" + self, logits: torch.Tensor, scheduler_output: "SchedulerOutput", bitmask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = scheduler_output.grammar_bitmask + grammar_bitmask = bitmask assert grammar_bitmask is not None num_reqs, _ = logits.shape