5151class ReqState :
5252 """Store the state a request."""
5353
54- out_list : List
54+ out_list : List [ Dict [ Any , Any ]]
5555 finished : bool
5656 event : asyncio .Event
57- obj : Any
57+ obj : Union [ GenerateReqInput , EmbeddingReqInput ]
5858
5959 # For metrics
6060 created_time : float
61- first_token_time : Optional [float ] = None
61+ finished_time : float = 0.0
62+ first_token_time : float = 0.0
63+ last_time : float = 0.0
64+ last_completion_tokens : int = 1
6265
6366 # For streaming output
6467 last_output_offset : int = 0
68+ # For incremental state update.
69+ text : str = ""
70+ output_ids : List [int ] = dataclasses .field (default_factory = list )
71+ input_token_logprobs_val : List [float ] = dataclasses .field (default_factory = list )
72+ input_token_logprobs_idx : List [int ] = dataclasses .field (default_factory = list )
73+ output_token_logprobs_val : List [float ] = dataclasses .field (default_factory = list )
74+ output_token_logprobs_idx : List [int ] = dataclasses .field (default_factory = list )
75+ input_top_logprobs_val : List [List [float ]] = dataclasses .field (default_factory = list )
76+ input_top_logprobs_idx : List [List [int ]] = dataclasses .field (default_factory = list )
77+ output_top_logprobs_val : List [List [float ]] = dataclasses .field (default_factory = list )
78+ output_top_logprobs_idx : List [List [int ]] = dataclasses .field (default_factory = list )
79+ input_token_ids_logprobs_val : List = dataclasses .field (default_factory = list )
80+ input_token_ids_logprobs_idx : List = dataclasses .field (default_factory = list )
81+ output_token_ids_logprobs_val : List = dataclasses .field (default_factory = list )
82+ output_token_ids_logprobs_idx : List = dataclasses .field (default_factory = list )
6583
6684
6785class TokenizerManager :
@@ -77,10 +95,16 @@ def __init__(
7795 # Init inter-process communication
7896 context = zmq .asyncio .Context (2 )
7997 self .recv_from_detokenizer = get_zmq_socket (
80- context , zmq .PULL , server_args .tokenizer_ipc_name
98+ context ,
99+ zmq .PULL ,
100+ server_args .tokenizer_ipc_name ,
101+ True ,
81102 )
82103 self .send_to_scheduler = get_zmq_socket (
83- context , zmq .PUSH , server_args .scheduler_input_ipc_name
104+ context ,
105+ zmq .PUSH ,
106+ server_args .scheduler_input_ipc_name ,
107+ True ,
84108 )
85109
86110 # Read model args
@@ -503,6 +527,7 @@ async def handle_loop(self):
503527 recv_obj : Union [
504528 BatchStrOut , BatchEmbeddingOut , BatchTokenIDOut , UpdateWeightReqOutput
505529 ] = await self .recv_from_detokenizer .recv_pyobj ()
530+
506531 if isinstance (recv_obj , UpdateWeightReqOutput ):
507532 if self .server_args .dp_size == 1 :
508533 self .model_update_result .set_result (recv_obj )
@@ -512,6 +537,7 @@ async def handle_loop(self):
512537 if len (self .model_update_tmp ) == self .server_args .dp_size :
513538 self .model_update_result .set_result (self .model_update_tmp )
514539 continue
540+
515541 elif isinstance (recv_obj , GetMemPoolSizeReqOutput ):
516542 if self .server_args .dp_size == 1 :
517543 self .mem_pool_size .set_result (recv_obj )
@@ -525,6 +551,7 @@ async def handle_loop(self):
525551 assert isinstance (
526552 recv_obj , (BatchStrOut , BatchEmbeddingOut , BatchTokenIDOut )
527553 ), f"Unexpected obj received: { type (recv_obj )} "
554+
528555 for i , rid in enumerate (recv_obj .rids ):
529556 state = self .rid_to_state .get (rid , None )
530557 if state is None :
@@ -537,21 +564,21 @@ async def handle_loop(self):
537564 if getattr (state .obj , "return_logprob" , False ):
538565 self .convert_logprob_style (
539566 meta_info ,
567+ state ,
540568 state .obj .top_logprobs_num ,
541569 state .obj .token_ids_logprob ,
542- state .obj .return_text_in_logprobs ,
570+ state .obj .return_text_in_logprobs
571+ and not self .server_args .skip_tokenizer_init ,
543572 recv_obj ,
544573 i ,
545574 )
546-
547575 if not isinstance (recv_obj , BatchEmbeddingOut ):
548576 meta_info .update (
549577 {
550578 "completion_tokens" : recv_obj .completion_tokens [i ],
551579 "cached_tokens" : recv_obj .cached_tokens [i ],
552580 }
553581 )
554-
555582 if isinstance (recv_obj , BatchStrOut ):
556583 out_dict = {
557584 "text" : recv_obj .output_strs [i ],
@@ -569,80 +596,132 @@ async def handle_loop(self):
569596 "embedding" : recv_obj .embeddings [i ],
570597 "meta_info" : meta_info ,
571598 }
599+
572600 state .out_list .append (out_dict )
573601 state .finished = recv_obj .finished_reasons [i ] is not None
574602 state .event .set ()
575603
576604 def convert_logprob_style (
577605 self ,
578606 meta_info : dict ,
607+ state : ReqState ,
579608 top_logprobs_num : int ,
580609 token_ids_logprob : List [int ],
581610 return_text_in_logprobs : bool ,
582611 recv_obj : BatchStrOut ,
583612 recv_obj_index : int ,
584- ):
613+ ) -> None :
614+ if len (recv_obj .input_token_logprobs_val ) > 0 :
615+ state .input_token_logprobs_val .extend (
616+ recv_obj .input_token_logprobs_val [recv_obj_index ]
617+ )
618+ state .input_token_logprobs_idx .extend (
619+ recv_obj .input_token_logprobs_idx [recv_obj_index ]
620+ )
621+ state .output_token_logprobs_val .extend (
622+ recv_obj .output_token_logprobs_val [recv_obj_index ]
623+ )
624+ state .output_token_logprobs_idx .extend (
625+ recv_obj .output_token_logprobs_idx [recv_obj_index ]
626+ )
585627 meta_info ["input_token_logprobs" ] = self .detokenize_logprob_tokens (
586- recv_obj .input_token_logprobs_val [ recv_obj_index ] ,
587- recv_obj .input_token_logprobs_idx [ recv_obj_index ] ,
628+ state .input_token_logprobs_val ,
629+ state .input_token_logprobs_idx ,
588630 return_text_in_logprobs ,
589631 )
590632 meta_info ["output_token_logprobs" ] = self .detokenize_logprob_tokens (
591- recv_obj .output_token_logprobs_val [ recv_obj_index ] ,
592- recv_obj .output_token_logprobs_idx [ recv_obj_index ] ,
633+ state .output_token_logprobs_val ,
634+ state .output_token_logprobs_idx ,
593635 return_text_in_logprobs ,
594636 )
595637
596638 if top_logprobs_num > 0 :
639+ if len (recv_obj .input_top_logprobs_val ) > 0 :
640+ state .input_top_logprobs_val .extend (
641+ recv_obj .input_top_logprobs_val [recv_obj_index ]
642+ )
643+ state .input_top_logprobs_idx .extend (
644+ recv_obj .input_top_logprobs_idx [recv_obj_index ]
645+ )
646+ state .output_top_logprobs_val .extend (
647+ recv_obj .output_top_logprobs_val [recv_obj_index ]
648+ )
649+ state .output_top_logprobs_idx .extend (
650+ recv_obj .output_top_logprobs_idx [recv_obj_index ]
651+ )
597652 meta_info ["input_top_logprobs" ] = self .detokenize_top_logprobs_tokens (
598- recv_obj .input_top_logprobs_val [ recv_obj_index ] ,
599- recv_obj .input_top_logprobs_idx [ recv_obj_index ] ,
653+ state .input_top_logprobs_val ,
654+ state .input_top_logprobs_idx ,
600655 return_text_in_logprobs ,
601656 )
602657 meta_info ["output_top_logprobs" ] = self .detokenize_top_logprobs_tokens (
603- recv_obj .output_top_logprobs_val [ recv_obj_index ] ,
604- recv_obj .output_top_logprobs_idx [ recv_obj_index ] ,
658+ state .output_top_logprobs_val ,
659+ state .output_top_logprobs_idx ,
605660 return_text_in_logprobs ,
606661 )
607662
608663 if token_ids_logprob is not None :
664+ if len (recv_obj .input_token_ids_logprobs_val ) > 0 :
665+ state .input_token_ids_logprobs_val .extend (
666+ recv_obj .input_token_ids_logprobs_val [recv_obj_index ]
667+ )
668+ state .input_token_ids_logprobs_idx .extend (
669+ recv_obj .input_token_ids_logprobs_idx [recv_obj_index ]
670+ )
671+ state .output_token_ids_logprobs_val .extend (
672+ recv_obj .output_token_ids_logprobs_val [recv_obj_index ]
673+ )
674+ state .output_token_ids_logprobs_idx .extend (
675+ recv_obj .output_token_ids_logprobs_idx [recv_obj_index ]
676+ )
609677 meta_info ["input_token_ids_logprobs" ] = self .detokenize_top_logprobs_tokens (
610- recv_obj .input_token_ids_logprobs_val [ recv_obj_index ] ,
611- recv_obj .input_token_ids_logprobs_idx [ recv_obj_index ] ,
678+ state .input_token_ids_logprobs_val ,
679+ state .input_token_ids_logprobs_idx ,
612680 return_text_in_logprobs ,
613681 )
614682 meta_info [
615683 "output_token_ids_logprobs"
616684 ] = self .detokenize_top_logprobs_tokens (
617- recv_obj .output_token_ids_logprobs_val [ recv_obj_index ] ,
618- recv_obj .output_token_ids_logprobs_idx [ recv_obj_index ] ,
685+ state .output_token_ids_logprobs_val ,
686+ state .output_token_ids_logprobs_idx ,
619687 return_text_in_logprobs ,
620688 )
621689
622690 def detokenize_logprob_tokens (
623- self , token_logprobs : List [Tuple [float , int ]], decode_to_text : bool
691+ self ,
692+ token_logprobs_val : List [float ],
693+ token_logprobs_idx : List [int ],
694+ decode_to_text : bool ,
624695 ):
625- # TODO(lianmin): This should run on DetokenizerManager
626696 if not decode_to_text :
627- return [(logprob , token_id , None ) for logprob , token_id in token_logprobs ]
628-
629- assert self .tokenizer is not None
630- token_ids = [tid for _ , tid in token_logprobs ]
631- token_texts = self .tokenizer .batch_decode (token_ids )
632- return [
633- (logprob , token_id , token_text )
634- for (logprob , token_id ), token_text in zip (token_logprobs , token_texts )
635- ]
697+ return [
698+ (logprob , token_id , None )
699+ for logprob , token_id in zip (token_logprobs_val , token_logprobs_idx )
700+ ]
701+ else :
702+ assert self .tokenizer is not None
703+ token_texts = self .tokenizer .batch_decode (token_logprobs_idx )
704+ return list (zip (token_logprobs_val , token_logprobs_idx , token_texts ))
636705
637- def detokenize_top_logprobs_tokens (self , top_logprobs , decode_to_text : bool ):
706+ def detokenize_top_logprobs_tokens (
707+ self ,
708+ token_logprobs_val : List [float ],
709+ token_logprobs_idx : List [int ],
710+ decode_to_text : bool ,
711+ ):
638712 # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
639713 # We should batch all top-k tokens in all positions.
640- for i , token_top_logprobs in enumerate (top_logprobs ):
641- if token_top_logprobs :
642- top_logprobs [i ] = self .detokenize_logprob_tokens (
643- token_top_logprobs , decode_to_text
714+ ret = []
715+ for i in range (len (token_logprobs_val )):
716+ if token_logprobs_val [i ]:
717+ ret .append (
718+ self .detokenize_logprob_tokens (
719+ token_logprobs_val [i ], token_logprobs_idx [i ], decode_to_text
720+ )
644721 )
645- return top_logprobs
722+ else :
723+ ret .append (None )
724+ return ret
646725
647726
648727class SignalHandler :
0 commit comments