Skip to content

Commit 2ecf7b1

Browse files
[core] [3/N] multi-step args and sequence.py (#7452)
1 parent 3f674a4 commit 2ecf7b1

File tree

4 files changed

+100
-5
lines changed

4 files changed

+100
-5
lines changed

vllm/config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,8 @@ def __init__(self,
847847
delay_factor: float = 0.0,
848848
enable_chunked_prefill: bool = False,
849849
embedding_mode: Optional[bool] = False,
850-
preemption_mode: Optional[str] = None) -> None:
850+
preemption_mode: Optional[str] = None,
851+
num_scheduler_steps: int = 1) -> None:
851852
if max_num_batched_tokens is not None:
852853
self.max_num_batched_tokens = max_num_batched_tokens
853854
else:
@@ -876,6 +877,7 @@ def __init__(self,
876877
self.chunked_prefill_enabled = enable_chunked_prefill
877878
self.embedding_mode = embedding_mode
878879
self.preemption_mode = preemption_mode
880+
self.num_scheduler_steps = num_scheduler_steps
879881
self._verify_args()
880882

881883
def _verify_args(self) -> None:
@@ -901,6 +903,16 @@ def _verify_args(self) -> None:
901903
f"({self.num_lookahead_slots}) must be greater than or "
902904
"equal to 0.")
903905

906+
if self.num_scheduler_steps < 1:
907+
raise ValueError(
908+
"num_scheduler_steps "
909+
f"({self.num_scheduler_steps}) must be greater than or "
910+
"equal to 1.")
911+
912+
@property
913+
def is_multi_step(self) -> bool:
914+
return self.num_scheduler_steps > 1
915+
904916

905917
class DeviceConfig:
906918
device: Optional[torch.device]

