@@ -60,6 +60,7 @@ class RequestFuncInput:
6060 prompt_token_ids : Optional [list ] = None
6161 tokenizer_model : str = None
6262 tokenizer_path : str = None
63+ stream : bool = True
6364
6465
6566@dataclass
@@ -226,6 +227,88 @@ def load_tokenizer(model, actor_tokenizer_path):
226227 return tokenizer
227228
228229
230+ async def handle_non_stream_response (
231+ response ,
232+ output ,
233+ st ,
234+ ):
235+ """
236+ 处理非流式返回
237+ """
238+ text = await response .text ()
239+
240+ timestamp = time .perf_counter ()
241+ data = json .loads (text )
242+ # print("data:", data)
243+
244+ request_id = data .get ("id" , "None" )
245+
246+ usage = data .get ("usage" , {})
247+
248+ output .output_tokens = usage .get ("completion_tokens" , 0 )
249+ output .prompt_tokens = usage .get ("prompt_tokens" , 0 )
250+
251+ if output .prompt_len == 0 :
252+ if usage .get ("prompt_tokens_details" , {}):
253+ output .prompt_len = usage .get ("prompt_tokens_details" , {}).get ("cached_tokens" , 0 )
254+
255+ choices = data .get ("choices" , [])
256+
257+ if choices :
258+ message = choices [0 ].get ("message" , {})
259+
260+ output .generated_text = message .get ("content" , "" ) or ""
261+ output .reasoning_content = message .get ("reasoning_content" , "" ) or ""
262+
263+ completion_token_ids = message .get ("completion_token_ids" , [])
264+ if completion_token_ids :
265+ output .output_ids .extend (completion_token_ids )
266+
267+ # tool calls
268+ tool_calls = message .get ("tool_calls" ) or []
269+
270+ for tc in tool_calls :
271+ func = tc .get ("function" , {})
272+
273+ try :
274+ args = json .loads (func .get ("arguments" , "{}" ))
275+ except Exception :
276+ args = {}
277+
278+ output .tool_calls .append (
279+ {
280+ "id" : tc .get ("id" ),
281+ "name" : func .get ("name" ),
282+ "arguments" : args ,
283+ }
284+ )
285+
286+ latency = timestamp - st
287+
288+ # 非流式没有ttft
289+ output .ttft = latency
290+ output .res_ttft = latency
291+
292+ output .end_timestamp = timestamp
293+ output .latency = latency
294+ # 非流式没有stream chunk
295+ # 非流式兼容stream benchmark逻辑
296+ # arrival_time:
297+ output .arrival_time = []
298+
299+ has_text = output .generated_text .strip () or output .reasoning_content .strip ()
300+
301+ has_tool = bool (output .tool_calls )
302+
303+ if not has_text and not has_tool :
304+ output .success = False
305+ output .error = "No generated text found!"
306+ else :
307+ output .success = True
308+
309+ return data , request_id
310+
311+
229312async def async_request_eb_openai_chat_completions (
230313 request_func_input : RequestFuncInput ,
231314 pbar : Optional [tqdm ] = None ,
@@ -250,14 +333,17 @@ async def async_request_eb_openai_chat_completions(
250333 payload = {
251334 "model" : request_func_input .model ,
252335 "messages" : request_func_input .history_QA ,
253- "stream" : True ,
254- "stream_options" : {
255- "include_usage" : True ,
256- "continuous_usage_stats" : True ,
257- },
336+ "stream" : request_func_input .stream ,
258337 "max_tokens" : request_func_input .output_len ,
259338 "collect_metrics" : request_func_input .pd_metrics ,
260339 }
340+
341+ # 流式模式返回usage
342+ if request_func_input .stream :
343+ payload ["stream_options" ] = {
344+ "include_usage" : True ,
345+ "continuous_usage_stats" : True ,
346+ }
261347 if request_func_input .json_data :
262348 json_data = request_func_input .json_data
263349
@@ -341,126 +427,133 @@ async def async_request_eb_openai_chat_completions(
341427 async with session .post (url = api_url , json = payload , headers = headers , read_bufsize = 10 * 1024 * 1024 ) as response :
342428 data = {}
343429 if response .status == 200 :
344- async for chunk_bytes in response .content :
345- chunk_bytes = chunk_bytes .strip ()
346- if not chunk_bytes :
347- continue
348-
349- chunk = chunk_bytes .decode ("utf-8" ).removeprefix ("data: " )
350- if chunk != "[DONE]" :
351- # print("####chunk:", chunk, type(chunk))
352- timestamp = time .perf_counter ()
353- data = json .loads (chunk )
354- # print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
355-
356- if "metrics" in data :
357- metrics_list .append (data ["metrics" ])
358-
359- if request_id == "None" and "id" in data :
360- request_id = data ["id" ]
361-
362- if choices := data .get ("choices" ):
363- content = choices [0 ]["delta" ].get ("content" )
364- reason_content = choices [0 ]["delta" ].get ("reasoning_content" )
365- tool_calls = choices [0 ]["delta" ].get ("tool_calls" )
366- completion_token_ids = choices [0 ]["delta" ].get ("completion_token_ids" , [])
367- if tool_calls :
368- for tc in tool_calls :
369- idx = tc .get ("index" , 0 )
370-
371- if idx not in tool_call_buffer :
372- tool_call_buffer [idx ] = {
373- "id" : tc .get ("id" ),
374- "name" : "" ,
375- "arguments" : "" ,
376- }
377-
378- func = tc .get ("function" , {})
379-
380- if "name" in func :
381- tool_call_buffer [idx ]["name" ] = func ["name" ]
382-
383- if "arguments" in func :
384- tool_call_buffer [idx ]["arguments" ] += func ["arguments" ]
385-
386- # First token
387- if ttft == 0.0 :
388- ttft = timestamp - st
389- output .ttft = ttft
390- # cached_tokens
391- if data ["usage" ] and data ["usage" ].get ("prompt_tokens_details" , {}):
392- output .prompt_len = (
393- data ["usage" ].get ("prompt_tokens_details" , {}).get ("cached_tokens" , 0 )
394- )
430+ # 默认流式模式
431+ if request_func_input .stream :
432+ async for chunk_bytes in response .content :
433+ chunk_bytes = chunk_bytes .strip ()
434+ if not chunk_bytes :
435+ continue
436+
437+ chunk = chunk_bytes .decode ("utf-8" ).removeprefix ("data: " )
438+ if chunk != "[DONE]" :
439+ # print("####chunk:", chunk, type(chunk))
440+ timestamp = time .perf_counter ()
441+ data = json .loads (chunk )
442+ # print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
443+
444+ if "metrics" in data :
445+ metrics_list .append (data ["metrics" ])
446+
447+ if request_id == "None" and "id" in data :
448+ request_id = data ["id" ]
449+
450+ if choices := data .get ("choices" ):
451+ content = choices [0 ]["delta" ].get ("content" )
452+ reason_content = choices [0 ]["delta" ].get ("reasoning_content" )
453+ tool_calls = choices [0 ]["delta" ].get ("tool_calls" )
454+ completion_token_ids = choices [0 ]["delta" ].get ("completion_token_ids" , [])
455+ if tool_calls :
456+ for tc in tool_calls :
457+ idx = tc .get ("index" , 0 )
458+
459+ if idx not in tool_call_buffer :
460+ tool_call_buffer [idx ] = {
461+ "id" : tc .get ("id" ),
462+ "name" : "" ,
463+ "arguments" : "" ,
464+ }
465+
466+ func = tc .get ("function" , {})
467+
468+ if "name" in func :
469+ tool_call_buffer [idx ]["name" ] = func ["name" ]
470+
471+ if "arguments" in func :
472+ tool_call_buffer [idx ]["arguments" ] += func ["arguments" ]
473+
474+ # First token
475+ if ttft == 0.0 :
476+ ttft = timestamp - st
477+ output .ttft = ttft
478+ # cached_tokens
479+ if data ["usage" ] and data ["usage" ].get ("prompt_tokens_details" , {}):
480+ output .prompt_len = (
481+ data ["usage" ].get ("prompt_tokens_details" , {}).get ("cached_tokens" , 0 )
482+ )
483+ else :
484+ output .prompt_len = 0
485+
486+ # Decoding phase
395487 else :
396- output .prompt_len = 0
397-
398- # Decoding phase
399- else :
400- output .itl .append (timestamp - most_recent_timestamp )
401-
402- # response首token
403- if res_ttft == 0.0 :
404- if content :
405- res_ttft = choices [0 ].get ("arrival_time" , timestamp )
406- output .res_ttft = res_ttft
407- usage = data .get ("usage" ) or {}
408- output .reasoning_tokens = max (usage .get ("completion_tokens" , 0 ) - 1 , 0 )
409-
410- output .generated_text += content or ""
411- output .reasoning_content += reason_content or ""
412- if completion_token_ids :
413- output .output_ids .extend (completion_token_ids )
414- # print(f"####content:{data}")
415- output .arrival_time .append (choices [0 ].get ("arrival_time" , timestamp ))
416- elif usage := data .get ("usage" , {}):
417- output .output_tokens = usage .get ("completion_tokens" , 0 )
418- output .prompt_tokens = usage .get ("prompt_tokens" , 0 )
419- if output .prompt_len == 0 :
420- if data ["usage" ] and data ["usage" ].get ("prompt_tokens_details" , {}):
421- output .prompt_len = (
422- data ["usage" ].get ("prompt_tokens_details" , {}).get ("cached_tokens" , 0 )
423- )
488+ output .itl .append (timestamp - most_recent_timestamp )
424489
425- most_recent_timestamp = timestamp
426- token_timestamps .append (time .time ())
427-
428- # output.generated_text = generated_text
429- # 在流式结束时,记录最后一个 chunk 收到的时间戳
430- output .end_timestamp = most_recent_timestamp
431- # 截断case也记录usage
432- usage = data .get ("usage" , {})
433- if usage :
490+ # response首token
491+ if res_ttft == 0.0 :
492+ if content :
493+ res_ttft = choices [0 ].get ("arrival_time" , timestamp )
494+ output .res_ttft = res_ttft
495+ usage = data .get ("usage" ) or {}
496+ output .reasoning_tokens = max (usage .get ("completion_tokens" , 0 ) - 1 , 0 )
497+
498+ output .generated_text += content or ""
499+ output .reasoning_content += reason_content or ""
500+ if completion_token_ids :
501+ output .output_ids .extend (completion_token_ids )
502+ # print(f"####content:{data}")
503+ output .arrival_time .append (choices [0 ].get ("arrival_time" , timestamp ))
504+ elif usage := data .get ("usage" , {}):
505+ output .output_tokens = usage .get ("completion_tokens" , 0 )
506+ output .prompt_tokens = usage .get ("prompt_tokens" , 0 )
507+ if output .prompt_len == 0 :
508+ if data ["usage" ] and data ["usage" ].get ("prompt_tokens_details" , {}):
509+ output .prompt_len = (
510+ data ["usage" ].get ("prompt_tokens_details" , {}).get ("cached_tokens" , 0 )
511+ )
512+
513+ most_recent_timestamp = timestamp
514+ token_timestamps .append (time .time ())
515+
516+ # output.generated_text = generated_text
517+ # 在流式结束时,记录最后一个 chunk 收到的时间戳
518+ output .end_timestamp = most_recent_timestamp
519+ # 截断case
520+ usage = data .get ("usage" , {})
434521 output .output_tokens = usage .get ("completion_tokens" , 0 )
435522 output .prompt_tokens = usage .get ("prompt_tokens" , 0 )
436523 if output .prompt_len == 0 :
437- prompt_details = usage .get ("prompt_tokens_details" , {})
438- if prompt_details :
439- output .prompt_len = prompt_details .get ("cached_tokens" , 0 )
524+ if data ["usage" ] and data ["usage" ].get ("prompt_tokens_details" , {}):
525+ output .prompt_len = data ["usage" ].get ("prompt_tokens_details" , {}).get ("cached_tokens" , 0 )
440526
441- if tool_call_buffer :
442- for _ , tc in tool_call_buffer .items ():
443- try :
444- args = json .loads (tc ["arguments" ]) if tc ["arguments" ] else {}
445- except :
446- args = {}
527+ if tool_call_buffer :
528+ for _ , tc in tool_call_buffer .items ():
529+ try :
530+ args = json .loads (tc ["arguments" ]) if tc ["arguments" ] else {}
531+ except :
532+ args = {}
447533
448- output .tool_calls .append ({"id" : tc ["id" ], "name" : tc ["name" ], "arguments" : args })
534+ output .tool_calls .append ({"id" : tc ["id" ], "name" : tc ["name" ], "arguments" : args })
449535
450- # 新增metrics统计,计算首token过滤空包
451- output .metrics = metrics_summary (metrics_list , token_timestamps [1 :])
536+ # 新增metrics统计,计算首token过滤空包
537+ output .metrics = metrics_summary (metrics_list , token_timestamps [1 :])
452538
453- has_text = output .generated_text .strip () or output .reasoning_content .strip ()
454- has_tool = getattr (output , "tool_calls" , None )
539+ has_text = output .generated_text .strip () or output .reasoning_content .strip ()
540+ has_tool = getattr (output , "tool_calls" , None )
455541
456- # 兼容思考内容超长截断的情况,此时回复内容为空
457- if not has_text and not has_tool :
458- output .success = False
459- output .reasoning_tokens = output .output_tokens
460- output .error = "No generated text found!"
542+ # 兼容思考内容超长截断的情况,此时回复内容为空
543+ if not has_text and not has_tool :
544+ output .success = False
545+ output .reasoning_tokens = output .output_tokens
546+ output .error = "No generated text found!"
547+ else :
548+ output .success = True
549+ output .latency = most_recent_timestamp - st
461550 else :
462- output .success = True
463- output .latency = most_recent_timestamp - st
551+ # 非流式模式
552+ data , request_id = await handle_non_stream_response (
553+ response = response ,
554+ output = output ,
555+ st = st ,
556+ )
464557 else :
465558 error_text = await response .text ()
466559 print (
0 commit comments