Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_serving_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _filter_func(item):
)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]

requests.append(
SampleRequest(
prompt=prompt,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class SchedulerOutput:
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
eos_token_ids: dict[str,int]

# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
51 changes: 13 additions & 38 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def schedule(self) -> SchedulerOutput:
# For logging.
scheduled_timestamp = time.monotonic()

#for structured output
structured_output_request_ids: dict[str,int] = {}

# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
Expand Down Expand Up @@ -343,13 +346,7 @@ def schedule(self) -> SchedulerOutput:
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
request.status = RequestStatus.WAITING

# Check that adding the request still respects the max_loras
# constraint.
Expand Down Expand Up @@ -559,9 +556,10 @@ def schedule(self) -> SchedulerOutput:
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
structured_output_request_ids, grammar_bitmask = (
self.get_grammar_bitmask(self.running,
scheduled_spec_decode_tokens))

eos_token_ids = {req.request_id: req.eos_token_id for req in self.running if req.eos_token_id is not None}
structured_output_request_ids = self.get_structured_output_request_ids(self.running,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
Expand All @@ -578,7 +576,7 @@ def schedule(self) -> SchedulerOutput:
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
eos_token_ids=eos_token_ids,
)

# NOTE(Kuntai): this function is designed for multiple purposes:
Expand Down Expand Up @@ -807,17 +805,13 @@ def _try_schedule_encoder_inputs(
encoder_compute_budget,
)

def get_grammar_bitmask(
def get_structured_output_request_ids(
self,
requests: list[Request],
scheduled_spec_decode_tokens: dict[str, list[int]],
):
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to its index in the batch.
# This will help us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
for i, req in enumerate(requests):
if req.use_structured_output:
Expand All @@ -826,16 +820,7 @@ def get_grammar_bitmask(
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[req.request_id] = i

if not structured_output_request_ids:
bitmask = None
else:
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
return structured_output_request_ids

def update_from_output(
self,
Expand Down Expand Up @@ -919,14 +904,6 @@ def update_from_output(
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)

if new_token_ids and self.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# checked above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)

if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]

Expand Down Expand Up @@ -1052,10 +1029,8 @@ def update_draft_token_ids(
if not spec_token_ids:
# NOTE(woosuk): request.spec_token_ids should be updated.
request.spec_token_ids.clear()
elif self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids)
elif draft_token_ids.should_advance[req_id]:
request.spec_token_ids = draft_token_ids.spec_token_ids[req_id]
else:
request.spec_token_ids = spec_token_ids

Expand Down
7 changes: 0 additions & 7 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,6 @@ def preprocess_add_request(

req = Request.from_engine_core_request(request,
self.request_block_hasher)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For
# `structured_output_manager`, each request is independent and
# grammar compilation is async. Scheduler always checks grammar
# compilation status before scheduling request.
self.structured_output_manager.grammar_init(req)
return req, request.current_wave


Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from dataclasses import dataclass,field
from typing import NamedTuple, Optional

import torch
Expand Down Expand Up @@ -122,6 +122,10 @@ class DraftTokenIds:
# num_reqs x num_draft_tokens
draft_token_ids: list[list[int]]

should_advance: dict[str, bool] = field(default_factory=dict)

spec_token_ids: dict[str, list[int]] = field(default_factory=dict)


EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.utils import ConstantList
import time

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -71,7 +74,7 @@ def __init__(self, vllm_config: VllmConfig):
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)

def grammar_init(self, request: Request) -> None:
def grammar_init(self, request: CachedRequestState) -> None:
if request.structured_output_request is None:
return

Expand Down Expand Up @@ -125,7 +128,7 @@ def grammar_init(self, request: Request) -> None:

def _async_create_grammar(
self,
request: Request,
request: CachedRequestState,
) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]

Expand Down Expand Up @@ -161,7 +164,7 @@ def _async_submit_fill_bitmask(

def grammar_bitmask(
self,
requests: dict[str, Request],
requests: dict[str, CachedRequestState],
structured_output_request_ids: dict[str, int],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]:
Expand Down Expand Up @@ -230,6 +233,8 @@ def grammar_bitmask(
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask = self.should_fill_bitmask(request)
while structured_output_request.grammar is None:
time.sleep(0.0000001)

state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
Expand All @@ -256,7 +261,7 @@ def grammar_bitmask(
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()

def should_fill_bitmask(self, request: Request) -> bool:
def should_fill_bitmask(self, request: CachedRequestState) -> bool:
if self.reasoner is not None:
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
Expand All @@ -265,9 +270,10 @@ def should_fill_bitmask(self, request: Request) -> bool:
return request.structured_output_request.reasoning_ended
return True

def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False
def should_advance(self, request: CachedRequestState) -> bool:
if request.sampling_params is not None :
if request.sampling_params.guided_decoding is None:
return False

# To determine whether we can advance the FSM.
# Supports thinking usage where we skip the reasoning components.
Expand All @@ -283,7 +289,8 @@ def should_advance(self, request: Request) -> bool:
return True

# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
all_token_ids = ConstantList(request.prompt_token_ids + request.output_token_ids)
if self.reasoner.is_reasoning_end(all_token_ids):
# Reasoning just ended, so we shouldn't advance til
# next pass
structured_req.reasoning_ended = True
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.structured_output.request import StructuredOutputRequest


@dataclass
Expand All @@ -47,6 +48,12 @@ class CachedRequestState:

lora_request: Optional[LoRARequest] = None

structured_output_request: Optional[StructuredOutputRequest] = None

requests_stop_id: Optional[int] = None

stop : bool = False

def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)

Expand Down
66 changes: 62 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Any, Optional, Union, cast

import numpy as np
import numpy.typing as npt
import torch
import torch.distributed
import torch.nn as nn
Expand Down Expand Up @@ -83,6 +84,8 @@
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest

from .utils import (AttentionGroup, MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
Expand Down Expand Up @@ -320,6 +323,7 @@ def __init__(
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.structured_output_manager = StructuredOutputManager(self.vllm_config)

def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
Expand Down Expand Up @@ -470,8 +474,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
structured_output_request = StructuredOutputRequest(
sampling_params=sampling_params,),
requests_stop_id = scheduler_output.eos_token_ids.get(req_id),

)
self.requests[req_id] = req_state
if self.requests[req_id].sampling_params is not None :
if self.requests[req_id].sampling_params.guided_decoding is not None:
self.structured_output_manager.grammar_init(self.requests[req_id])

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
Expand Down Expand Up @@ -1275,8 +1286,9 @@ def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
bitmask: Optional[npt.NDArray[np.int32]],
):
grammar_bitmask = scheduler_output.grammar_bitmask
grammar_bitmask = bitmask
if grammar_bitmask is None:
return

Expand Down Expand Up @@ -1575,6 +1587,11 @@ def execute_model(
inputs_embeds=inputs_embeds,
**model_kwargs,
)
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
scheduler_output.structured_output_request_ids,
scheduler_output.scheduled_spec_decode_tokens,
)

if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
Expand Down Expand Up @@ -1615,8 +1632,8 @@ def execute_model(
logits = model_output_broadcast_data["logits"]

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
if bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits, bitmask)

# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
Expand Down Expand Up @@ -1740,6 +1757,26 @@ def execute_model(
)

self.eplb_step()
for req_id,num_tokens_scheduled in scheduler_output.num_scheduled_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id]
generated_token_ids = valid_sampled_token_ids[req_idx] if valid_sampled_token_ids else []
count = 0
for num_new, output_token_id in enumerate(generated_token_ids, 1):
last_token_id = output_token_id
if self.requests[req_id].sampling_params is not None :
if not self.requests[req_id].sampling_params.ignore_eos \
and last_token_id == scheduler_output.eos_token_ids[req_id] :
count = num_new
self.requests[req_id].stop = True
break
if generated_token_ids and self.structured_output_manager.should_advance(self.requests[req_id]):
if self.requests[req_id].structured_output_request is not None :
if self.requests[req_id].structured_output_request.grammar is not None:

if count > 0:
self.requests[req_id].structured_output_request.grammar.accept_tokens(req_id, generated_token_ids[:count])
else:
self.requests[req_id].structured_output_request.grammar.accept_tokens(req_id, generated_token_ids)

return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
Expand All @@ -1761,7 +1798,28 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)

should_advance: dict[str, bool] = {}
spec_token_ids: dict[str, list[int]] = {}

for i,req_id in enumerate(req_ids):
if self.structured_output_manager.should_advance(self.requests[req_id]) and not self.requests[req_id].stop:
should_advance[req_id] = True
count = 0
for num_new, output_token_id in enumerate(draft_token_ids[i], 1):
last_token_id = output_token_id
if last_token_id == self.requests[req_id].requests_stop_id:
count = num_new
break
if self.requests[req_id].structured_output_request is not None :
if self.requests[req_id].structured_output_request.grammar is not None:
if count == 0 or count == len(draft_token_ids[i]):
spec_token_ids[req_id] = self.requests[req_id].structured_output_request.grammar.validate_tokens(draft_token_ids[i])
else:
spec_token_ids[req_id] = self.requests[req_id].structured_output_request.grammar.validate_tokens(draft_token_ids[i][:count])
else:
should_advance[req_id] = False
return DraftTokenIds(req_ids, draft_token_ids, should_advance, spec_token_ids)

def propose_draft_token_ids(
self,
Expand Down
Loading