Skip to content

Commit 12c6ae0

Browse files
authored
[benchmark] Refactor async request handling for streaming responses (#7835)
1 parent e3541c2 commit 12c6ae0

1 file changed

Lines changed: 207 additions & 114 deletions

File tree

benchmarks/backend_request_func_swe.py

Lines changed: 207 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class RequestFuncInput:
6060
prompt_token_ids: Optional[list] = None
6161
tokenizer_model: str = None
6262
tokenizer_path: str = None
63+
stream: bool = True
6364

6465

6566
@dataclass
@@ -226,6 +227,88 @@ def load_tokenizer(model, actor_tokenizer_path):
226227
return tokenizer
227228

228229

230+
async def handle_non_stream_response(
231+
response,
232+
output,
233+
st,
234+
):
235+
"""
236+
处理非流式返回
237+
"""
238+
text = await response.text()
239+
240+
timestamp = time.perf_counter()
241+
data = json.loads(text)
242+
# print("data:", data)
243+
244+
request_id = data.get("id", "None")
245+
246+
usage = data.get("usage", {})
247+
248+
output.output_tokens = usage.get("completion_tokens", 0)
249+
output.prompt_tokens = usage.get("prompt_tokens", 0)
250+
251+
if output.prompt_len == 0:
252+
if usage.get("prompt_tokens_details", {}):
253+
output.prompt_len = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0)
254+
255+
choices = data.get("choices", [])
256+
257+
if choices:
258+
message = choices[0].get("message", {})
259+
260+
output.generated_text = message.get("content", "") or ""
261+
output.reasoning_content = message.get("reasoning_content", "") or ""
262+
263+
completion_token_ids = message.get("completion_token_ids", [])
264+
if completion_token_ids:
265+
output.output_ids.extend(completion_token_ids)
266+
267+
# tool calls
268+
tool_calls = message.get("tool_calls") or []
269+
270+
for tc in tool_calls:
271+
func = tc.get("function", {})
272+
273+
try:
274+
args = json.loads(func.get("arguments", "{}"))
275+
except Exception:
276+
args = {}
277+
278+
output.tool_calls.append(
279+
{
280+
"id": tc.get("id"),
281+
"name": func.get("name"),
282+
"arguments": args,
283+
}
284+
)
285+
286+
latency = timestamp - st
287+
288+
# 非流式没有ttft
289+
output.ttft = latency
290+
output.res_ttft = latency
291+
292+
output.end_timestamp = timestamp
293+
output.latency = latency
294+
# 非流式没有stream chunk
295+
# 非流式兼容stream benchmark逻辑
296+
# arrival_time:
297+
output.arrival_time = []
298+
299+
has_text = output.generated_text.strip() or output.reasoning_content.strip()
300+
301+
has_tool = bool(output.tool_calls)
302+
303+
if not has_text and not has_tool:
304+
output.success = False
305+
output.error = "No generated text found!"
306+
else:
307+
output.success = True
308+
309+
return data, request_id
310+
311+
229312
async def async_request_eb_openai_chat_completions(
230313
request_func_input: RequestFuncInput,
231314
pbar: Optional[tqdm] = None,
@@ -250,14 +333,17 @@ async def async_request_eb_openai_chat_completions(
250333
payload = {
251334
"model": request_func_input.model,
252335
"messages": request_func_input.history_QA,
253-
"stream": True,
254-
"stream_options": {
255-
"include_usage": True,
256-
"continuous_usage_stats": True,
257-
},
336+
"stream": request_func_input.stream,
258337
"max_tokens": request_func_input.output_len,
259338
"collect_metrics": request_func_input.pd_metrics,
260339
}
340+
341+
# 流式模式返回usage
342+
if request_func_input.stream:
343+
payload["stream_options"] = {
344+
"include_usage": True,
345+
"continuous_usage_stats": True,
346+
}
261347
if request_func_input.json_data:
262348
json_data = request_func_input.json_data
263349

@@ -341,126 +427,133 @@ async def async_request_eb_openai_chat_completions(
341427
async with session.post(url=api_url, json=payload, headers=headers, read_bufsize=10 * 1024 * 1024) as response:
342428
data = {}
343429
if response.status == 200:
344-
async for chunk_bytes in response.content:
345-
chunk_bytes = chunk_bytes.strip()
346-
if not chunk_bytes:
347-
continue
348-
349-
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
350-
if chunk != "[DONE]":
351-
# print("####chunk:", chunk, type(chunk))
352-
timestamp = time.perf_counter()
353-
data = json.loads(chunk)
354-
# print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
355-
356-
if "metrics" in data:
357-
metrics_list.append(data["metrics"])
358-
359-
if request_id == "None" and "id" in data:
360-
request_id = data["id"]
361-
362-
if choices := data.get("choices"):
363-
content = choices[0]["delta"].get("content")
364-
reason_content = choices[0]["delta"].get("reasoning_content")
365-
tool_calls = choices[0]["delta"].get("tool_calls")
366-
completion_token_ids = choices[0]["delta"].get("completion_token_ids", [])
367-
if tool_calls:
368-
for tc in tool_calls:
369-
idx = tc.get("index", 0)
370-
371-
if idx not in tool_call_buffer:
372-
tool_call_buffer[idx] = {
373-
"id": tc.get("id"),
374-
"name": "",
375-
"arguments": "",
376-
}
377-
378-
func = tc.get("function", {})
379-
380-
if "name" in func:
381-
tool_call_buffer[idx]["name"] = func["name"]
382-
383-
if "arguments" in func:
384-
tool_call_buffer[idx]["arguments"] += func["arguments"]
385-
386-
# First token
387-
if ttft == 0.0:
388-
ttft = timestamp - st
389-
output.ttft = ttft
390-
# cached_tokens
391-
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
392-
output.prompt_len = (
393-
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
394-
)
430+
# 默认流式模式
431+
if request_func_input.stream:
432+
async for chunk_bytes in response.content:
433+
chunk_bytes = chunk_bytes.strip()
434+
if not chunk_bytes:
435+
continue
436+
437+
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
438+
if chunk != "[DONE]":
439+
# print("####chunk:", chunk, type(chunk))
440+
timestamp = time.perf_counter()
441+
data = json.loads(chunk)
442+
# print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
443+
444+
if "metrics" in data:
445+
metrics_list.append(data["metrics"])
446+
447+
if request_id == "None" and "id" in data:
448+
request_id = data["id"]
449+
450+
if choices := data.get("choices"):
451+
content = choices[0]["delta"].get("content")
452+
reason_content = choices[0]["delta"].get("reasoning_content")
453+
tool_calls = choices[0]["delta"].get("tool_calls")
454+
completion_token_ids = choices[0]["delta"].get("completion_token_ids", [])
455+
if tool_calls:
456+
for tc in tool_calls:
457+
idx = tc.get("index", 0)
458+
459+
if idx not in tool_call_buffer:
460+
tool_call_buffer[idx] = {
461+
"id": tc.get("id"),
462+
"name": "",
463+
"arguments": "",
464+
}
465+
466+
func = tc.get("function", {})
467+
468+
if "name" in func:
469+
tool_call_buffer[idx]["name"] = func["name"]
470+
471+
if "arguments" in func:
472+
tool_call_buffer[idx]["arguments"] += func["arguments"]
473+
474+
# First token
475+
if ttft == 0.0:
476+
ttft = timestamp - st
477+
output.ttft = ttft
478+
# cached_tokens
479+
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
480+
output.prompt_len = (
481+
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
482+
)
483+
else:
484+
output.prompt_len = 0
485+
486+
# Decoding phase
395487
else:
396-
output.prompt_len = 0
397-
398-
# Decoding phase
399-
else:
400-
output.itl.append(timestamp - most_recent_timestamp)
401-
402-
# response首token
403-
if res_ttft == 0.0:
404-
if content:
405-
res_ttft = choices[0].get("arrival_time", timestamp)
406-
output.res_ttft = res_ttft
407-
usage = data.get("usage") or {}
408-
output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0)
409-
410-
output.generated_text += content or ""
411-
output.reasoning_content += reason_content or ""
412-
if completion_token_ids:
413-
output.output_ids.extend(completion_token_ids)
414-
# print(f"####content:{data}")
415-
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
416-
elif usage := data.get("usage", {}):
417-
output.output_tokens = usage.get("completion_tokens", 0)
418-
output.prompt_tokens = usage.get("prompt_tokens", 0)
419-
if output.prompt_len == 0:
420-
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
421-
output.prompt_len = (
422-
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
423-
)
488+
output.itl.append(timestamp - most_recent_timestamp)
424489

425-
most_recent_timestamp = timestamp
426-
token_timestamps.append(time.time())
427-
428-
# output.generated_text = generated_text
429-
# 在流式结束时,记录最后一个 chunk 收到的时间戳
430-
output.end_timestamp = most_recent_timestamp
431-
# 截断case也记录usage
432-
usage = data.get("usage", {})
433-
if usage:
490+
# response首token
491+
if res_ttft == 0.0:
492+
if content:
493+
res_ttft = choices[0].get("arrival_time", timestamp)
494+
output.res_ttft = res_ttft
495+
usage = data.get("usage") or {}
496+
output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0)
497+
498+
output.generated_text += content or ""
499+
output.reasoning_content += reason_content or ""
500+
if completion_token_ids:
501+
output.output_ids.extend(completion_token_ids)
502+
# print(f"####content:{data}")
503+
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
504+
elif usage := data.get("usage", {}):
505+
output.output_tokens = usage.get("completion_tokens", 0)
506+
output.prompt_tokens = usage.get("prompt_tokens", 0)
507+
if output.prompt_len == 0:
508+
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
509+
output.prompt_len = (
510+
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
511+
)
512+
513+
most_recent_timestamp = timestamp
514+
token_timestamps.append(time.time())
515+
516+
# output.generated_text = generated_text
517+
# 在流式结束时,记录最后一个 chunk 收到的时间戳
518+
output.end_timestamp = most_recent_timestamp
519+
# 截断case
520+
usage = data.get("usage", {})
434521
output.output_tokens = usage.get("completion_tokens", 0)
435522
output.prompt_tokens = usage.get("prompt_tokens", 0)
436523
if output.prompt_len == 0:
437-
prompt_details = usage.get("prompt_tokens_details", {})
438-
if prompt_details:
439-
output.prompt_len = prompt_details.get("cached_tokens", 0)
524+
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
525+
output.prompt_len = data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
440526

441-
if tool_call_buffer:
442-
for _, tc in tool_call_buffer.items():
443-
try:
444-
args = json.loads(tc["arguments"]) if tc["arguments"] else {}
445-
except:
446-
args = {}
527+
if tool_call_buffer:
528+
for _, tc in tool_call_buffer.items():
529+
try:
530+
args = json.loads(tc["arguments"]) if tc["arguments"] else {}
531+
except:
532+
args = {}
447533

448-
output.tool_calls.append({"id": tc["id"], "name": tc["name"], "arguments": args})
534+
output.tool_calls.append({"id": tc["id"], "name": tc["name"], "arguments": args})
449535

450-
# 新增metrics统计,计算首token过滤空包
451-
output.metrics = metrics_summary(metrics_list, token_timestamps[1:])
536+
# 新增metrics统计,计算首token过滤空包
537+
output.metrics = metrics_summary(metrics_list, token_timestamps[1:])
452538

453-
has_text = output.generated_text.strip() or output.reasoning_content.strip()
454-
has_tool = getattr(output, "tool_calls", None)
539+
has_text = output.generated_text.strip() or output.reasoning_content.strip()
540+
has_tool = getattr(output, "tool_calls", None)
455541

456-
# 兼容思考内容超长截断的情况,此时回复内容为空
457-
if not has_text and not has_tool:
458-
output.success = False
459-
output.reasoning_tokens = output.output_tokens
460-
output.error = "No generated text found!"
542+
# 兼容思考内容超长截断的情况,此时回复内容为空
543+
if not has_text and not has_tool:
544+
output.success = False
545+
output.reasoning_tokens = output.output_tokens
546+
output.error = "No generated text found!"
547+
else:
548+
output.success = True
549+
output.latency = most_recent_timestamp - st
461550
else:
462-
output.success = True
463-
output.latency = most_recent_timestamp - st
551+
# 非流式模式
552+
data, request_id = await handle_non_stream_response(
553+
response=response,
554+
output=output,
555+
st=st,
556+
)
464557
else:
465558
error_text = await response.text()
466559
print(

0 commit comments

Comments
 (0)