Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,7 @@ def update_from_output(

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
if new_token_ids:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
Expand All @@ -637,6 +636,8 @@ def update_from_output(
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
else:
assert not prompt_logprobs_tensors

self.scheduled_req_ids.remove(request.request_id)
if not stopped:
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/engine/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def _update_prompt_logprobs(
num_prompt_tokens, num_logprobs = logprobs.shape

# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
Expand Down
21 changes: 3 additions & 18 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def make_request_output(
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY

# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (self.is_prefilling or final_only):
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None

Expand Down Expand Up @@ -285,19 +283,7 @@ def process_outputs(
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason

# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
req_state.is_prefilling = False

# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
Expand All @@ -306,8 +292,7 @@ def process_outputs(
finish_reason = FinishReason.STOP
stop_reason = stop_string

# 3) Compute sample and prompt logprobs for request,
# if required.
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)

# 4) Create and handle RequestOutput objects.
Expand Down
19 changes: 4 additions & 15 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,8 @@ def update_from_output(self, output: "EngineCoreOutput",
num_new_generation_tokens = len(output.new_token_ids)

self.num_generation_tokens += num_new_generation_tokens
if is_prefilling and num_new_generation_tokens > 0:
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if is_prefilling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want @markmc 's review on stats, although this all looks good to me.

assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len

first_token_latency = self._time_since(req_stats.arrival_time)
Expand All @@ -123,16 +116,12 @@ def update_from_output(self, output: "EngineCoreOutput",

# Process the batch-level "new tokens" engine core event
if is_prefilling:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.first_token_ts = engine_core_timestamp
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)

# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
req_stats.last_token_ts = engine_core_timestamp

def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
Expand Down Expand Up @@ -197,6 +198,9 @@ def __init__(
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}

# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
Expand Down
54 changes: 46 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ def _get_prompt_logprobs_dict(
if not num_prompt_logprobs_dict:
return {}

in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}

# Since prompt logprobs are a rare feature, prioritize simple,
Expand All @@ -1201,16 +1202,50 @@ def _get_prompt_logprobs_dict(
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)

# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
if not logprobs_tensors:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprob_token_ids = torch.empty(
(num_prompt_tokens - 1, num_prompt_logprobs + 1),
dtype=torch.int32,
device="cpu")
logprobs_tensors = LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=torch.empty_like(logprob_token_ids,
dtype=torch.float32),
selected_token_ranks=torch.empty(num_prompt_tokens - 1,
dtype=torch.int32,
device="cpu"),
)

# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
if start_idx == 0:
# Store the tensors for subsequent iterations.
in_progress_dict[req_id] = logprobs_tensors
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
if start_idx != 0:
del in_progress_dict[req_id]
completed_prefill_reqs.append(req_id)
prompt_logprobs_dict[req_id] = logprobs_tensors

if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue

# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
Expand All @@ -1231,19 +1266,22 @@ def _get_prompt_logprobs_dict(
logprobs, num_prompt_logprobs, tgt_token_ids)

# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
token_ids, non_blocking=True)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
ranks, non_blocking=True)

# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]

# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
if prompt_logprobs_dict:
torch.cuda.synchronize()

return prompt_logprobs_dict

Expand Down