@@ -962,6 +962,45 @@ def _process_sequence_group_outputs(
962962
963963 return
964964
965+ def _update_num_computed_tokens_for_multi_step_prefill (
966+ self , seq_group : SequenceGroup ,
967+ seq_group_meta : SequenceGroupMetadata ,
968+ is_first_step_output : Optional [bool ]):
969+ """
970+ This function updates num_computed_tokens for prompt sequences
971+ when Multi-Step is enabled.
972+
973+ seq_group: SequenceGroup to update the num_computed_tokens for.
974+ seq_group_meta: Metadata of the given SequenceGroup.
975+ is_first_step_output: Optional[bool] -
976+ When available, is_first_step_output indicates if the appended
977+ output token is the output of the first-step in multi-step.
978+ A value of None indicates that outputs from all steps in
979+ in multi-step are submitted in a single burst.
980+ """
981+
982+ assert self .scheduler_config .is_multi_step
983+
984+ if not seq_group_meta .is_prompt :
985+ # num_computed_token updates for multi-step decodes happen after
986+ # the tokens are appended to the sequence.
987+ return
988+
989+ do_update : bool = False
990+ if self .scheduler_config .chunked_prefill_enabled :
991+ # In multi-step + chunked-prefill case, the prompt sequences
992+ # that are scheduled are fully processed in the first step.
993+ do_update = is_first_step_output is None or is_first_step_output
994+ else :
995+ # Normal multi-step decoding case. In this case prompt-sequences
996+ # are actually single-stepped. Always update in this case.
997+ assert seq_group .state .num_steps == 1
998+ do_update = True
999+
1000+ if do_update :
1001+ seq_group .update_num_computed_tokens (
1002+ seq_group_meta .token_chunk_size )
1003+
9651004 def _process_model_outputs (self ,
9661005 ctx : SchedulerContext ,
9671006 request_id : Optional [str ] = None ) -> None :
@@ -972,64 +1011,6 @@ def _process_model_outputs(self,
9721011 request_id: If provided, then only this request is going to be processed
9731012 """
9741013
975- def update_prefill_num_computed_tokens (
976- seq_group : SequenceGroup ,
977- seq_group_meta : SequenceGroupMetadata , num_outputs : int ,
978- is_first_step_output : Optional [bool ]) -> None :
979- """
980- When multi-step and chunked-prefill are enabled together, the
981- prefill sequence scheduled for multi-step execution turn into
982- decodes in the first step itself. This function accounts
983- for that conversion.
984-
985- seq_group: SequenceGroup - A prefill seq_group
986- seq_group_meta: SequenceGroupMetadata - Metadata of the given
987- prefill seq_group
988- num_outputs: int - number of output tokens being processed for the
989- given seq_group
990- is_first_step_output: Optional[bool] -
991- If multi-step is enabled and num_outputs is 1, this value
992- indicates if this outputs belongs to the first step in the
993- multi-step.
994- If multi-step is enabled and num_outputs > 1, this value
995- must be None, as num_outputs > 1 indicates that outputs from
996- all the steps in multi-step are submitted in a single burst.
997- When multi-step is disabled, this value is always True.
998- """
999-
1000- assert seq_group_meta .is_prompt
1001-
1002- token_chunk_size = seq_group_meta .token_chunk_size
1003-
1004- if num_outputs == 1 :
1005- assert is_first_step_output is not None
1006-
1007- if seq_group_meta .state .num_steps == 1 :
1008- assert is_first_step_output is True
1009- seq_group .update_num_computed_tokens (token_chunk_size )
1010- return
1011-
1012- # multi-step prefill is only supported when multi-step is
1013- # enabled with chunked prefill
1014- assert self .scheduler_config .is_multi_step and \
1015- self .scheduler_config .chunked_prefill_enabled
1016- if is_first_step_output is True :
1017- # This sequence is a prompt during the first step only.
1018- seq_group .update_num_computed_tokens (token_chunk_size )
1019- return
1020-
1021- assert is_first_step_output is None
1022-
1023- # multi-step prefill is only supported when multi-step is
1024- # enabled with chunked prefill. Outputs from all the steps are
1025- # submitted in a single burst.
1026- assert self .scheduler_config .is_multi_step and \
1027- self .scheduler_config .chunked_prefill_enabled
1028- assert num_outputs == seq_group_meta .state .num_steps , \
1029- f"#outputs { len (outputs )} - num steps { seq_group_meta .state .num_steps } " #noqa
1030- # This sequence is a prompt during the first step only.
1031- seq_group .update_num_computed_tokens (token_chunk_size )
1032-
10331014 now = time .time ()
10341015
10351016 if len (ctx .output_queue ) == 0 :
@@ -1090,7 +1071,7 @@ def update_prefill_num_computed_tokens(
10901071 seq_group_meta = seq_group_metadata_list [i ]
10911072 scheduled_seq_group = scheduler_outputs .scheduled_seq_groups [i ]
10921073
1093- seq_group = scheduled_seq_group .seq_group
1074+ seq_group : SequenceGroup = scheduled_seq_group .seq_group
10941075
10951076 if seq_group .is_finished ():
10961077 finished_before .append (i )
@@ -1101,14 +1082,14 @@ def update_prefill_num_computed_tokens(
11011082 else :
11021083 output = [outputs_by_sequence_group [0 ][i ]]
11031084
1104- if not is_async and seq_group_meta . is_prompt :
1105- # Updates for all decodes happen when we actually append the
1106- # token ids to the seq in process_outputs.
1107- update_prefill_num_computed_tokens ( seq_group , seq_group_meta ,
1108- len ( output ),
1109- is_first_step_output )
1110- elif not is_async :
1111- seq_group . update_num_computed_tokens ( 1 )
1085+ if not is_async :
1086+ if self . scheduler_config . is_multi_step :
1087+ # Updates happen only if the sequence is prefill
1088+ self . _update_num_computed_tokens_for_multi_step_prefill (
1089+ seq_group , seq_group_meta , is_first_step_output )
1090+ else :
1091+ seq_group . update_num_computed_tokens (
1092+ seq_group_meta . token_chunk_size )
11121093
11131094 if outputs :
11141095 for o in outputs :
@@ -1132,16 +1113,8 @@ def update_prefill_num_computed_tokens(
11321113 else :
11331114 self .output_processor .process_prompt_logprob (seq_group , output )
11341115 if seq_group_meta .do_sample :
1135- output_token_num = self .output_processor .process_outputs (
1116+ self .output_processor .process_outputs (
11361117 seq_group , output , is_async )
1137- if self .speculative_config :
1138- # We -1 here because we always
1139- # (w/o speculative decoding) add the number of
1140- # computed tokens by one in the decoding phase.
1141- # Therefore, we remove that one token that
1142- # is already added.
1143- seq_group .update_num_computed_tokens (output_token_num -
1144- 1 )
11451118
11461119 if seq_group .is_finished ():
11471120 finished_now .append (i )
@@ -1250,20 +1223,15 @@ def _advance_to_next_step(
12501223 if seq_group .is_finished ():
12511224 continue
12521225
1253- if seq_group_metadata .is_prompt :
1254- if self .scheduler_config .is_multi_step and \
1255- self .scheduler_config .chunked_prefill_enabled :
1256- # Prompts are scheduled in multi-step only when
1257- # chunking is enabled. These prompts turn into
1258- # decodes after the very first step. Therefore,
1259- # we skip the update to the num_computed_tokens
1260- # here.
1261- seq_group .update_num_computed_tokens (1 )
1262- else :
1263- seq_group .update_num_computed_tokens (
1264- seq_group_metadata .token_chunk_size )
1226+ if self .scheduler_config .is_multi_step :
1227+ # Updates happen only if the sequence is prefill
1228+ self ._update_num_computed_tokens_for_multi_step_prefill (
1229+ seq_group , seq_group_metadata ,
1230+ seq_group .state .num_steps == 1 )
12651231 else :
1266- seq_group .update_num_computed_tokens (1 )
1232+ seq_group .update_num_computed_tokens (
1233+ seq_group_metadata .token_chunk_size )
1234+
12671235 if seq_group_metadata .do_sample :
12681236 assert len (sequence_group_outputs .samples ) == 1 , (
12691237 "Async output processor expects a single sample"
@@ -1273,7 +1241,15 @@ def _advance_to_next_step(
12731241
12741242 assert len (seq_group .seqs ) == 1
12751243 seq = seq_group .seqs [0 ]
1276- seq .append_token_id (sample .output_token , sample .logprobs )
1244+
1245+ if self .scheduler_config .is_multi_step :
1246+ is_prefill_append = seq .data .get_num_uncomputed_tokens (
1247+ ) == 0
1248+ seq .append_token_id (sample .output_token , sample .logprobs )
1249+ if not is_prefill_append :
1250+ seq_group .update_num_computed_tokens (1 )
1251+ else :
1252+ seq .append_token_id (sample .output_token , sample .logprobs )
12771253
12781254 def step (self ) -> List [Union [RequestOutput , EmbeddingRequestOutput ]]:
12791255 """Performs one decoding iteration and returns newly generated results.
0 commit comments