Skip to content

Commit 0d450f8

Browse files
committed
[V1] Aggregate prompt logprobs in model runner
Signed-off-by: Nick Hill <[email protected]>
1 parent 61c6a5a commit 0d450f8

File tree

6 files changed

+50
-43
lines changed

6 files changed

+50
-43
lines changed

vllm/v1/core/scheduler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,7 @@ def update_from_output(
625625

626626
# Get prompt logprobs for this request.
627627
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
628-
# Transmit partial if chunked prefill & prompt logprobs is enabled
629-
if new_token_ids or prompt_logprobs_tensors is not None:
628+
if new_token_ids:
630629
# Add EngineCoreOutput for this Request.
631630
outputs.append(
632631
EngineCoreOutput(
@@ -637,6 +636,8 @@ def update_from_output(
637636
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
638637
stop_reason=request.stop_reason,
639638
events=request.take_events()))
639+
else:
640+
assert not prompt_logprobs_tensors
640641

641642
self.scheduled_req_ids.remove(request.request_id)
642643
if not stopped:

vllm/v1/engine/logprobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def _update_prompt_logprobs(
115115
num_prompt_tokens, num_logprobs = logprobs.shape
116116

117117
# Pythonize the torch tensors.
118-
# TODO(rob): experiment with doing this in EngineCore?
119118
prompt_token_ranks = ranks.tolist()
120119
prompt_logprobs = logprobs.tolist()
121120
token_ids = token_ids.tolist()

vllm/v1/engine/output_processor.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def make_request_output(
105105
finished = finish_reason is not None
106106
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
107107

108-
# In follow up, we will switch to invariant where EngineCore
109-
# does not stream partial prefills.
110-
if not finished and (self.is_prefilling or final_only):
108+
if not finished and final_only:
111109
# Only the final output is required in FINAL_ONLY mode.
112110
return None
113111

@@ -285,19 +283,7 @@ def process_outputs(
285283
finish_reason = engine_core_output.finish_reason
286284
stop_reason = engine_core_output.stop_reason
287285

288-
# TODO(andy): prompt logprobs + chunked prefill can
289-
# result in engine core returning an output for a
290-
# partial prefill (in order to send back partial
291-
# prompt logprobs.) This breaks the invariant that
292-
# process_outputs is only operating on engine core
293-
# outputs associated with non-partial completions.
294-
# Currently this is handled by having `is_prefilling`
295-
# check for new decoded tokens, indicating that
296-
# the completion is not partial.
297-
#
298-
# Follow up will aggregate partial prompt logprobs
299-
# in the EngineCore.
300-
req_state.is_prefilling = not new_token_ids
286+
req_state.is_prefilling = False
301287

302288
# 2) Detokenize the token ids into text and perform stop checks.
303289
stop_string = req_state.detokenizer.update(
@@ -306,8 +292,7 @@ def process_outputs(
306292
finish_reason = FinishReason.STOP
307293
stop_reason = stop_string
308294

309-
# 3) Compute sample and prompt logprobs for request,
310-
# if required.
295+
# 3) Compute sample and prompt logprobs for request, if required.
311296
req_state.logprobs_processor.update_from_output(engine_core_output)
312297

313298
# 4) Create and handle RequestOutput objects.

vllm/v1/metrics/stats.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,8 @@ def update_from_output(self, output: "EngineCoreOutput",
100100
num_new_generation_tokens = len(output.new_token_ids)
101101

102102
self.num_generation_tokens += num_new_generation_tokens
103-
if is_prefilling and num_new_generation_tokens > 0:
104-
# TODO(andy): we used to assert that num_new_generation_tokens
105-
# > 0 with an invariant that EngineCore does not stream outputs
106-
# for partially completed prefills (scheduler.update_from_output
107-
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
108-
# When prompt logprobs are enabled, we currently stream out the
109-
# partially completed prompt.
110-
# This will be reverted in a follow up PR and we should re-enable
111-
# this assertion / invariant.
103+
if is_prefilling:
104+
assert num_new_generation_tokens > 0
112105
self.num_prompt_tokens += prompt_len
113106

114107
first_token_latency = self._time_since(req_stats.arrival_time)
@@ -123,16 +116,12 @@ def update_from_output(self, output: "EngineCoreOutput",
123116

124117
# Process the batch-level "new tokens" engine core event
125118
if is_prefilling:
126-
# TODO: re-enable no-output-for-partial-prefills invariant as above
127-
if num_new_generation_tokens > 0:
128-
req_stats.first_token_ts = engine_core_timestamp
119+
req_stats.first_token_ts = engine_core_timestamp
129120
else:
130121
tpot = engine_core_timestamp - req_stats.last_token_ts
131122
self.time_per_output_tokens_iter.append(tpot)
132123

133-
# TODO: re-enable no-output-for-partial-prefills invariant as above
134-
if num_new_generation_tokens > 0:
135-
req_stats.last_token_ts = engine_core_timestamp
124+
req_stats.last_token_ts = engine_core_timestamp
136125

137126
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
138127
is_prefilling: bool, req_stats: RequestStateStats,

vllm/v1/worker/gpu_input_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.multimodal import MultiModalKwargs
1212
from vllm.sampling_params import SamplingParams, SamplingType
1313
from vllm.utils import swap_dict_values
14+
from vllm.v1.outputs import LogprobsTensors
1415
from vllm.v1.sample.metadata import SamplingMetadata
1516
from vllm.v1.utils import copy_slice
1617
from vllm.v1.worker.block_table import BlockTable
@@ -197,6 +198,9 @@ def __init__(
197198
# that are currently in the prefill phase.
198199
self.num_prompt_logprobs: dict[str, int] = {}
199200

201+
# To accumulate prompt logprobs tensor chunks across prefill steps.
202+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
203+
200204
self.logit_bias: list[Optional[dict[int,
201205
float]]] = [None] * max_num_reqs
202206
self.has_allowed_token_ids: set[str] = set()

vllm/v1/worker/gpu_model_runner.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,7 @@ def _get_prompt_logprobs_dict(
11341134
if not num_prompt_logprobs_dict:
11351135
return {}
11361136

1137+
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
11371138
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
11381139

11391140
# Since prompt logprobs are a rare feature, prioritize simple,
@@ -1149,16 +1150,41 @@ def _get_prompt_logprobs_dict(
11491150
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
11501151
self.device, non_blocking=True)
11511152

1153+
# Set up target LogprobsTensors object.
1154+
logprobs_tensors = in_progress_dict.get(req_id)
1155+
if not logprobs_tensors:
1156+
# Create empty logprobs CPU tensors for the entire prompt.
1157+
# If chunked, we'll copy in slice by slice.
1158+
logprob_token_ids = torch.empty(
1159+
(num_prompt_tokens - 1, num_prompt_logprobs + 1),
1160+
dtype=torch.int32,
1161+
device="cpu")
1162+
logprobs_tensors = LogprobsTensors(
1163+
logprob_token_ids=logprob_token_ids,
1164+
logprobs=torch.empty_like(logprob_token_ids,
1165+
dtype=hidden_states.dtype),
1166+
selected_token_ranks=torch.empty(num_prompt_tokens - 1,
1167+
dtype=torch.int32,
1168+
device="cpu"),
1169+
)
1170+
11521171
# Determine number of logits to retrieve.
1153-
start_tok = request.num_computed_tokens + 1
1172+
start_idx = request.num_computed_tokens
1173+
start_tok = start_idx + 1
11541174
num_remaining_tokens = num_prompt_tokens - start_tok
11551175
if num_tokens < num_remaining_tokens:
11561176
# This is a chunk, more tokens remain.
11571177
num_logits = num_tokens
1178+
if start_idx == 0:
1179+
# Store the tensors for subsequent iterations.
1180+
in_progress_dict[req_id] = logprobs_tensors
11581181
else:
11591182
# This is the last chunk of prompt tokens to return.
11601183
num_logits = num_remaining_tokens
1184+
if start_idx != 0:
1185+
del in_progress_dict[req_id]
11611186
completed_prefill_reqs.append(req_id)
1187+
prompt_logprobs_dict[req_id] = logprobs_tensors
11621188

11631189
# Get the logits corresponding to this req's prompt tokens.
11641190
# If this is a partial request (i.e. chunked prefill),
@@ -1179,19 +1205,22 @@ def _get_prompt_logprobs_dict(
11791205
logprobs, num_prompt_logprobs, tgt_token_ids)
11801206

11811207
# Transfer GPU->CPU async.
1182-
prompt_logprobs_dict[req_id] = LogprobsTensors(
1183-
token_ids.to("cpu", non_blocking=True),
1184-
logprobs.to("cpu", non_blocking=True),
1185-
ranks.to("cpu", non_blocking=True),
1186-
)
1208+
chunk_slice = slice(start_idx, start_idx + num_logits)
1209+
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
1210+
token_ids, non_blocking=True)
1211+
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
1212+
non_blocking=True)
1213+
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
1214+
ranks, non_blocking=True)
11871215

11881216
# Remove requests that have completed prefill from the batch
11891217
# num_prompt_logprobs_dict.
11901218
for req_id in completed_prefill_reqs:
11911219
del num_prompt_logprobs_dict[req_id]
11921220

11931221
# Must synchronize the non-blocking GPU->CPU transfers.
1194-
torch.cuda.synchronize()
1222+
if prompt_logprobs_dict:
1223+
torch.cuda.synchronize()
11951224

11961225
return prompt_logprobs_dict
11971226

0 commit comments

Comments
 (0)