diff --git a/tests/conftest.py b/tests/conftest.py index 91155a72b16c..41fda04a6c92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,9 @@ from tblib import pickling_support +# Import fixture +from tests.v1.entrypoints.conftest import sample_json_schema # noqa + # ruff: noqa # Install support for pickling exceptions so that we can nicely propagate diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index fba577239682..92e3831b9c7a 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -337,8 +337,6 @@ def test_stop_via_update_from_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_output = ModelRunnerOutput( @@ -385,8 +383,6 @@ def test_stop_via_update_from_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_output = ModelRunnerOutput( @@ -431,8 +427,6 @@ def test_stop_via_update_from_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_output = ModelRunnerOutput( @@ -472,8 +466,6 @@ def test_stop_via_update_from_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_output = ModelRunnerOutput( @@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init(): scheduler.add_request(request) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) - assert output.grammar_bitmask is None def test_schedule_skip_tokenizer_init_structured_output_request(): diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_scheduling.py similarity index 91% rename from tests/v1/e2e/test_async_sched_and_preempt.py rename to tests/v1/e2e/test_async_scheduling.py index 15a1cc255817..444afd5196dd 100644 --- a/tests/v1/e2e/test_async_sched_and_preempt.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -7,6 +7,7 @@ from vllm import SamplingParams from vllm.logprobs import Logprob +from vllm.sampling_params import StructuredOutputsParams from ...conftest import VllmRunner from ...models.utils import check_outputs_equal @@ -15,9 +16,12 @@ @dynamo_config.patch(cache_size_limit=16) -def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): +def test_preempt_and_async_scheduling_e2e( + sample_json_schema, monkeypatch: pytest.MonkeyPatch +): """Test consistency of combos of async scheduling, preemption, - uni/multiproc executor, and various sampling parameters.""" + uni/multiproc executor, and various sampling parameters + including structured outputs.""" first_prompt = ( "The following numbers of the sequence " @@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): dict(bad_words=["the", " the"]), dict(logprobs=2), dict(logprobs=2, presence_penalty=-1.0), + dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)), + dict( + structured_outputs=StructuredOutputsParams(json=sample_json_schema), + logprobs=2, + presence_penalty=-1.0, + ), ] default_params = dict( diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index becedb59f644..534b60312fd1 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -248,7 +248,7 @@ def execute_model( self, scheduler_output, non_block=False, - ) -> Future[ModelRunnerOutput]: + ) -> Future[ModelRunnerOutput | None]: """Make execute_model non-blocking.""" # DummyExecutor used only for testing async case. @@ -263,6 +263,23 @@ def _execute(): # Use the thread pool instead of creating a new thread return self.thread_pool.submit(_execute) + def sample_tokens( + self, grammar_output, non_block=False + ) -> Future[ModelRunnerOutput]: + """Make sample_tokens non-blocking.""" + + # DummyExecutor used only for testing async case. + assert non_block + + def _execute(): + output = self.collective_rpc("sample_tokens", args=(grammar_output,)) + # Make a copy because output[0] may be reused + # by the next batch. + return copy.deepcopy(output[0]) + + # Use the thread pool instead of creating a new thread + return self.thread_pool.submit(_execute) + @property def max_concurrent_batches(self) -> int: return 2 diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index 7293ad09a717..56574124b272 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -31,7 +31,9 @@ def collective_rpc( # Drop marker to show that this was run with open(".marker", "w"): ... - return super().collective_rpc(method, timeout, args, kwargs) + return super().collective_rpc( + method, timeout, args, kwargs, non_block, unique_reply_rank + ) CustomMultiprocExecutorAsync = CustomMultiprocExecutor diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py index b5c8f378be18..d0a6eeae6286 100644 --- a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -26,8 +26,6 @@ def _make_empty_scheduler_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, kv_connector_metadata=SharedStorageConnectorMetadata(), ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 44d8b3e331fd..1f3fdafc644d 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation(): scheduled_encoder_inputs={}, num_common_prefix_blocks=[0], finished_req_ids=set(), - free_encoder_mm_hashes=set(), - structured_output_request_ids={}, - grammar_bitmask=None, + free_encoder_mm_hashes=[], ) engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 18aa599f1aaf..7b3a07b4e12a 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) @@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner): num_common_prefix_blocks=[], finished_req_ids={req_id}, free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_runner._update_states(scheduler_output) @@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_runner._update_states(scheduler_output) @@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_runner._update_states(scheduler_output) @@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_runner._update_states(scheduler_output) @@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_runner._update_states(scheduler_output) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 9007436350be..a3cadb6a7308 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -150,8 +150,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) @@ -216,8 +214,6 @@ def test_update_states_request_finished(model_runner, dist_init): num_common_prefix_blocks=[], finished_req_ids={req_id}, free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata @@ -248,8 +244,6 @@ def test_update_states_request_resumed(model_runner, dist_init): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) model_runner._update_states(scheduler_output) @@ -277,8 +271,6 @@ def test_update_states_request_resumed(model_runner, dist_init): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata @@ -370,8 +362,6 @@ def test_update_states_no_changes(model_runner, dist_init): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) metadata_before = model_runner.input_batch.sampling_metadata @@ -407,8 +397,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids=[], - grammar_bitmask=None, ) metadata_before = model_runner._update_states(scheduler_output) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 22af489a89b9..7464f8469c3b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -6,7 +6,7 @@ from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, Literal import torch @@ -138,8 +138,11 @@ def from_connector(cls, connector: "KVConnectorBase", world_size: int): return cls(connector.get_finished_count() or world_size) def aggregate( - self, outputs: list[ModelRunnerOutput], output_rank: int = 0 - ) -> ModelRunnerOutput: + self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0 + ) -> ModelRunnerOutput | None: + if not outputs[output_rank]: + return None + # Aggregate kv_connector_output from all workers def update_finished_set( @@ -161,6 +164,7 @@ def update_finished_set( aggregated_kv_connector_stats = None invalid_block_ids = set[int]() for model_runner_output in outputs: + assert model_runner_output is not None kv_output = model_runner_output.kv_connector_output if not kv_output: continue @@ -204,6 +208,7 @@ def update_finished_set( # select output of the worker specified by output_rank output = outputs[output_rank] + assert output is not None output.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending or None, finished_recving=finished_recving or None, @@ -215,13 +220,16 @@ def update_finished_set( return output def async_aggregate( - self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0 - ) -> Future[ModelRunnerOutput]: + self, + output_futures: Sequence[Future[ModelRunnerOutput | None]], + output_rank: int = 0, + ) -> Future[ModelRunnerOutput | None]: """Takes a list of futures and returns a single future which resolves to the respective list of outputs.""" - result_future: Future[ModelRunnerOutput] = Future() + result_future: Future[ModelRunnerOutput | None] = Future() outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures) + remaining = len(output_futures) def make_callback(idx): def callback(fut): @@ -236,12 +244,10 @@ def callback(fut): result_future.set_exception(e) # this check assumes io_thread_pool uses a single thread - if all(outputs): - result_future.set_result( - self.aggregate( - cast(list[ModelRunnerOutput], outputs), output_rank - ) - ) + nonlocal remaining + remaining -= 1 + if not remaining: + result_future.set_result(self.aggregate(outputs, output_rank)) return callback diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index da6e4aa2996b..0ad994c360b0 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -15,8 +15,12 @@ def _update_after_schedule( scheduler_output: SchedulerOutput, ) -> None: super()._update_after_schedule(scheduler_output) + pending_structured_output_tokens = False for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] + pending_structured_output_tokens |= ( + request.use_structured_output and request.num_output_placeholders > 0 + ) if ( request.num_computed_tokens == request.num_tokens + request.num_output_placeholders @@ -25,6 +29,10 @@ def _update_after_schedule( # TODO(woosuk): Support speculative decoding. request.num_output_placeholders += 1 + scheduler_output.pending_structured_output_tokens = ( + pending_structured_output_tokens + ) + def _update_request_with_output( self, request: Request, diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c36483203343..291d33c9bf98 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -40,6 +40,12 @@ def schedule(self) -> "SchedulerOutput": """ raise NotImplementedError + @abstractmethod + def get_grammar_bitmask( + self, scheduler_output: "SchedulerOutput" + ) -> "GrammarOutput | None": + raise NotImplementedError + @abstractmethod def update_from_output( self, diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index cc6b89e2bf3f..866136648bcb 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -181,12 +181,17 @@ class SchedulerOutput: # freed from the encoder cache. free_encoder_mm_hashes: list[str] - # ids of structured outputs requests included in the bitmask, in the - # same order as the corresponding stacked rows of the bitmask. - # There may be more than one row per request in the case of speculative decoding. - structured_output_request_ids: list[str] - # the bitmask for the whole batch - grammar_bitmask: "npt.NDArray[np.int32] | None" + # Whether the scheduled requests have all the output tokens they + # need to perform grammar bitmask computation. + pending_structured_output_tokens: bool = False # KV Cache Connector metadata. kv_connector_metadata: KVConnectorMetadata | None = None + + +@dataclass +class GrammarOutput: + # ids of structured output requests. + structured_output_request_ids: list[str] + # Bitmask ordered as structured_output_request_ids. + grammar_bitmask: "npt.NDArray[np.int32]" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 98c8f08b0aae..f51744eb2640 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from collections.abc import Iterable -from typing import TYPE_CHECKING, Any +from typing import Any from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -24,7 +24,12 @@ ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.core.sched.output import ( + CachedRequestData, + GrammarOutput, + NewRequestData, + SchedulerOutput, +) from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs @@ -35,10 +40,6 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager -if TYPE_CHECKING: - import numpy as np - import numpy.typing as npt - logger = init_logger(__name__) @@ -619,9 +620,6 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( - num_scheduled_tokens.keys(), scheduled_spec_decode_tokens - ) # Record the request ids that were scheduled in this step. self.prev_step_scheduled_req_ids.clear() @@ -641,8 +639,6 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), - structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -872,9 +868,8 @@ def _try_schedule_encoder_inputs( def get_grammar_bitmask( self, - scheduled_request_ids: Iterable[str], - scheduled_spec_decode_tokens: dict[str, list[int]], - ) -> tuple[list[str], "npt.NDArray[np.int32] | None"]: + scheduler_output: SchedulerOutput, + ) -> GrammarOutput | None: # Collect list of scheduled request ids that use structured output. # The corresponding rows of the bitmask will be in this order. # PERF: in case of chunked prefill, @@ -883,18 +878,18 @@ def get_grammar_bitmask( # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids = [ req_id - for req_id in scheduled_request_ids + for req_id in scheduler_output.num_scheduled_tokens if (req := self.requests.get(req_id)) and req.use_structured_output ] if not structured_output_request_ids: - return structured_output_request_ids, None + return None bitmask = self.structured_output_manager.grammar_bitmask( self.requests, structured_output_request_ids, - scheduled_spec_decode_tokens, + scheduler_output.scheduled_spec_decode_tokens, ) - return structured_output_request_ids, bitmask + return GrammarOutput(structured_output_request_ids, bitmask) def update_from_output( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bfe87b718282..78af197821e2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -12,7 +12,7 @@ from contextlib import ExitStack, contextmanager from inspect import isclass, signature from logging import DEBUG -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import msgspec import zmq @@ -334,9 +334,12 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: if not self.scheduler.has_requests(): return {}, False scheduler_output = self.scheduler.schedule() - + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) with self.log_error_detail(scheduler_output): - model_output = self.model_executor.execute_model(scheduler_output) + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output @@ -376,20 +379,47 @@ def step_with_batch_queue( assert len(batch_queue) < self.batch_queue_size model_executed = False + deferred_scheduler_output = None if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, non_block=True) - batch_queue.appendleft((future, scheduler_output)) - + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) model_executed = scheduler_output.total_num_scheduled_tokens > 0 - if ( - model_executed - and len(batch_queue) < self.batch_queue_size - and not batch_queue[-1][0].done() - ): - # Don't block on next worker response unless the queue is full - # or there are no more requests to schedule. - return None, True + + if scheduler_output.pending_structured_output_tokens: + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + # Block-wait for execute to return (continues running async on the GPU). + with self.log_error_detail(scheduler_output): + exec_result = exec_future.result() + assert exec_result is None + else: + # We aren't waiting for any tokens, get any grammar output immediately. + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + # Block-wait for execute to return (continues running async on the GPU). + with self.log_error_detail(scheduler_output): + exec_result = exec_future.result() + + if exec_result is None: + # Call sample tokens. + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) + else: + # No sampling required (e.g. all requests finished). + future = cast(Future[ModelRunnerOutput], exec_future) + # Add this step's future to the queue. + batch_queue.appendleft((future, scheduler_output)) + if ( + model_executed + and len(batch_queue) < self.batch_queue_size + and not batch_queue[-1][0].done() + ): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True elif not batch_queue: # Queue is empty. We should not reach here since this method should @@ -405,6 +435,19 @@ def step_with_batch_queue( engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output ) + + # NOTE(nick): We can either handle the deferred tasks here or save + # in a field and do it immediately once step_with_batch_queue is + # re-called. The latter slightly favors TTFT over TPOT/throughput. + if deferred_scheduler_output: + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens(grammar_output, non_block=True) + batch_queue.appendleft((future, deferred_scheduler_output)) + return engine_core_outputs, model_executed def shutdown(self): diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index ef7840e1796f..d76c6107ad2b 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -16,7 +16,7 @@ from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -187,28 +187,44 @@ def get_kv_connector_handshake_metadata( @overload def execute_model( - self, - scheduler_output: SchedulerOutput, - non_block: Literal[False] = False, - ) -> ModelRunnerOutput: + self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False + ) -> ModelRunnerOutput | None: pass @overload def execute_model( - self, - scheduler_output: SchedulerOutput, - non_block: Literal[True] = True, - ) -> Future[ModelRunnerOutput]: + self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True + ) -> Future[ModelRunnerOutput | None]: pass def execute_model( self, scheduler_output: SchedulerOutput, non_block: bool = False - ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: output = self.collective_rpc( # type: ignore[call-overload] "execute_model", args=(scheduler_output,), non_block=non_block ) return output[0] + @overload + def sample_tokens( + self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False + ) -> ModelRunnerOutput: + pass + + @overload + def sample_tokens( + self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True + ) -> Future[ModelRunnerOutput]: + pass + + def sample_tokens( + self, grammar_output: GrammarOutput | None, non_block: bool = False + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + output = self.collective_rpc( # type: ignore[call-overload] + "sample_tokens", args=(grammar_output,), non_block=non_block + ) + return output[0] + def execute_dummy_batch(self) -> None: self.collective_rpc("execute_dummy_batch") diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 4c58d5771c39..999a3ba870ea 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -46,7 +46,7 @@ get_mp_context, set_process_title, ) -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -132,15 +132,12 @@ def _init_executor(self) -> None: uw.death_writer.close() self._ensure_worker_termination([uw.proc for uw in unready_workers]) - # For pipeline parallel, we use a thread pool for asynchronous - # execute_model. - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue - # _async_aggregate_workers_output also assumes a single IO thread - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io" - ) + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue. + # _async_aggregate_workers_output also assumes a single IO thread. + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io" + ) self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None @@ -180,15 +177,27 @@ def register_failure_callback(self, callback: FailureCallback): self.failure_callback = callback def execute_model( # type: ignore[override] - self, - scheduler_output: SchedulerOutput, - non_block: bool = False, + self, scheduler_output: SchedulerOutput, non_block: bool = False + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: + return self._execute_with_aggregation( + "execute_model", scheduler_output, non_block=non_block + ) + + def sample_tokens( # type: ignore[override] + self, grammar_output: GrammarOutput | None, non_block: bool = False ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + return self._execute_with_aggregation( # type: ignore[return-value] + "sample_tokens", grammar_output, non_block=non_block + ) + + def _execute_with_aggregation( + self, method: str, *args, non_block: bool = False + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: if not self.has_connector: # get output only from a single worker (output_rank) (output,) = self.collective_rpc( - "execute_model", - args=(scheduler_output,), + method, + args=args, unique_reply_rank=self.output_rank, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, @@ -197,8 +206,8 @@ def execute_model( # type: ignore[override] # get output from all workers outputs = self.collective_rpc( - "execute_model", - args=(scheduler_output,), + method, + args=args, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, ) diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index a4823acc8764..4a69cca723ac 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -19,7 +19,7 @@ get_ip, get_open_port, ) -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.executor.ray_utils import ( @@ -41,6 +41,9 @@ logger = init_logger(__name__) +COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future() +COMPLETED_NONE_FUTURE.set_result(None) + @dataclass class RayWorkerMetaData: @@ -96,6 +99,8 @@ def _init_executor(self) -> None: # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None + self.scheduler_output: SchedulerOutput | None = None + @property def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, @@ -381,22 +386,46 @@ def reinitialize_distributed( self.shutdown() def execute_model( # type: ignore[override] - self, scheduler_output: SchedulerOutput, non_block: bool = False + self, + scheduler_output: SchedulerOutput, + non_block: bool = False, + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: + if self.scheduler_output is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) + self.scheduler_output = scheduler_output + return COMPLETED_NONE_FUTURE if non_block else None + + def sample_tokens( # type: ignore[override] + self, + grammar_output: "GrammarOutput | None", + non_block: bool = False, ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: """Execute the model on the Ray workers. + The scheduler output to use should have been provided in + a prior call to execute_model(). + Args: - scheduler_output: The scheduler output to execute. + grammar_output: The structured outputs grammar bitmask, if applicable. non_block: If True, the method will return a Future. Returns: The model runner output. """ + scheduler_output = self.scheduler_output + if scheduler_output is None: + return None # noqa + + self.scheduler_output = None + # Build the compiled DAG for the first time. if self.forward_dag is None: # type: ignore self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - refs = self.forward_dag.execute(scheduler_output) # type: ignore + refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore if not self.has_connector: # Get output only from a single worker (output_rank) diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 9385e55b066f..a282cdc9909d 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -19,7 +19,7 @@ from vllm.v1.worker.worker_base import WorkerWrapperBase if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -82,36 +82,41 @@ def setup_device_if_necessary(self): def execute_model_ray( self, - scheduler_output: Union[ - "SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"] - ], + execute_model_input: tuple["SchedulerOutput", "GrammarOutput"] + | tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"], ) -> Union[ - "ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"] + "ModelRunnerOutput", + tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"], ]: # This method is used by Ray Compiled Graph to execute the model, # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() assert self.worker is not None, "Worker is not initialized" - if isinstance(scheduler_output, tuple): - scheduler_output, intermediate_tensors = scheduler_output + if len(execute_model_input) == 3: + scheduler_output, grammar_output, intermediate_tensors = ( + execute_model_input + ) else: - scheduler_output, intermediate_tensors = scheduler_output, None + scheduler_output, grammar_output = execute_model_input + intermediate_tensors = None assert self.worker.model_runner is not None output = self.worker.model_runner.execute_model( scheduler_output, intermediate_tensors ) if isinstance(output, IntermediateTensors): - output = scheduler_output, output + output = scheduler_output, grammar_output, output elif not get_pp_group().is_last_rank: # Case where there are no scheduled requests # but may still be finished requests. assert not output or not output.req_ids - output = scheduler_output, None - # Ensure outputs crossing Ray compiled DAG are serializable. - # AsyncModelRunnerOutput holds CUDA events and cannot be - # pickled. - if isinstance(output, AsyncModelRunnerOutput): - output = output.get_output() + output = scheduler_output, grammar_output, None + elif output is None: + output = self.worker.model_runner.sample_tokens(grammar_output) + # Ensure outputs crossing Ray compiled DAG are serializable. + # AsyncModelRunnerOutput holds CUDA events and cannot be + # pickled. + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() return output def override_env_vars(self, vars: dict[str, str]): diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index ef9bae2367be..d2d14fcfc436 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.import_utils import LazyLoader +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput if TYPE_CHECKING: import outlines_core as oc @@ -24,7 +25,6 @@ import xgrammar as xgr from vllm.transformers_utils.tokenizer import AnyTokenizer - from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -47,6 +47,7 @@ def apply_grammar_bitmask( scheduler_output: SchedulerOutput, + grammar_output: GrammarOutput, input_batch: InputBatch, logits: torch.Tensor, ) -> None: @@ -58,9 +59,9 @@ def apply_grammar_bitmask( input_batch (InputBatch): The input of model runner. logits (torch.Tensor): The output logits of model forward. """ - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = grammar_output.grammar_bitmask # We receive the structured output bitmask from the scheduler, # compacted to contain bitmasks only for structured output requests. @@ -79,7 +80,7 @@ def apply_grammar_bitmask( cumulative_offset += len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) ) - if req_id in scheduler_output.structured_output_request_ids: + if req_id in grammar_output.structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] @@ -91,7 +92,7 @@ def apply_grammar_bitmask( dtype=grammar_bitmask.dtype, ) cumulative_index = 0 - for req_id in scheduler_output.structured_output_request_ids: + for req_id in grammar_output.structured_output_request_ids: num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) ) @@ -101,22 +102,28 @@ def apply_grammar_bitmask( sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask + + # Copy async to device as tensor. + grammar_bitmask = torch.from_numpy(sorted_bitmask).to( + logits.device, non_blocking=True + ) # If the length of out indices and the logits have the same shape # we don't need to pass indices to the kernel, # since the bitmask is already aligned with the logits. skip_out_indices = len(out_indices) == logits.shape[0] - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() + index_tensor = None + if not skip_out_indices: + # xgrammar expects a python list of indices but it will actually work with + # a tensor. If we copy the tensor ourselves here we can do it in a non_blocking + # manner and there should be no cpu sync within xgrammar. + index_tensor = torch.tensor( + out_indices, dtype=torch.int32, device="cpu", pin_memory=True + ) + index_tensor = index_tensor.to(logits.device, non_blocking=True) - xgr.apply_token_bitmask_inplace( - logits, - grammar_bitmask.to(logits.device, non_blocking=True), - indices=out_indices if not skip_out_indices else None, - ) + xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor) class OutlinesVocabulary: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 747a7b377e40..3c32971a88c3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -109,6 +109,7 @@ EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, + KVConnectorOutput, LogprobsLists, LogprobsTensors, ModelRunnerOutput, @@ -150,7 +151,7 @@ if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput logger = init_logger(__name__) @@ -218,6 +219,20 @@ def get_output(self) -> ModelRunnerOutput: return output +class ExecuteModelState(NamedTuple): + """Ephemeral cached state transferred between execute_model() and + sample_tokens(), after execute_model() returns None.""" + + scheduler_output: "SchedulerOutput" + logits: torch.Tensor + spec_decode_metadata: SpecDecodeMetadata | None + spec_decode_common_attn_metadata: CommonAttentionMetadata | None + hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + kv_connector_output: KVConnectorOutput | None + + class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, @@ -509,6 +524,9 @@ def __init__( pin_memory=self.pin_memory, ) + # Ephemeral state transferred between execute_model() and sample_tokens(). + self.execute_model_state: ExecuteModelState | None = None + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -2113,7 +2131,6 @@ def _preprocess( num_input_tokens: int, # Padded intermediate_tensors: IntermediateTensors | None = None, ) -> tuple[ - int, torch.Tensor | None, torch.Tensor | None, torch.Tensor, @@ -2207,7 +2224,6 @@ def _preprocess( model_kwargs.update(encoder_inputs) return ( - num_scheduled_tokens, input_ids, inputs_embeds, positions, @@ -2425,13 +2441,19 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + ) -> ModelRunnerOutput | IntermediateTensors | None: + if self.execute_model_state is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: + if not num_scheduled_tokens: if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT @@ -2471,7 +2493,6 @@ def execute_model( ) ( - num_scheduled_tokens, input_ids, inputs_embeds, positions, @@ -2559,6 +2580,7 @@ def execute_model( # Rare case. assert not self.is_pooling_model + sample_hidden_states = hidden_states[logits_indices] if not get_pp_group().is_last_rank: all_gather_tensors = { "residual": not is_residual_scattered_for_sp( @@ -2572,7 +2594,6 @@ def execute_model( ) logits = None else: - sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) model_output_broadcast_data = {} @@ -2585,9 +2606,45 @@ def execute_model( assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] - # Apply structured output bitmasks if present - if scheduler_output.structured_output_request_ids: - apply_grammar_bitmask(scheduler_output, self.input_batch, logits) + self.execute_model_state = ExecuteModelState( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + ) + return None + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + return None # noqa + + # Unpack ephemeral state. + ( + scheduler_output, + logits, + spec_decode_metadata, + spec_decode_common_attn_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + kv_connector_output, + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -2646,7 +2703,7 @@ def propose_draft_token_ids(sampled_token_ids): sampler_output, logits, hidden_states, - num_scheduled_tokens, + scheduler_output.total_num_scheduled_tokens, spec_decode_metadata, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5b11bdf5282f..c2bf1419bebd 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,6 +6,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext +from types import NoneType from typing import TYPE_CHECKING, Any import torch @@ -37,6 +38,7 @@ from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, memory_profiling +from vllm.v1.core.sched.output import GrammarOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( @@ -508,11 +510,16 @@ def get_model(self) -> nn.Module: def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + @torch.inference_mode() + def sample_tokens( + self, grammar_output: "GrammarOutput" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output) + @torch.inference_mode() def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: + self, scheduler_output: "SchedulerOutput" + ) -> ModelRunnerOutput | None: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -531,13 +538,13 @@ def execute_model( ) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): + if isinstance(output, (ModelRunnerOutput, NoneType)): return output assert isinstance(output, IntermediateTensors) parallel_config = self.vllm_config.parallel_config assert ( - parallel_config.distributed_executor_backend != ("external_launcher") + parallel_config.distributed_executor_backend != "external_launcher" and not get_pp_group().is_last_rank ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0ced138b940d..0e34504a5e26 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -92,7 +92,7 @@ ) if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput logger = init_logger(__name__) @@ -372,6 +372,11 @@ def __init__( else: self.sample_from_logits_func = self.sample_from_logits + # For passing scheduler_output between successive + # execute_model() and sample_tokens() calls. + self.scheduler_output: SchedulerOutput | None = None + self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() @@ -1078,7 +1083,12 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput: + ) -> ModelRunnerOutput | None: + if self.scheduler_output is not None: + raise RuntimeError( + "State error: sample_tokens() must be called " + "after execute_model() returns None." + ) # Update cached state self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: @@ -1088,14 +1098,30 @@ def execute_model( return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + mm_embed_inputs = None if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) - else: - mm_embed_inputs = None torch_xla.sync(wait=False) + + self.scheduler_output = scheduler_output + self.mm_embed_inputs = mm_embed_inputs + return None + + @torch.no_grad() + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput: + if self.scheduler_output is None: + # Nothing to do (PP non-final rank case), output isn't used. + return None # noqa + scheduler_output = self.scheduler_output + mm_embed_inputs = self.mm_embed_inputs + self.scheduler_output = None + self.mm_embed_inputs = None + # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. start_index = 0 @@ -1131,9 +1157,9 @@ def execute_model( tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( self.input_batch, padded_num_reqs, self.device ) - if scheduler_output.grammar_bitmask is not None: + if grammar_output is not None: require_struct_decoding, grammar_bitmask_padded, arange = ( - self.prepare_structured_decoding_input(logits, scheduler_output) + self.prepare_structured_decoding_input(logits, grammar_output) ) logits = self.structured_decode( require_struct_decoding, grammar_bitmask_padded, logits, arange @@ -1954,10 +1980,9 @@ def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) def prepare_structured_decoding_input( - self, logits: torch.Tensor, scheduler_output: "SchedulerOutput" + self, logits: torch.Tensor, grammar_output: "GrammarOutput" ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = scheduler_output.grammar_bitmask - assert grammar_bitmask is not None + grammar_bitmask = grammar_output.grammar_bitmask num_reqs, _ = logits.shape # Reset pre-allocated tensors @@ -1965,7 +1990,7 @@ def prepare_structured_decoding_input( self.require_structured_out_cpu.zero_() cumulative_mask_idx = 0 - for req_id in scheduler_output.structured_output_request_ids: + for req_id in grammar_output.structured_output_request_ids: if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index e867e3c07caa..a716a9c3aa82 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -17,7 +17,6 @@ ) from vllm.distributed.kv_transfer import ( ensure_kv_transfer_initialized, - has_kv_transfer_group, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -27,7 +26,7 @@ from vllm.tasks import SupportedTask from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats @@ -255,13 +254,13 @@ def determine_available_memory(self) -> int: tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size return int(tpu_kv_cache_bytes) + def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output) + def execute_model( - self, - scheduler_output: "SchedulerOutput", + self, scheduler_output: "SchedulerOutput" ) -> ModelRunnerOutput | None: - output = self.model_runner.execute_model(scheduler_output) - # every worker's output is needed when kv_transfer_group is set up - return output if self.is_driver_worker or has_kv_transfer_group() else None + return self.model_runner.execute_model(scheduler_output) def profile(self, is_start: bool = True): if self.rank < 1: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 9162e2e85a51..30ea0ab77bd9 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -20,10 +20,12 @@ from vllm.v1.serial_utils import run_method if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput + from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput else: SchedulerOutput = object + GrammarOutput = object + AsyncModelRunnerOutput = object ModelRunnerOutput = object logger = init_logger(__name__) @@ -122,7 +124,21 @@ def load_model(self) -> None: """Load model onto target device.""" raise NotImplementedError - def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput: + def execute_model( + self, scheduler_output: SchedulerOutput + ) -> ModelRunnerOutput | None: + """If this method returns None, sample_tokens should be called immediately after + to obtain the ModelRunnerOutput. + + Note that this design may be changed in future if/when structured outputs + parallelism is re-architected. + """ + raise NotImplementedError + + def sample_tokens( + self, grammar_output: GrammarOutput + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + """Should be called immediately after execute_model iff it returned None.""" raise NotImplementedError def get_cache_block_size_bytes(self) -> int: @@ -344,7 +360,7 @@ def execute_model( scheduler_output: SchedulerOutput, *args, **kwargs, - ) -> ModelRunnerOutput: + ) -> ModelRunnerOutput | None: self._apply_mm_cache(scheduler_output) assert self.worker is not None