@@ -81,6 +81,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
8181
8282 def _process_sequence_group_outputs (self , seq_group : SequenceGroup ,
8383 outputs : SequenceGroupOutput ) -> None :
84+ sampling_params = seq_group .sampling_params
85+ if sampling_params .n == 1 and not sampling_params .use_beam_search :
86+ # only have one output sample
87+ sample = outputs .samples [0 ]
88+ # only have one sequence
89+ seq = seq_group .seqs [0 ]
90+ seq .append_token_id (sample .output_token , sample .logprobs )
91+ if sampling_params .detokenize and self .detokenizer :
92+ new_char_count = self .detokenizer .decode_sequence_inplace (
93+ seq , sampling_params )
94+ else :
95+ new_char_count = 0
96+ self .stop_checker .maybe_stop_sequence (
97+ seq ,
98+ new_char_count ,
99+ sampling_params ,
100+ lora_req = seq_group .lora_request ,
101+ )
102+ if seq .is_finished ():
103+ for scheduler in self .scheduler :
104+ scheduler .free_seq (seq )
105+ return
106+
84107 # Process samples
85108 samples = outputs .samples
86109 parent_seqs = seq_group .get_seqs (status = SequenceStatus .RUNNING )
@@ -127,20 +150,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
127150 child_seqs .append ((parent , parent ))
128151
129152 for seq , _ in child_seqs :
130- if seq_group . sampling_params .detokenize and self .detokenizer :
153+ if sampling_params .detokenize and self .detokenizer :
131154 new_char_count = self .detokenizer .decode_sequence_inplace (
132- seq , seq_group . sampling_params )
155+ seq , sampling_params )
133156 else :
134157 new_char_count = 0
135158 self .stop_checker .maybe_stop_sequence (
136159 seq ,
137160 new_char_count ,
138- seq_group . sampling_params ,
161+ sampling_params ,
139162 lora_req = seq_group .lora_request ,
140163 )
141164
142165 # Non-beam search case
143- if not seq_group . sampling_params .use_beam_search :
166+ if not sampling_params .use_beam_search :
144167 # For newly created child sequences, add them to the sequence group
145168 # and fork them in block manager if they are not finished.
146169 for seq , parent in child_seqs :
@@ -164,8 +187,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
164187 # Select the child sequences to keep in the sequence group.
165188 selected_child_seqs : List [Tuple [Sequence , Optional [Sequence ]]] = []
166189 unselected_child_seqs : List [Tuple [Sequence , Optional [Sequence ]]] = []
167- beam_width = seq_group . sampling_params .best_of
168- length_penalty = seq_group . sampling_params .length_penalty
190+ beam_width = sampling_params .best_of
191+ length_penalty = sampling_params .length_penalty
169192
170193 # Select the newly finished sequences with the highest scores
171194 # to replace existing finished sequences.
@@ -219,8 +242,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
219242 best_running_seq = running_child_seqs [0 ][0 ]
220243 current_worst_seq = all_finished_seqs [beam_width - 1 ][0 ]
221244 stop_beam_search = self ._check_beam_search_early_stopping (
222- seq_group . sampling_params .early_stopping ,
223- seq_group . sampling_params , best_running_seq , current_worst_seq )
245+ sampling_params .early_stopping , sampling_params ,
246+ best_running_seq , current_worst_seq )
224247
225248 if stop_beam_search :
226249 # Stop the beam search and remove all the running sequences from
0 commit comments