@@ -60,20 +60,25 @@ def __init__(
6060 enable_prompt_tokens_details : bool = False ,
6161 enable_force_include_usage : bool = False ,
6262 ):
63- super ().__init__ (engine_client = engine_client ,
64- model_config = model_config ,
65- models = models ,
66- request_logger = request_logger ,
67- return_tokens_as_token_ids = return_tokens_as_token_ids ,
68- enable_force_include_usage = enable_force_include_usage )
63+ super ().__init__ (
64+ engine_client = engine_client ,
65+ model_config = model_config ,
66+ models = models ,
67+ request_logger = request_logger ,
68+ return_tokens_as_token_ids = return_tokens_as_token_ids ,
69+ enable_force_include_usage = enable_force_include_usage ,
70+ )
6971 self .enable_prompt_tokens_details = enable_prompt_tokens_details
7072 self .default_sampling_params = (
7173 self .model_config .get_diff_sampling_param ())
7274 if self .default_sampling_params :
7375 source = self .model_config .generation_config
7476 source = "model" if source == "auto" else source
75- logger .info ("Using default completion sampling params from %s: %s" ,
76- source , self .default_sampling_params )
77+ logger .info (
78+ "Using default completion sampling params from %s: %s" ,
79+ source ,
80+ self .default_sampling_params ,
81+ )
7782
7883 async def create_completion (
7984 self ,
@@ -172,23 +177,28 @@ async def create_completion(
172177 max_model_len = self .max_model_len ,
173178 request = request ,
174179 input_length = input_length ,
175- default_sampling_params = self .default_sampling_params )
180+ default_sampling_params = self .default_sampling_params ,
181+ )
176182
177183 if request .use_beam_search :
178184 sampling_params = request .to_beam_search_params (
179185 max_tokens , self .default_sampling_params )
180186 else :
181187 sampling_params = request .to_sampling_params (
182- max_tokens , self .model_config .logits_processor_pattern ,
183- self .default_sampling_params )
188+ max_tokens ,
189+ self .model_config .logits_processor_pattern ,
190+ self .default_sampling_params ,
191+ )
184192
185193 request_id_item = f"{ request_id } -{ i } "
186194
187- self ._log_inputs (request_id_item ,
188- request_prompts [i ],
189- params = sampling_params ,
190- lora_request = lora_request ,
191- prompt_adapter_request = prompt_adapter_request )
195+ self ._log_inputs (
196+ request_id_item ,
197+ request_prompts [i ],
198+ params = sampling_params ,
199+ lora_request = lora_request ,
200+ prompt_adapter_request = prompt_adapter_request ,
201+ )
192202
193203 trace_headers = (None if raw_request is None else await
194204 self ._get_trace_headers (raw_request .headers ))
@@ -245,7 +255,8 @@ async def create_completion(
245255 num_prompts = num_prompts ,
246256 tokenizer = tokenizer ,
247257 request_metadata = request_metadata ,
248- enable_force_include_usage = self .enable_force_include_usage )
258+ enable_force_include_usage = self .enable_force_include_usage ,
259+ )
249260
250261 # Non-streaming response
251262 final_res_batch : list [Optional [RequestOutput ]] = [None ] * num_prompts
@@ -321,10 +332,10 @@ async def completion_stream_generator(
321332
322333 stream_options = request .stream_options
323334 if stream_options :
324- include_usage = stream_options .include_usage or \
325- enable_force_include_usage
326- include_continuous_usage = include_usage and \
327- stream_options .continuous_usage_stats
335+ include_usage = ( stream_options .include_usage
336+ or enable_force_include_usage )
337+ include_continuous_usage = ( include_usage and
338+ stream_options .continuous_usage_stats )
328339 else :
329340 include_usage , include_continuous_usage = False , False
330341
@@ -370,7 +381,8 @@ async def completion_stream_generator(
370381 # echo the prompt and first token
371382 delta_text = prompt_text + output .text
372383 delta_token_ids = [
373- * prompt_token_ids , * output .token_ids
384+ * prompt_token_ids ,
385+ * output .token_ids ,
374386 ]
375387 out_logprobs = [
376388 * (prompt_logprobs or []),
@@ -383,8 +395,8 @@ async def completion_stream_generator(
383395 delta_token_ids = output .token_ids
384396 out_logprobs = output .logprobs
385397
386- if not delta_text and not delta_token_ids \
387- and not previous_num_tokens [i ]:
398+ if ( not delta_text and not delta_token_ids
399+ and not previous_num_tokens [i ]) :
388400 # Chunked prefill case, don't return empty chunks
389401 continue
390402
@@ -420,7 +432,8 @@ async def completion_stream_generator(
420432 finish_reason = finish_reason ,
421433 stop_reason = stop_reason ,
422434 )
423- ])
435+ ],
436+ )
424437 if include_continuous_usage :
425438 prompt_tokens = num_prompt_tokens [prompt_idx ]
426439 completion_tokens = previous_num_tokens [i ]
@@ -438,7 +451,8 @@ async def completion_stream_generator(
438451 final_usage_info = UsageInfo (
439452 prompt_tokens = total_prompt_tokens ,
440453 completion_tokens = total_completion_tokens ,
441- total_tokens = total_prompt_tokens + total_completion_tokens )
454+ total_tokens = total_prompt_tokens + total_completion_tokens ,
455+ )
442456
443457 if self .enable_prompt_tokens_details and num_cached_tokens :
444458 final_usage_info .prompt_tokens_details = PromptTokenUsageInfo (
@@ -452,8 +466,8 @@ async def completion_stream_generator(
452466 choices = [],
453467 usage = final_usage_info ,
454468 )
455- final_usage_data = ( final_usage_chunk .model_dump_json (
456- exclude_unset = False , exclude_none = True ))
469+ final_usage_data = final_usage_chunk .model_dump_json (
470+ exclude_unset = False , exclude_none = True )
457471 yield f"data: { final_usage_data } \n \n "
458472
459473 # report to FastAPI middleware aggregate usage across all choices
@@ -478,8 +492,10 @@ def request_output_to_completion_response(
478492 choices : list [CompletionResponseChoice ] = []
479493 num_prompt_tokens = 0
480494 num_generated_tokens = 0
481-
495+ kv_transfer_params = None
496+ last_final_res = None
482497 for final_res in final_res_batch :
498+ last_final_res = final_res
483499 prompt_token_ids = final_res .prompt_token_ids
484500 assert prompt_token_ids is not None
485501 prompt_logprobs = clamp_prompt_logprobs (final_res .prompt_logprobs )
@@ -548,19 +564,22 @@ def request_output_to_completion_response(
548564 total_tokens = num_prompt_tokens + num_generated_tokens ,
549565 )
550566
551- if self .enable_prompt_tokens_details and final_res .num_cached_tokens :
567+ if (self .enable_prompt_tokens_details and last_final_res
568+ and last_final_res .num_cached_tokens ):
552569 usage .prompt_tokens_details = PromptTokenUsageInfo (
553- cached_tokens = final_res .num_cached_tokens )
570+ cached_tokens = last_final_res .num_cached_tokens )
554571
555572 request_metadata .final_usage_info = usage
556-
573+ if final_res_batch :
574+ kv_transfer_params = final_res_batch [0 ].kv_transfer_params
557575 return CompletionResponse (
558576 id = request_id ,
559577 created = created_time ,
560578 model = model_name ,
561579 choices = choices ,
562580 usage = usage ,
563- kv_transfer_params = final_res_batch [0 ].kv_transfer_params )
581+ kv_transfer_params = kv_transfer_params ,
582+ )
564583
565584 def _create_completion_logprobs (
566585 self ,
@@ -579,8 +598,9 @@ def _create_completion_logprobs(
579598
580599 last_token_len = 0
581600
582- should_return_as_token_id = return_as_token_id if \
583- return_as_token_id is not None else self .return_tokens_as_token_ids
601+ should_return_as_token_id = (return_as_token_id
602+ if return_as_token_id is not None else
603+ self .return_tokens_as_token_ids )
584604 for i , token_id in enumerate (token_ids ):
585605 step_top_logprobs = top_logprobs [i ]
586606 if step_top_logprobs is None :
@@ -612,10 +632,12 @@ def _create_completion_logprobs(
612632 out_top_logprobs .append ({
613633 # Convert float("-inf") to the
614634 # JSON-serializable float that OpenAI uses
615- self ._get_decoded_token (top_lp [1 ],
616- top_lp [0 ],
617- tokenizer ,
618- return_as_token_id = should_return_as_token_id ):
635+ self ._get_decoded_token (
636+ top_lp [1 ],
637+ top_lp [0 ],
638+ tokenizer ,
639+ return_as_token_id = should_return_as_token_id ,
640+ ):
619641 max (top_lp [1 ].logprob , - 9999.0 )
620642 for i , top_lp in enumerate (step_top_logprobs .items ())
621643 if num_output_top_logprobs >= i
0 commit comments