Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,7 @@ def update_from_output(

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
if new_token_ids:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
Expand All @@ -639,6 +638,9 @@ def update_from_output(
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors

self.scheduled_req_ids.remove(request.request_id)
if not stopped:
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/engine/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def _update_prompt_logprobs(
num_prompt_tokens, num_logprobs = logprobs.shape

# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
Expand Down
21 changes: 3 additions & 18 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def make_request_output(
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY

# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (self.is_prefilling or final_only):
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None

Expand Down Expand Up @@ -285,19 +283,7 @@ def process_outputs(
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason

# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
req_state.is_prefilling = False

# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
Expand All @@ -306,8 +292,7 @@ def process_outputs(
finish_reason = FinishReason.STOP
stop_reason = stop_string

# 3) Compute sample and prompt logprobs for request,
# if required.
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)

# 4) Create and handle RequestOutput objects.
Expand Down
19 changes: 4 additions & 15 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,8 @@ def update_from_output(self, output: "EngineCoreOutput",
num_new_generation_tokens = len(output.new_token_ids)

self.num_generation_tokens += num_new_generation_tokens
if is_prefilling and num_new_generation_tokens > 0:
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if is_prefilling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want @markmc 's review on stats, although this all looks good to me.

assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len

first_token_latency = self._time_since(req_stats.arrival_time)
Expand All @@ -123,16 +116,12 @@ def update_from_output(self, output: "EngineCoreOutput",

# Process the batch-level "new tokens" engine core event
if is_prefilling:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.first_token_ts = engine_core_timestamp
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)

# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
req_stats.last_token_ts = engine_core_timestamp

def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,
Expand Down
19 changes: 19 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ def tolists(self):
self.selected_token_ranks.tolist(),
)

@staticmethod
def empty_cpu(num_positions: int,
num_tokens_per_position: int) -> "LogprobsTensors":
"""Create empty LogprobsTensors on CPU."""

logprob_token_ids = torch.empty(
(num_positions, num_tokens_per_position),
dtype=torch.int32,
device="cpu")
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
selected_token_ranks = torch.empty(num_positions,
dtype=torch.int32,
device="cpu")
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=selected_token_ranks,
)


@dataclass
class SamplerOutput:
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
Expand Down Expand Up @@ -197,6 +198,9 @@ def __init__(
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}

# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
Expand Down Expand Up @@ -362,6 +366,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)

# LoRA
lora_id = self.request_lora_mapping[req_index]
Expand Down
41 changes: 33 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,7 @@ def _get_prompt_logprobs_dict(
if not num_prompt_logprobs_dict:
return {}

in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}

# Since prompt logprobs are a rare feature, prioritize simple,
Expand All @@ -1205,16 +1206,36 @@ def _get_prompt_logprobs_dict(
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)

# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
if not logprobs_tensors:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors = LogprobsTensors.empty_cpu(
num_prompt_tokens - 1, num_prompt_logprobs + 1)
in_progress_dict[req_id] = logprobs_tensors

# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
prompt_logprobs_dict[req_id] = logprobs_tensors

if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue

# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
Expand All @@ -1235,19 +1256,23 @@ def _get_prompt_logprobs_dict(
logprobs, num_prompt_logprobs, tgt_token_ids)

# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
token_ids, non_blocking=True)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
ranks, non_blocking=True)

# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
del in_progress_dict[req_id]

# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
if prompt_logprobs_dict:
torch.cuda.synchronize()

return prompt_logprobs_dict

Expand Down