Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,18 +433,28 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
model=model_name,
messages=messages,
max_tokens=10,
extra_body=dict(min_tokens=10),
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats": True
"continuous_usage_stats": True,
},
)
last_completion_tokens = 0
async for chunk in stream:
assert chunk.usage.prompt_tokens >= 0
assert chunk.usage.completion_tokens >= 0
assert last_completion_tokens == 0 or \
chunk.usage.completion_tokens > last_completion_tokens or \
(
not chunk.choices and
chunk.usage.completion_tokens == last_completion_tokens
)
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
last_completion_tokens = chunk.usage.completion_tokens

assert last_completion_tokens == 10


# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
Expand Down
115 changes: 41 additions & 74 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,14 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"
return

stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False

try:
async for res in result_generator:
if res.prompt_token_ids is not None:
Expand All @@ -348,7 +356,6 @@ async def chat_completion_stream_generator(
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
Expand All @@ -364,19 +371,12 @@ async def chat_completion_stream_generator(
choices=[choice_data],
model=model_name)

# if usage should be included
if (request.stream_options
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
# if continuous usage stats are requested, add it
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
Expand Down Expand Up @@ -404,17 +404,11 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)

data = chunk.model_dump_json(
exclude_unset=True)
Expand Down Expand Up @@ -494,36 +488,11 @@ async def chat_completion_stream_generator(

if output.finish_reason is None:
# Send token-by-token response for each request.n

choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

# if the model is finished generating
else:
Expand Down Expand Up @@ -573,34 +542,32 @@ async def chat_completion_stream_generator(
finish_reason=output.finish_reason
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

finish_reason_sent[i] = True

chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

# handle usage stats if requested & if continuous
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
if include_usage:
completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
Expand Down