Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,25 @@ def __init__(
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
super().__init__(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)
logger.info(
"Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)

async def create_completion(
self,
Expand Down Expand Up @@ -172,23 +177,28 @@ async def create_completion(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params)
default_sampling_params=self.default_sampling_params,
)

if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
max_tokens, self.model_config.logits_processor_pattern,
self.default_sampling_params)
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)

request_id_item = f"{request_id}-{i}"

self._log_inputs(request_id_item,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._log_inputs(
request_id_item,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)

trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
Expand Down Expand Up @@ -245,7 +255,8 @@ async def create_completion(
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage)
enable_force_include_usage=self.enable_force_include_usage,
)

# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
Expand Down Expand Up @@ -321,10 +332,10 @@ async def completion_stream_generator(

stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage or \
enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
include_usage = (stream_options.include_usage
or enable_force_include_usage)
include_continuous_usage = (include_usage and
stream_options.continuous_usage_stats)
else:
include_usage, include_continuous_usage = False, False

Expand Down Expand Up @@ -370,7 +381,8 @@ async def completion_stream_generator(
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
*prompt_token_ids,
*output.token_ids,
]
out_logprobs = [
*(prompt_logprobs or []),
Expand All @@ -383,8 +395,8 @@ async def completion_stream_generator(
delta_token_ids = output.token_ids
out_logprobs = output.logprobs

if not delta_text and not delta_token_ids \
and not previous_num_tokens[i]:
if (not delta_text and not delta_token_ids
and not previous_num_tokens[i]):
# Chunked prefill case, don't return empty chunks
continue

Expand Down Expand Up @@ -420,7 +432,8 @@ async def completion_stream_generator(
finish_reason=finish_reason,
stop_reason=stop_reason,
)
])
],
)
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
Expand All @@ -438,7 +451,8 @@ async def completion_stream_generator(
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
total_tokens=total_prompt_tokens + total_completion_tokens,
)

if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
Expand All @@ -452,8 +466,8 @@ async def completion_stream_generator(
choices=[],
usage=final_usage_info,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True))
final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True)
yield f"data: {final_usage_data}\n\n"

# report to FastAPI middleware aggregate usage across all choices
Expand All @@ -478,8 +492,10 @@ def request_output_to_completion_response(
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0

kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch:
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
Expand Down Expand Up @@ -548,19 +564,22 @@ def request_output_to_completion_response(
total_tokens=num_prompt_tokens + num_generated_tokens,
)

if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: final_res is an internal variable from the for-loop above. If final_res_batch is an empty list ([]), then final_res will not be initialized.

if (self.enable_prompt_tokens_details and last_final_res
and last_final_res.num_cached_tokens):
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
cached_tokens=last_final_res.num_cached_tokens)

request_metadata.final_usage_info = usage

if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
kv_transfer_params=final_res_batch[0].kv_transfer_params)
kv_transfer_params=kv_transfer_params,
)

def _create_completion_logprobs(
self,
Expand All @@ -579,8 +598,9 @@ def _create_completion_logprobs(

last_token_len = 0

should_return_as_token_id = return_as_token_id if \
return_as_token_id is not None else self.return_tokens_as_token_ids
should_return_as_token_id = (return_as_token_id
if return_as_token_id is not None else
self.return_tokens_as_token_ids)
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
Expand Down Expand Up @@ -612,10 +632,12 @@ def _create_completion_logprobs(
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id):
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=should_return_as_token_id,
):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
Expand Down