diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 5b1de3a66ceb..0201a9dadd92 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarBitmask, SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -41,6 +41,14 @@ def schedule(self) -> "SchedulerOutput": """ raise NotImplementedError + @abstractmethod + def get_grammar_bitmask( + self, + scheduler_output: "SchedulerOutput", + ) -> Optional["GrammarBitmask"]: + """Get the grammar bitmask for the scheduled requests.""" + raise NotImplementedError + @abstractmethod def update_from_output( self, diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 9ba7ec9d9693..ea264af582fa 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -147,11 +147,15 @@ class SchedulerOutput: # Used to free the encoder cache. free_encoder_input_ids: list[tuple[str, int]] + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None + + +@dataclass +class GrammarBitmask: + # Dict of request ids to their index within the batch # for filling the next token bitmask structured_output_request_ids: dict[str, int] # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] - - # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 60d5720b6bef..a05d60fab941 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -21,8 +21,8 @@ compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) +from vllm.v1.core.sched.output import (CachedRequestData, GrammarBitmask, + NewRequestData, SchedulerOutput) from vllm.v1.core.sched.request_queue import (SchedulingPolicy, create_request_queue) from vllm.v1.core.sched.utils import check_stop, remove_all @@ -534,9 +534,6 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(self.running, - scheduled_spec_decode_tokens)) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -551,8 +548,6 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), - structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, ) # NOTE(Kuntai): this function is designed for multiple purposes: @@ -736,9 +731,8 @@ def _try_schedule_encoder_inputs( def get_grammar_bitmask( self, - requests: list[Request], - scheduled_spec_decode_tokens: dict[str, list[int]], - ): + scheduler_output: SchedulerOutput, + ) -> Optional[GrammarBitmask]: # NOTE: structured_output_request_ids maps # a request's (request that uses structured output) # request_id to its index in the batch. @@ -746,8 +740,10 @@ def get_grammar_bitmask( # and only applies valid mask for requests that # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - for i, req in enumerate(requests): - if req.use_structured_output: + req_ids = scheduler_output.num_scheduled_tokens.keys() + for i, req_id in enumerate(req_ids): + req = self.requests.get(req_id) + if req is not None and req.use_structured_output: # PERF: in case of chunked prefill, # request might not include any new tokens. # Therefore, we might introduce some additional @@ -755,14 +751,13 @@ def get_grammar_bitmask( structured_output_request_ids[req.request_id] = i if not structured_output_request_ids: - bitmask = None - else: - bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) - return structured_output_request_ids, bitmask + return None + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduler_output.scheduled_spec_decode_tokens, + ) + return GrammarBitmask(structured_output_request_ids, bitmask) def update_from_output( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 32765cda6482..dbc829fab375 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -286,12 +286,12 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: if not self.scheduler.has_requests(): return {}, False scheduler_output = self.scheduler.schedule() - model_output = self.execute_model_with_error_logging( - self.model_executor.execute_model, # type: ignore - scheduler_output) + self.model_executor.prepare_inputs(scheduler_output) + self.model_executor.execute_model() + bitmask = self.scheduler.get_grammar_bitmask(scheduler_output) + model_output = self.model_executor.sample(bitmask, non_block=False) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output) # type: ignore - return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) @@ -327,7 +327,10 @@ def step_with_batch_queue( if not self.batch_queue.full(): scheduler_output = self.scheduler.schedule() if scheduler_output.total_num_scheduled_tokens > 0: - future = self.model_executor.execute_model(scheduler_output) + self.model_executor.prepare_inputs(scheduler_output) + self.model_executor.execute_model() + bitmask = self.scheduler.get_grammar_bitmask(scheduler_output) + future = self.model_executor.sample(bitmask) self.batch_queue.put_nowait( (future, scheduler_output)) # type: ignore @@ -353,6 +356,32 @@ def step_with_batch_queue( return engine_core_outputs, scheduled_batch + def step_async( + self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: + model_output = None + engine_core_outputs = None + bitmask = None + + scheduler_output = self.scheduler.schedule() + is_scheduled = scheduler_output.total_num_scheduled_tokens > 0 + if is_scheduled: + self.model_executor.prepare_inputs(scheduler_output) + if self.inflight_batch: + model_output = self.model_executor.sample(self.prev_bitmask) + self.model_executor.execute_model() + bitmask = self.scheduler.get_grammar_bitmask(scheduler_output) + elif self.inflight_batch: + model_output = self.model_executor.sample(self.prev_bitmask) + + if model_output is not None: + engine_core_outputs = self.scheduler.update_from_output( + self.prev_scheduler_output, model_output.result()) + + self.inflight_batch = is_scheduled + self.prev_scheduler_output = scheduler_output + self.prev_bitmask = bitmask + return engine_core_outputs, is_scheduled + def shutdown(self): self.structured_output_manager.clear_backend() if self.model_executor: @@ -529,8 +558,14 @@ def __init__( assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) + if self.batch_queue is None: + if self.vllm_config.scheduler_config.async_scheduling: + self.step_fn = self.step_async + self.inflight_batch = False + else: + self.step_fn = self.step + else: + self.step_fn = self.step_with_batch_queue @contextmanager def _perform_handshakes( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 063a5f592e1a..4f7d64310001 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -80,12 +80,19 @@ def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: output = self.collective_rpc("get_kv_cache_spec") return output - def execute_model( + def prepare_inputs(self, scheduler_output) -> None: + self.collective_rpc("prepare_inputs", args=(scheduler_output, )) + + def execute_model(self) -> None: + self.collective_rpc("execute_model") + + def sample( self, - scheduler_output, + grammar_bitmask, + non_block: bool = True, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - output = self.collective_rpc("execute_model", - args=(scheduler_output, )) + del non_block + output = self.collective_rpc("sample", args=(grammar_bitmask, )) return output[0] def take_draft_token_ids(self) -> Optional[DraftTokenIds]: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 15b88a212899..f48678b85a26 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -115,12 +115,11 @@ def _init_executor(self) -> None: # For pipeline parallel, we use a thread pool for asynchronous # execute_model. - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue - # _async_aggregate_workers_output also assumes a single IO thread - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue + # _async_aggregate_workers_output also assumes a single IO thread + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None @@ -162,17 +161,33 @@ def register_failure_callback(self, callback: FailureCallback): else: self.failure_callback = callback - def execute_model( + def prepare_inputs(self, scheduler_output) -> None: + self.collective_rpc( + "prepare_inputs", + args=(scheduler_output, ), + non_block=True, + skip_response=True, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) + + def execute_model(self) -> None: + self.collective_rpc( + "execute_model", + non_block=True, + skip_response=True, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) + + def sample( self, - scheduler_output, + grammar_bitmask, + non_block: bool = True, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - non_block = self.max_concurrent_batches > 1 - if not self.has_connector: # get output only from a single worker (output_rank) (output, ) = self.collective_rpc( - "execute_model", - args=(scheduler_output, ), + "sample", + args=(grammar_bitmask, ), unique_reply_rank=self.output_rank, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) @@ -180,8 +195,8 @@ def execute_model( # get output from all workers outputs = self.collective_rpc( - "execute_model", - args=(scheduler_output, ), + "sample", + args=(grammar_bitmask, ), non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) @@ -203,6 +218,7 @@ def collective_rpc(self, args: tuple = (), kwargs: Optional[dict] = None, non_block: bool = False, + skip_response: bool = False, unique_reply_rank: Optional[int] = None) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -219,6 +235,15 @@ def collective_rpc(self, else: send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) + + if skip_response: + if unique_reply_rank is not None: + raise ValueError( + "unique_reply_rank must be None " + f"when skip_response is True. got {unique_reply_rank}") + self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, -1)) + return [] + self.rpc_broadcast_mq.enqueue( (send_method, args, kwargs, unique_reply_rank)) @@ -309,8 +334,6 @@ def check_health(self) -> None: @property def max_concurrent_batches(self) -> int: - if self.scheduler_config.async_scheduling: - return 2 return self.parallel_config.pipeline_parallel_size def _get_output_rank(self) -> int: diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index c05ad1966d61..d9659dd536b9 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -58,8 +58,6 @@ def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, meaning that it allows PP size batches to be executed concurrently. """ - if self.scheduler_config.async_scheduling: - return 2 return self.parallel_config.pipeline_parallel_size def execute_model( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 870aca41ec2a..7663b4457009 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -90,7 +90,7 @@ import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarBitmask, SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") xgr_torch_compile = LazyLoader( @@ -1313,11 +1313,13 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", + bitmask: Optional["GrammarBitmask"], logits: torch.Tensor, ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: + if bitmask is None: return + structured_output_request_ids = bitmask.structured_output_request_ids + grammar_bitmask = bitmask.grammar_bitmask # We receive the structured output bitmask from the scheduler, # compacted to contain bitmasks only for structured output requests. @@ -1336,7 +1338,7 @@ def apply_grammar_bitmask( logit_index = batch_index + cumulative_offset cumulative_offset += len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: + if req_id in structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] @@ -1347,8 +1349,7 @@ def apply_grammar_bitmask( fill_value=-1, dtype=grammar_bitmask.dtype) cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) + seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) for req_id, _ in seq: logit_index = struct_out_req_batch_indices[req_id] num_spec_tokens = len( @@ -1497,20 +1498,22 @@ def _pool( kv_connector_output=kv_connector_output, ) + def prepare_inputs(self, scheduler_output: "SchedulerOutput"): + # NOTE(woosuk): For now, this method only exists to fetch the + # scheduler output from shared memory and cache it. + # TODO(woosuk): Move more ops to this method. + self._scheduler_output = scheduler_output + @torch.inference_mode() def execute_model( self, - scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: + ) -> Optional[IntermediateTensors]: + scheduler_output = self._scheduler_output + self._sample_scheduler_output = scheduler_output self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) + return None # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, @@ -1634,14 +1637,16 @@ def execute_model( return hidden_states get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) + sample_hidden_states = None logits = None else: if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, kv_connector_output) + sample_hidden_states = None + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), @@ -1651,9 +1656,50 @@ def execute_model( assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] + # FIXME(woosuk): This is hacky. Refactor this. + self._num_scheduled_tokens_np = num_scheduled_tokens_np + self._spec_decode_metadata = spec_decode_metadata + self._spec_decode_common_attn_metadata = spec_decode_common_attn_metadata + self._hidden_states = hidden_states + self._aux_hidden_states = aux_hidden_states + self._sample_hidden_states = sample_hidden_states + self._logits = logits + self._kv_connector_output = kv_connector_output + return None + + def sample( + self, + grammar_bitmask: Optional["GrammarBitmask"], + ) -> ModelRunnerOutput: + # Compute logits. + if not get_pp_group().is_last_rank: + return EMPTY_MODEL_RUNNER_OUTPUT + + scheduler_output = self._sample_scheduler_output + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if num_scheduled_tokens == 0: + if not has_kv_transfer_group(): + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + + num_scheduled_tokens_np = self._num_scheduled_tokens_np + spec_decode_metadata = self._spec_decode_metadata + spec_decode_common_attn_metadata = self._spec_decode_common_attn_metadata + hidden_states = self._hidden_states + aux_hidden_states = self._aux_hidden_states + sample_hidden_states = self._sample_hidden_states + logits = self._logits + kv_connector_output = self._kv_connector_output + + if self.input_batch.pooling_params: + return self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np, kv_connector_output) + # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + if grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, grammar_bitmask, + logits) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d61177d4245d..5ba66421ce30 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.core.sched.output import GrammarBitmask, SchedulerOutput class Worker(WorkerBase): @@ -349,19 +349,23 @@ def get_model(self) -> nn.Module: def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + def prepare_inputs(self, scheduler_output: "SchedulerOutput") -> None: + self.model_runner.prepare_inputs(scheduler_output) + @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + def execute_model(self) -> None: intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) + self.model_runner.execute_model(intermediate_tensors) - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) + def sample( + self, + grammar_bitmask: Optional["GrammarBitmask"], + ) -> Optional[ModelRunnerOutput]: + output = self.model_runner.sample(grammar_bitmask) parallel_config = self.vllm_config.parallel_config if parallel_config.distributed_executor_backend != "external_launcher" \