vllm/core/scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ def _schedule_prefills(
805805
curr_loras.add(lora_int_id)
806806
waiting_queue.popleft()
807807
self._allocate_and_set_running(seq_group)
808+
seq_group.init_multi_step(
809+
num_scheduler_steps=self._get_num_lookahead_slots(
810+
is_prefill=True) + 1)
808811
seq_groups.append(
809812
ScheduledSequenceGroup(seq_group=seq_group,
810813
token_chunk_size=num_new_tokens))
@@ -1108,6 +1111,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
11081111
computed_block_nums=common_computed_block_nums,
11091112
encoder_seq_data=encoder_seq_data,
11101113
cross_block_table=cross_block_table,
1114+
state=seq_group.state,
11111115
# `multi_modal_data` will only be present for the 1st comm
11121116
# between engine and worker.
11131117
# the subsequent comms can still use delta, but
@@ -1184,6 +1188,7 @@ def _append_slots(
11841188
slots.
11851189
"""
11861190
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
1191+
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
11871192

11881193
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
11891194
cows = self.block_manager.append_slots(seq, num_lookahead_slots)

vllm/engine/arg_utils.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class EngineArgs:
115115
lora_dtype: str = 'auto'
116116
max_cpu_loras: Optional[int] = None
117117
device: str = 'auto'
118+
num_scheduler_steps: int = 1
118119
ray_workers_use_nsight: bool = False
119120
num_gpu_blocks_override: Optional[int] = None
120121
num_lookahead_slots: int = 0
@@ -543,6 +544,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
543544
"tpu", "xpu"
544545
],
545546
help='Device type for vLLM execution.')
547+
parser.add_argument('--num-scheduler-steps',
548+
type=int,
549+
default=1,
550+
help=('Maximum number of forward steps per '
551+
'scheduler call.'))
546552

547553
parser.add_argument(
548554
'--scheduler-delay-factor',
@@ -858,18 +864,34 @@ def create_engine_config(self, ) -> EngineConfig:
858864
disable_logprobs=self.disable_logprobs_during_spec_decoding,
859865
)
860866

867+
if self.num_scheduler_steps > 1:
868+
raise NotImplementedError("Multi-step is not yet supported.")
869+
if speculative_config is not None:
870+
raise ValueError("Speculative decoding is not supported with "
871+
"multi-step (--num-scheduler-steps > 1)")
872+
if self.enable_chunked_prefill:
873+
raise ValueError("Chunked prefill is not supported with "
874+
"multi-step (--num-scheduler-steps > 1)")
875+
876+
# make sure num_lookahead_slots is set the higher value depending on
877+
# if we are using speculative decoding or multi-step
878+
num_lookahead_slots = max(self.num_lookahead_slots,
879+
self.num_scheduler_steps - 1)
880+
num_lookahead_slots = num_lookahead_slots \
881+
if speculative_config is None \
882+
else speculative_config.num_lookahead_slots
883+
861884
scheduler_config = SchedulerConfig(
862885
max_num_batched_tokens=self.max_num_batched_tokens,
863886
max_num_seqs=self.max_num_seqs,
864887
max_model_len=model_config.max_model_len,
865888
use_v2_block_manager=self.use_v2_block_manager,
866-
num_lookahead_slots=(self.num_lookahead_slots
867-
if speculative_config is None else
868-
speculative_config.num_lookahead_slots),
889+
num_lookahead_slots=num_lookahead_slots,
869890
delay_factor=self.scheduler_delay_factor,
870891
enable_chunked_prefill=self.enable_chunked_prefill,
871892
embedding_mode=model_config.embedding_mode,
872893
preemption_mode=self.preemption_mode,
894+
num_scheduler_steps=self.num_scheduler_steps,
873895
)
874896
lora_config = LoRAConfig(
875897
max_lora_rank=self.max_lora_rank,

vllm/sequence.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
99
Union, cast)
1010

11+
import numpy
1112
import torch
1213

1314
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
@@ -489,6 +490,19 @@ def __repr__(self) -> str:
489490
f"num_blocks={self.n_blocks}, ")
490491

491492

493+
@dataclass
494+
class SequenceGroupState:
495+
"""Mutable state tied to a specific sequence group"""
496+
497+
# for multi-step decoding
498+
num_steps: int = 1
499+
current_step: int = 0
500+
501+
@property
502+
def remaining_steps(self) -> int:
503+
return self.num_steps - self.current_step
504+
505+
492506
class SequenceGroup:
493507
"""A group of sequences that are generated from the same prompt.
494508
@@ -534,6 +548,7 @@ def __init__(
534548
time_in_queue=None)
535549
self.lora_request = lora_request
536550
self.prompt_logprobs: Optional[PromptLogprobs] = None
551+
self.state = SequenceGroupState()
537552
self.embeddings = embeddings
538553
self.pooling_params = pooling_params
539554
self.prompt_adapter_request = prompt_adapter_request
@@ -588,6 +603,10 @@ def prompt_adapter_num_virtual_tokens(self) -> int:
588603
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
589604
if self.prompt_adapter_request else 0
590605

606+
def init_multi_step(self, num_scheduler_steps: int) -> None:
607+
self.state.num_steps = num_scheduler_steps
608+
self.state.current_step = 0
609+
591610
def get_last_latency(self, now: float) -> Optional[float]:
592611
"""Sets the last token time for Request level timings."""
593612
# If still in prefill phase, raise Error.
@@ -756,6 +775,7 @@ class SequenceGroupMetadata:
756775
lora_request: LoRA request.
757776
computed_block_nums: The block numbers that are already computed,
758777
used in prefix caching.
778+
state: Internal state tied to this sequence group.
759779
multi_modal_data: Multi modal data.
760780
encoder_seq_data: Optional sequence data for encoder prompt
761781
(SequenceGroup.encoder_seq). Should be None
@@ -781,6 +801,7 @@ def __init__(
781801
token_chunk_size: Optional[int] = None,
782802
lora_request: Optional[LoRARequest] = None,
783803
computed_block_nums: Optional[List[int]] = None,
804+
state: Optional[SequenceGroupState] = None,
784805
multi_modal_data: Optional["MultiModalDataDict"] = None,
785806
encoder_seq_data: Optional[SequenceData] = None,
786807
cross_block_table: Optional[List[int]] = None,
@@ -796,6 +817,7 @@ def __init__(
796817
self.prompt_adapter_request = prompt_adapter_request
797818
self.computed_block_nums = computed_block_nums
798819
self.multi_modal_data = multi_modal_data
820+
self.state = SequenceGroupState() if state is None else state
799821
self.encoder_seq_data = encoder_seq_data
800822
self.cross_block_table = cross_block_table
801823
self._token_chunk_size = token_chunk_size
@@ -834,6 +856,10 @@ def token_chunk_size(self) -> int:
834856
assert self._token_chunk_size is not None
835857
return self._token_chunk_size
836858

859+
def finish_step(self) -> None:
860+
assert self.state.current_step < self.state.num_steps
861+
self.state.current_step += 1
862+
837863

838864
class SequenceOutput:
839865
"""The model output associated with a sequence.
@@ -971,6 +997,7 @@ class SamplerOutput:
971997

972998
# On-device tensor containing the sampled token ids.
973999
sampled_token_ids: Optional[torch.Tensor] = None
1000+
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
9741001

9751002
# Spec decode metrics populated by workers.
9761003
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
11121139
num_steps: int = 1
11131140
# Finished request ids since last step.
11141141
finished_requests_ids: List[str] = field(default_factory=list)
1142+
# The last sampled token ids for multi step decoding.
1143+
last_sampled_token_ids: Optional[torch.Tensor] = None
1144+
1145+
@property
1146+
def is_first_multi_step(self) -> bool:
1147+
# TODO(will) make this be able to handle batches with variable number of
1148+
# steps
1149+
assert len(self.seq_group_metadata_list) > 0
1150+
first_seq_group = self.seq_group_metadata_list[0]
1151+
return first_seq_group.state.current_step == 0
1152+
1153+
@property
1154+
def is_last_step(self) -> bool:
1155+
# TODO(will) make this be able to handle batches with variable number of
1156+
# steps
1157+
assert len(self.seq_group_metadata_list) > 0
1158+
first_seq_group = self.seq_group_metadata_list[0]
1159+
num_steps = first_seq_group.state.num_steps
1160+
current_step = first_seq_group.state.current_step
1161+
return num_steps - current_step == 1
1162+
1163+
@property
1164+
def current_step(self) -> int:
1165+
# TODO(will) make this be able to handle batches with variable number of
1166+
# steps
1167+
assert len(self.seq_group_metadata_list) > 0
1168+
return self.seq_group_metadata_list[0].state.current_step
11151169

11161170
def clone(
11171171
self, seq_group_metadata_list: List[SequenceGroupMetadata]
@@ -1127,4 +1181,6 @@ def clone(
11271181
running_queue_size=self.running_queue_size,
11281182
previous_hidden_states=self.previous_hidden_states,
11291183
num_steps=self.num_steps,
1130-
finished_requests_ids=self.finished_requests_ids)
1184+
finished_requests_ids=self.finished_requests_ids,
1185+
last_sampled_token_ids=self.last_sampled_token_ids.clone()
1186+
if self.last_sampled_token_ids is not None else None)

0 commit comments

Comments
 (0)