diff --git a/vllm/config.py b/vllm/config.py index ddd63dfd6ee1..aa0973b1588c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -459,6 +459,8 @@ class SchedulerConfig: max_paddings: Maximum number of paddings to be added to a batch. policy: Policy of sequence scheduling(`fcfs` or `reorder`). reorder_window: Allowed reorder window size(in sec) for `reorder` policy. + swap_tolerance: Maximum acceptable number of swapped sequences to start + a new waiting request. 0 means to process SWAP first. """ def __init__( @@ -469,6 +471,7 @@ def __init__( max_paddings: int, policy: str = 'fcfs', reorder_window: float = 0, + swap_tolerance: int = 0, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -481,6 +484,7 @@ def __init__( self.max_paddings = max_paddings self.policy = policy self.reorder_window = reorder_window + self.swap_tolerance = swap_tolerance self._verify_args() def _verify_args(self) -> None: @@ -504,6 +508,7 @@ def _verify_args(self) -> None: raise ValueError( f"fcfs policy doesn't support reorder_window ({self.reorder_window})." ) + assert self.swap_tolerance >= 0 class DeviceConfig: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1b3fc4d0e974..a6245955da9b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -150,7 +150,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_copy: Dict[int, List[int]] = {} # Join waiting sequences if possible. - if not self.swapped: + if len(self.swapped) <= self.scheduler_config.swap_tolerance: ignored_seq_groups: List[SequenceGroup] = [] scheduled: List[SequenceGroup] = [] # The total number of sequences on the fly, including the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7995c0f2ea36..af8cd7184cbf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ class EngineArgs: max_logprobs: int = 5 # OpenAI default value scheduler_policy: str = 'fcfs' scheduler_reorder_window: float = 0 + scheduler_swap_tolerance: int = 0 disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None @@ -230,6 +231,12 @@ def add_cli_args( type=float, default=EngineArgs.scheduler_reorder_window, help='allowed sequences reorder window(in sec)') + parser.add_argument( + '--scheduler-swap-tolerance', + type=int, + default=EngineArgs.scheduler_swap_tolerance, + help='Maximum acceptable number of swapped sequences to start a ' + 'new waiting request') parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') @@ -331,12 +338,11 @@ def create_engine_configs( self.max_parallel_loading_workers, self.disable_custom_all_reduce, self.ray_workers_use_nsight) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.max_paddings, - self.scheduler_policy, - self.scheduler_reorder_window) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, self.max_num_seqs, + model_config.max_model_len, self.max_paddings, + self.scheduler_policy, self.scheduler_reorder_window, + self.scheduler_swap_tolerance) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras,