diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d002a19b08a4..c71eb9a0445c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -627,8 +627,7 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - # Transmit partial if chunked prefill & prompt logprobs is enabled - if new_token_ids or prompt_logprobs_tensors is not None: + if new_token_ids: # Add EngineCoreOutput for this Request. outputs.append( EngineCoreOutput( @@ -639,6 +638,9 @@ def update_from_output( new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, events=request.take_events())) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors self.scheduled_req_ids.remove(request.request_id) if not stopped: diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 500de14e57d6..03d82b6bbc1d 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -115,7 +115,6 @@ def _update_prompt_logprobs( num_prompt_tokens, num_logprobs = logprobs.shape # Pythonize the torch tensors. - # TODO(rob): experiment with doing this in EngineCore? prompt_token_ranks = ranks.tolist() prompt_logprobs = logprobs.tolist() token_ids = token_ids.tolist() diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 04235eda0926..12df341772f5 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -105,9 +105,7 @@ def make_request_output( finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY - # In follow up, we will switch to invariant where EngineCore - # does not stream partial prefills. - if not finished and (self.is_prefilling or final_only): + if not finished and final_only: # Only the final output is required in FINAL_ONLY mode. return None @@ -285,19 +283,7 @@ def process_outputs( finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason - # TODO(andy): prompt logprobs + chunked prefill can - # result in engine core returning an output for a - # partial prefill (in order to send back partial - # prompt logprobs.) This breaks the invariant that - # process_outputs is only operating on engine core - # outputs associated with non-partial completions. - # Currently this is handled by having `is_prefilling` - # check for new decoded tokens, indicating that - # the completion is not partial. - # - # Follow up will aggregate partial prompt logprobs - # in the EngineCore. - req_state.is_prefilling = not new_token_ids + req_state.is_prefilling = False # 2) Detokenize the token ids into text and perform stop checks. stop_string = req_state.detokenizer.update( @@ -306,8 +292,7 @@ def process_outputs( finish_reason = FinishReason.STOP stop_reason = stop_string - # 3) Compute sample and prompt logprobs for request, - # if required. + # 3) Compute sample and prompt logprobs for request, if required. req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 83383ce1f3f7..6f3d34447426 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -100,15 +100,8 @@ def update_from_output(self, output: "EngineCoreOutput", num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens - if is_prefilling and num_new_generation_tokens > 0: - # TODO(andy): we used to assert that num_new_generation_tokens - # > 0 with an invariant that EngineCore does not stream outputs - # for partially completed prefills (scheduler.update_from_output - # makes EngineCoreOutput iff num_computed_tokens == num_tokens). - # When prompt logprobs are enabled, we currently stream out the - # partially completed prompt. - # This will be reverted in a follow up PR and we should re-enable - # this assertion / invariant. + if is_prefilling: + assert num_new_generation_tokens > 0 self.num_prompt_tokens += prompt_len first_token_latency = self._time_since(req_stats.arrival_time) @@ -123,16 +116,12 @@ def update_from_output(self, output: "EngineCoreOutput", # Process the batch-level "new tokens" engine core event if is_prefilling: - # TODO: re-enable no-output-for-partial-prefills invariant as above - if num_new_generation_tokens > 0: - req_stats.first_token_ts = engine_core_timestamp + req_stats.first_token_ts = engine_core_timestamp else: tpot = engine_core_timestamp - req_stats.last_token_ts self.time_per_output_tokens_iter.append(tpot) - # TODO: re-enable no-output-for-partial-prefills invariant as above - if num_new_generation_tokens > 0: - req_stats.last_token_ts = engine_core_timestamp + req_stats.last_token_ts = engine_core_timestamp def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], is_prefilling: bool, req_stats: RequestStateStats, diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 6f46417170f6..2732b933c28a 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -39,6 +39,25 @@ def tolists(self): self.selected_token_ranks.tolist(), ) + @staticmethod + def empty_cpu(num_positions: int, + num_tokens_per_position: int) -> "LogprobsTensors": + """Create empty LogprobsTensors on CPU.""" + + logprob_token_ids = torch.empty( + (num_positions, num_tokens_per_position), + dtype=torch.int32, + device="cpu") + logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) + selected_token_ranks = torch.empty(num_positions, + dtype=torch.int32, + device="cpu") + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=selected_token_ranks, + ) + @dataclass class SamplerOutput: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 55d5429a8935..01a5cb5548bb 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,6 +11,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values +from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import BlockTable @@ -197,6 +198,9 @@ def __init__( # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} + # To accumulate prompt logprobs tensor chunks across prefill steps. + self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} + self.logit_bias: list[Optional[dict[int, float]]] = [None] * max_num_reqs self.has_allowed_token_ids: set[str] = set() @@ -362,6 +366,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + self.in_progress_prompt_logprobs_cpu.pop(req_id, None) # LoRA lora_id = self.request_lora_mapping[req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 229849e4439b..898ffd75466c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1190,6 +1190,7 @@ def _get_prompt_logprobs_dict( if not num_prompt_logprobs_dict: return {} + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, @@ -1205,16 +1206,36 @@ def _get_prompt_logprobs_dict( prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1) + in_progress_dict[req_id] = logprobs_tensors + # Determine number of logits to retrieve. - start_tok = request.num_computed_tokens + 1 + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 num_remaining_tokens = num_prompt_tokens - start_tok - if num_tokens < num_remaining_tokens: + if num_tokens <= num_remaining_tokens: # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. num_logits = num_tokens else: # This is the last chunk of prompt tokens to return. num_logits = num_remaining_tokens completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue # Get the logits corresponding to this req's prompt tokens. # If this is a partial request (i.e. chunked prefill), @@ -1235,19 +1256,23 @@ def _get_prompt_logprobs_dict( logprobs, num_prompt_logprobs, tgt_token_ids) # Transfer GPU->CPU async. - prompt_logprobs_dict[req_id] = LogprobsTensors( - token_ids.to("cpu", non_blocking=True), - logprobs.to("cpu", non_blocking=True), - ranks.to("cpu", non_blocking=True), - ) + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( + token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, + non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( + ranks, non_blocking=True) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. for req_id in completed_prefill_reqs: del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] # Must synchronize the non-blocking GPU->CPU transfers. - torch.cuda.synchronize() + if prompt_logprobs_dict: + torch.cuda.synchronize() return prompt_logprobs_dict