Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def run_vllm(
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
scheduler_policy: str,
scheduler_reorder_window: float,
swap_space: int,
gpu_memory_utilization: float = 0.9,
) -> float:
from vllm import LLM, SamplingParams
Expand All @@ -89,8 +92,10 @@ def run_vllm(
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching)

enable_prefix_caching=enable_prefix_caching,
scheduler_policy=scheduler_policy,
scheduler_reorder_window=scheduler_reorder_window,
swap_space=swap_space)
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
Expand Down Expand Up @@ -213,7 +218,9 @@ def main(args: argparse.Namespace):
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.device,
args.enable_prefix_caching, args.gpu_memory_utilization)
args.enable_prefix_caching, args.vllm_scheduler_policy,
args.vllm_scheduler_reorder_window, args.swap_space,
args.gpu_memory_utilization)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand All @@ -236,6 +243,13 @@ def main(args: argparse.Namespace):
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--vllm-scheduler-policy",
type=str,
choices=["fcfs", "reorder"],
default="fcfs")
parser.add_argument("--vllm-scheduler-reorder-window",
type=float,
default=0)
parser.add_argument("--dataset",
type=str,
default=None,
Expand Down Expand Up @@ -314,6 +328,10 @@ def main(args: argparse.Namespace):
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument('--swap-space',
type=int,
default=16,
help='CPU swap space size (GiB) per GPU')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ class SchedulerConfig:
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
policy: Policy of sequence scheduling(`fcfs` or `reorder`).
reorder_window: Allowed reorder window size(in sec) for `reorder` policy.
"""

def __init__(
Expand All @@ -465,6 +467,8 @@ def __init__(
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
policy: str = 'fcfs',
reorder_window: float = 0,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand All @@ -475,6 +479,8 @@ def __init__(
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self.policy = policy
self.reorder_window = reorder_window
self._verify_args()

def _verify_args(self) -> None:
Expand All @@ -491,6 +497,13 @@ def _verify_args(self) -> None:
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs}).")
if self.reorder_window < 0:
raise ValueError(f"reorder_window ({self.reorder_window}) must "
"be not be negative.")
if self.reorder_window != 0 and self.policy != 'reorder':
raise ValueError(
f"fcfs policy doesn't support reorder_window ({self.reorder_window})."
)


class DeviceConfig:
Expand Down
86 changes: 68 additions & 18 deletions vllm/core/policy.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,95 @@
from collections import deque
from typing import Deque
import enum
import bisect

from vllm.sequence import SequenceGroup


class PreemptionMode(enum.Enum):
"""Preemption modes.

1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()


class Policy:
"""Base class policy"""

def get_priority(
def sort(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
raise NotImplementedError

def sort_by_priority(
def get_preemption_mode(self, seq_group: SequenceGroup) -> PreemptionMode:
raise NotImplementedError


class FCFS(Policy):

def __init__(self, **kwargs) -> None:
super().__init__()

def sort(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
"""We can just sort `Deque[SequenceGroup]` by `arrival_time`"""
return deque(sorted(seq_groups, key=lambda x: x.metrics.arrival_time))

def get_preemption_mode(self, seq_group: SequenceGroup) -> PreemptionMode:
if seq_group.get_max_num_running_seqs() == 1:
return PreemptionMode.RECOMPUTE
else:
return PreemptionMode.SWAP

class FCFS(Policy):

def get_priority(
class ReorderPolicy(Policy):
"""ReorderPolicy tries to maximize throughput by reordering incoming requests by length.

Args:
reorder_window: window size in sec within which `List[SequenceGroup]` is allowed to be reordered. 0 means no reorder.
"""

def __init__(self, reorder_window: float = 0, **kwargs) -> None:
super().__init__()
self.reorder_window = reorder_window

def sort(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
"""Sort head within `reorder_window` of the `seq_groups` by length. It reduces padding computation overhead."""
if len(seq_groups) == 0:
return seq_groups
arrival_time_sorted = sorted(seq_groups,
key=lambda x: x.metrics.arrival_time)
pos = bisect.bisect_left(arrival_time_sorted,
arrival_time_sorted[0].metrics.arrival_time +
self.reorder_window,
key=lambda x: x.metrics.arrival_time)
return deque(
sorted(arrival_time_sorted[:pos],
key=lambda x: x.get_seqs()[0].get_len()) +
arrival_time_sorted[pos:])

def get_preemption_mode(self, seq_group: SequenceGroup) -> PreemptionMode:
"""Always use SWAP, as it is faster than `RECOMPUTE` for heavy models like llama."""
return PreemptionMode.SWAP


class PolicyFactory:

_POLICY_REGISTRY = {
'fcfs': FCFS,
'reorder': ReorderPolicy,
}

@classmethod
Expand Down
42 changes: 14 additions & 28 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
from collections import deque
import enum
import time
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.lora.request import LoRARequest
from vllm.core.policy import PolicyFactory, PreemptionMode, FCFS
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
import time

logger = init_logger(__name__)


class PreemptionMode(enum.Enum):
"""Preemption modes.

1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()


class SchedulerOutputs:

def __init__(
Expand Down Expand Up @@ -88,7 +74,10 @@ def __init__(
self.scheduler_config.max_num_batched_tokens)

# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
self.policy = PolicyFactory.get_policy(
policy_name=self.scheduler_config.policy,
reorder_window=self.scheduler_config.reorder_window,
)
# Create the block space manager.
self.block_manager = BlockSpaceManager(
block_size=self.cache_config.block_size,
Expand Down Expand Up @@ -160,9 +149,6 @@ def _schedule(self) -> SchedulerOutputs:
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}

# Fix the current time.
now = time.monotonic()

# Join waiting sequences if possible.
if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = []
Expand All @@ -176,10 +162,12 @@ def _schedule(self) -> SchedulerOutputs:
for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = []

# Optimization: We do not sort the waiting queue since the preempted
# Optimization: We do not sort the waiting queue when using FCFS policy since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences = deque()
if not isinstance(self.policy, FCFS):
self.waiting = self.policy.sort(self.waiting)
while self.waiting:
seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs(
Expand Down Expand Up @@ -269,11 +257,11 @@ def _schedule(self) -> SchedulerOutputs:
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running)

# Reserve new token slots for the running sequence groups.
running: Deque[SequenceGroup] = deque()
preempted: List[SequenceGroup] = []
self.running = self.policy.sort(self.running)
while self.running:
seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group):
Expand All @@ -294,8 +282,6 @@ def _schedule(self) -> SchedulerOutputs:
running.append(seq_group)
self.running = running

# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped)
if not preempted:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
Expand All @@ -305,6 +291,9 @@ def _schedule(self) -> SchedulerOutputs:

leftover_swapped = deque()

# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort(self.swapped)

while self.swapped:
seq_group = self.swapped[0]
lora_int_id = 0
Expand Down Expand Up @@ -439,10 +428,7 @@ def _preempt(
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
if seq_group.get_max_num_running_seqs() == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
preemption_mode = self.policy.get_preemption_mode(seq_group)
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class EngineArgs:
max_num_seqs: int = 256
max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value
scheduler_policy: str = 'fcfs'
scheduler_reorder_window: float = 0
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
Expand Down Expand Up @@ -219,6 +221,15 @@ def add_cli_args(
default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in'
' SamplingParams'))
parser.add_argument('--scheduler-policy',
type=str,
default=EngineArgs.scheduler_policy,
choices=['fcfs', 'reorder'],
help='scheduler policy')
parser.add_argument('--scheduler-reorder-window',
type=float,
default=EngineArgs.scheduler_reorder_window,
help='allowed sequences reorder window(in sec)')
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
Expand Down Expand Up @@ -323,7 +334,9 @@ def create_engine_configs(
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
self.max_paddings,
self.scheduler_policy,
self.scheduler_reorder_window)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
Expand Down