88from typing import (TYPE_CHECKING , Dict , List , Mapping , Optional , Set , Tuple ,
99 Union , cast )
1010
11+ import numpy
1112import torch
1213
1314from vllm .inputs .parse import is_valid_encoder_decoder_llm_inputs
@@ -489,6 +490,19 @@ def __repr__(self) -> str:
489490 f"num_blocks={ self .n_blocks } , " )
490491
491492
493+ @dataclass
494+ class SequenceGroupState :
495+ """Mutable state tied to a specific sequence group"""
496+
497+ # for multi-step decoding
498+ num_steps : int = 1
499+ current_step : int = 0
500+
501+ @property
502+ def remaining_steps (self ) -> int :
503+ return self .num_steps - self .current_step
504+
505+
492506class SequenceGroup :
493507 """A group of sequences that are generated from the same prompt.
494508
@@ -534,6 +548,7 @@ def __init__(
534548 time_in_queue = None )
535549 self .lora_request = lora_request
536550 self .prompt_logprobs : Optional [PromptLogprobs ] = None
551+ self .state = SequenceGroupState ()
537552 self .embeddings = embeddings
538553 self .pooling_params = pooling_params
539554 self .prompt_adapter_request = prompt_adapter_request
@@ -588,6 +603,10 @@ def prompt_adapter_num_virtual_tokens(self) -> int:
588603 return self .prompt_adapter_request .prompt_adapter_num_virtual_tokens \
589604 if self .prompt_adapter_request else 0
590605
606+ def init_multi_step (self , num_scheduler_steps : int ) -> None :
607+ self .state .num_steps = num_scheduler_steps
608+ self .state .current_step = 0
609+
591610 def get_last_latency (self , now : float ) -> Optional [float ]:
592611 """Sets the last token time for Request level timings."""
593612 # If still in prefill phase, raise Error.
@@ -756,6 +775,7 @@ class SequenceGroupMetadata:
756775 lora_request: LoRA request.
757776 computed_block_nums: The block numbers that are already computed,
758777 used in prefix caching.
778+ state: Internal state tied to this sequence group.
759779 multi_modal_data: Multi modal data.
760780 encoder_seq_data: Optional sequence data for encoder prompt
761781 (SequenceGroup.encoder_seq). Should be None
@@ -781,6 +801,7 @@ def __init__(
781801 token_chunk_size : Optional [int ] = None ,
782802 lora_request : Optional [LoRARequest ] = None ,
783803 computed_block_nums : Optional [List [int ]] = None ,
804+ state : Optional [SequenceGroupState ] = None ,
784805 multi_modal_data : Optional ["MultiModalDataDict" ] = None ,
785806 encoder_seq_data : Optional [SequenceData ] = None ,
786807 cross_block_table : Optional [List [int ]] = None ,
@@ -796,6 +817,7 @@ def __init__(
796817 self .prompt_adapter_request = prompt_adapter_request
797818 self .computed_block_nums = computed_block_nums
798819 self .multi_modal_data = multi_modal_data
820+ self .state = SequenceGroupState () if state is None else state
799821 self .encoder_seq_data = encoder_seq_data
800822 self .cross_block_table = cross_block_table
801823 self ._token_chunk_size = token_chunk_size
@@ -834,6 +856,10 @@ def token_chunk_size(self) -> int:
834856 assert self ._token_chunk_size is not None
835857 return self ._token_chunk_size
836858
859+ def finish_step (self ) -> None :
860+ assert self .state .current_step < self .state .num_steps
861+ self .state .current_step += 1
862+
837863
838864class SequenceOutput :
839865 """The model output associated with a sequence.
@@ -971,6 +997,7 @@ class SamplerOutput:
971997
972998 # On-device tensor containing the sampled token ids.
973999 sampled_token_ids : Optional [torch .Tensor ] = None
1000+ sampled_token_ids_numpy : Optional [numpy .ndarray ] = None
9741001
9751002 # Spec decode metrics populated by workers.
9761003 spec_decode_worker_metrics : Optional ["SpecDecodeWorkerMetrics" ] = None
@@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
11121139 num_steps : int = 1
11131140 # Finished request ids since last step.
11141141 finished_requests_ids : List [str ] = field (default_factory = list )
1142+ # The last sampled token ids for multi step decoding.
1143+ last_sampled_token_ids : Optional [torch .Tensor ] = None
1144+
1145+ @property
1146+ def is_first_multi_step (self ) -> bool :
1147+ # TODO(will) make this be able to handle batches with variable number of
1148+ # steps
1149+ assert len (self .seq_group_metadata_list ) > 0
1150+ first_seq_group = self .seq_group_metadata_list [0 ]
1151+ return first_seq_group .state .current_step == 0
1152+
1153+ @property
1154+ def is_last_step (self ) -> bool :
1155+ # TODO(will) make this be able to handle batches with variable number of
1156+ # steps
1157+ assert len (self .seq_group_metadata_list ) > 0
1158+ first_seq_group = self .seq_group_metadata_list [0 ]
1159+ num_steps = first_seq_group .state .num_steps
1160+ current_step = first_seq_group .state .current_step
1161+ return num_steps - current_step == 1
1162+
1163+ @property
1164+ def current_step (self ) -> int :
1165+ # TODO(will) make this be able to handle batches with variable number of
1166+ # steps
1167+ assert len (self .seq_group_metadata_list ) > 0
1168+ return self .seq_group_metadata_list [0 ].state .current_step
11151169
11161170 def clone (
11171171 self , seq_group_metadata_list : List [SequenceGroupMetadata ]
@@ -1127,4 +1181,6 @@ def clone(
11271181 running_queue_size = self .running_queue_size ,
11281182 previous_hidden_states = self .previous_hidden_states ,
11291183 num_steps = self .num_steps ,
1130- finished_requests_ids = self .finished_requests_ids )
1184+ finished_requests_ids = self .finished_requests_ids ,
1185+ last_sampled_token_ids = self .last_sampled_token_ids .clone ()
1186+ if self .last_sampled_token_ids is not None else None )
0 commit comments