From 51a2557ded46c30358ee20e93f6b7733db0a80a8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 20:55:36 +0000 Subject: [PATCH 01/19] Add watermark to avoid thrashing --- cacheflow/master/block_manager.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 571ee247eedd..203ab5438643 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -60,11 +60,15 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + watermark: float = 0.01, ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks + self.watermark = watermark + assert watermark >= 0.0 + self.watermark_blocks = int(watermark * num_gpu_blocks) self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) @@ -72,11 +76,13 @@ def __init__( self.block_tables: Dict[int, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> bool: - # NOTE: Here we assume that all sequences in the group have the same prompt. + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This is not true for preempted sequences. seq = seq_group.seqs[0] num_required_blocks = len(seq.logical_token_blocks) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - return num_required_blocks <= num_free_gpu_blocks + # Use watermark to avoid thrashing. + return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same prompt. From ab21ab8efa70fcb66276850c6f5e69b7af2bb770 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 20:55:55 +0000 Subject: [PATCH 02/19] Add Policy class --- cacheflow/master/policy.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 cacheflow/master/policy.py diff --git a/cacheflow/master/policy.py b/cacheflow/master/policy.py new file mode 100644 index 000000000000..0bd552e24378 --- /dev/null +++ b/cacheflow/master/policy.py @@ -0,0 +1,45 @@ +from typing import List + +from cacheflow.sequence import SequenceGroup + + +class Policy: + + def get_priority( + self, + now: float, + seq_group: SequenceGroup, + ) -> float: + raise NotImplementedError + + def sort_by_priority( + self, + now: float, + seq_groups: List[SequenceGroup], + ) -> List[SequenceGroup]: + return sorted( + seq_groups, + key=lambda seq_group: self.get_priority(now, seq_group), + reverse=True, + ) + + +class PolicyFactory: + + def __init__(self) -> None: + self.policies = { + 'fcfs': FCFS, + } + + def get_policy(self, policy_name: str, **kwargs) -> Policy: + return self.policies[policy_name](**kwargs) + + +class FCFS(Policy): + + def get_priority( + self, + now: float, + seq_group: SequenceGroup, + ) -> float: + return now - seq_group.arrival_time From 15e1e412f772f845e18bef0b85c324444d56ee0b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 21:07:17 +0000 Subject: [PATCH 03/19] Refactor Policy class --- cacheflow/master/policy.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/cacheflow/master/policy.py b/cacheflow/master/policy.py index 0bd552e24378..7d8afbff10c7 100644 --- a/cacheflow/master/policy.py +++ b/cacheflow/master/policy.py @@ -24,17 +24,6 @@ def sort_by_priority( ) -class PolicyFactory: - - def __init__(self) -> None: - self.policies = { - 'fcfs': FCFS, - } - - def get_policy(self, policy_name: str, **kwargs) -> Policy: - return self.policies[policy_name](**kwargs) - - class FCFS(Policy): def get_priority( @@ -43,3 +32,14 @@ def get_priority( seq_group: SequenceGroup, ) -> float: return now - seq_group.arrival_time + + +class PolicyFactory: + + _POLICY_REGISTRY = { + 'fcfs': FCFS, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> Policy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) From 0e2b47a99e1340377983b675d4914bdb1327091e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 21:10:22 +0000 Subject: [PATCH 04/19] Use recomputation as preemption mechanism --- cacheflow/master/frontend.py | 4 +- cacheflow/master/scheduler.py | 244 +++++++++++++++++----------------- cacheflow/sequence.py | 6 +- server.py | 7 +- 4 files changed, 136 insertions(+), 125 deletions(-) diff --git a/cacheflow/master/frontend.py b/cacheflow/master/frontend.py index cfa17684fd56..ae9461f06e50 100644 --- a/cacheflow/master/frontend.py +++ b/cacheflow/master/frontend.py @@ -1,3 +1,4 @@ +import time from typing import List, Optional, Set, Tuple from transformers import AutoTokenizer @@ -54,6 +55,7 @@ def _add_query( token_ids: List[int], sampling_params: SamplingParams, ) -> None: + arrival_time = time.time() seqs: List[Sequence] = [] for _ in range(sampling_params.n): seq_id = next(self.seq_counter) @@ -61,7 +63,7 @@ def _add_query( seqs.append(seq) group_id = next(self.seq_group_counter) - seq_group = SequenceGroup(group_id, seqs) + seq_group = SequenceGroup(group_id, seqs, arrival_time) self.inputs.append((seq_group, sampling_params)) def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]: diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 3931d92684f3..8aaf1e34728d 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,7 +1,9 @@ +import time from typing import Dict, List from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.frontend import Frontend +from cacheflow.master.policy import PolicyFactory from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup @@ -28,6 +30,8 @@ def __init__( self.num_cpu_blocks = num_cpu_blocks self.max_num_batched_tokens = max_num_batched_tokens + # Instantiate the scheduling policy. + self.policy = PolicyFactory.get_policy('fcfs') # Create the block space manager. self.block_manager = BlockSpaceManager( block_size=block_size, @@ -35,153 +39,95 @@ def __init__( num_cpu_blocks=num_cpu_blocks, ) - # Running sequence groups (FIFO). + # Sequence groups in the WAITING state. + self.waiting: List[SequenceGroup] = [] + # Sequence groups in the RUNNING state. self.running: List[SequenceGroup] = [] # Mapping: group_id -> num_steps. self.num_steps: Dict[int, int] = {} # Mapping: group_id -> sampling params. self.sampling_params: Dict[int, SamplingParams] = {} - - # Swapped sequence groups (LIFO). + # Sequence groups in the SWAPPED state. + # NOTE(woosuk): This is not used for now. self.swapped: List[SequenceGroup] = [] - # Pending sequence groups (FIFO). - self.pending: List[SequenceGroup] = [] - def _fetch_inputs(self) -> None: + def _fetch_requests(self) -> int: inputs = self.frontend.get_inputs() for seq_group, sampling_params in inputs: - self.pending.append(seq_group) + self.waiting.append(seq_group) self.sampling_params[seq_group.group_id] = sampling_params - - def _free_seq(self, seq: Sequence) -> None: - seq.status = SequenceStatus.FINISHED - self.block_manager.free(seq) - - def _allocate(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) - for seq in seq_group.seqs: - seq.status = SequenceStatus.RUNNING - self.running.append(seq_group) - # FIXME(woosuk): Support interactive generation. - self.num_steps[seq_group.group_id] = 0 - - def _append( - self, - seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], - ) -> None: - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - ret = self.block_manager.append(seq) - if ret is not None: - src_block, dst_block = ret - if src_block in blocks_to_copy: - blocks_to_copy[src_block].append(dst_block) - else: - blocks_to_copy[src_block] = [dst_block] - - def _swap_in( - self, - seq_group: SequenceGroup, - blocks_to_swap_in: Dict[int, int], - ) -> None: - mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.update(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - seq.status = SequenceStatus.RUNNING - self.running.append(seq_group) - - def _swap_out( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], - ) -> None: - assert self.block_manager.can_swap_out(seq_group) - mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.update(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.SWAPPED - self.swapped.append(seq_group) + return len(inputs) def step(self) -> None: # Blocks that need to be swaped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} - blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} - # 1. Reserve new slots for the running sequences. - # NOTE: Here we implicitly assume FCFS scheduling. - # That is, the most recently added sequence group is the first - # to be swapped out. - victim_idx = len(self.running) - 1 - for i, seq_group in enumerate(self.running): - if i > victim_idx: - # The i-th sequence group has already been swapped out. - break - # OOM. Swap out the victim sequence groups. + # Fetch new requests. + num_requests = self._fetch_requests() + + # Fix the current time. + now = time.time() + + # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state + # in order to minimize the preemption overheads. + # Preemption happens only when there is no available slot to keep all + # the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # Reserve new token slots for the running sequence groups. + running: List[SequenceGroup] = [] + while self.running: + seq_group = self.running.pop(0) while not self.block_manager.can_append(seq_group): - victim_seq_group = self.running[victim_idx] - self._swap_out(victim_seq_group, blocks_to_swap_out) - victim_idx -= 1 - if i > victim_idx: - # No other sequence groups can be swapped out. + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq_group = self.running.pop(-1) + self._preempt(victim_seq_group) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq_group) break else: + # Append new slots to the sequence group. self._append(seq_group, blocks_to_copy) - self.running = self.running[:victim_idx + 1] - - # 2. Swap in the swapped sequences if possible. - # NOTE: Here we implicitly assume FCFS scheduling. - # The swapped sequences are in LIFO order. - for i, seq_group in enumerate(reversed(self.swapped)): - if self.block_manager.can_swap_in(seq_group): - self._swap_in(seq_group, blocks_to_swap_in) - self._append(seq_group, blocks_to_copy) - else: - # OOM. Stop swapping. - self.swapped = self.swapped[:len(self.swapped) - i] - break - else: - # All swapped sequences are swapped in. - self.swapped.clear() - - # Ensure that swap-in and swap-out never happen at the same timestep. - if blocks_to_swap_in: - assert not blocks_to_swap_out + running.append(seq_group) + self.running = running num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running ) - # 3. Join new sequences if possible. - # NOTE: Here we implicitly assume FCFS scheduling. - # TODO(woosuk): Add a batching policy to control the batch size. - self._fetch_inputs() - if not self.swapped: - for i, seq_group in enumerate(self.pending): - num_prompt_tokens = seq_group.seqs[0].get_len() - if self.block_manager.can_allocate(seq_group): - if (num_batched_tokens + num_prompt_tokens - <= self.max_num_batched_tokens): - self._allocate(seq_group) - num_batched_tokens += num_prompt_tokens - continue - - self.pending = self.pending[i:] + # Join new sequences if possible. + self.waiting = self.policy.sort_by_priority(now, self.waiting) + # FIXME(woosuk): This does not work if sequence groups have more than + # one sequence. + while self.waiting: + seq_group = self.waiting[0] + # If the sequence group cannot be allocated, stop joining. + if not self.block_manager.can_allocate(seq_group): + break + + # If the number of batched tokens exceeds the limit, stop joining. + num_prompt_tokens = seq_group.seqs[0].get_len() + if (num_batched_tokens + num_prompt_tokens + > self.max_num_batched_tokens): break - else: - self.pending.clear() - # 4. Create input data structures. + seq_group = self.waiting.pop(0) + self._allocate(seq_group) + self.running.append(seq_group) + num_batched_tokens += num_prompt_tokens + + # Create input data structures. input_seq_groups: List[SequenceGroupInputs] = [] for seq_group in self.running: group_id = seq_group.group_id num_steps = self.num_steps[group_id] - # NOTE(woosuk): We assume that the number of steps is 0 - # for the prompt sequences. is_prompt = num_steps == 0 input_tokens: Dict[int, List[int]] = {} @@ -210,13 +156,13 @@ def step(self) -> None: ) input_seq_groups.append(input_seq_group) - # 5. Execute the first stage of the pipeline. - if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out): + # Execute the first stage of the pipeline. + if input_seq_groups: self.controllers[0].execute_stage( input_seq_groups, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, + blocks_to_copy=blocks_to_copy, + blocks_to_swap_in={}, + blocks_to_swap_out={}, ) def post_step( @@ -273,8 +219,66 @@ def post_step( running.append(seq_group) self.running = running + def _allocate(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.seqs: + seq.status = SequenceStatus.RUNNING + # FIXME(woosuk): Support interactive generation. + if seq_group.group_id not in self.num_steps: + self.num_steps[seq_group.group_id] = 0 + + def _append( + self, + seq_group: SequenceGroup, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + for seq in seq_group.seqs(status=SequenceStatus.RUNNING): + ret = self.block_manager.append(seq) + if ret is not None: + src_block, dst_block = ret + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] + + def _preempt(self, seq_group: SequenceGroup) -> None: + # NOTE(woosuk): There are two preemption mechanisms. + # 1. Swapping: Swap out the blocks of the preempted sequences to CPU + # memory and swap them back in when the sequences are resumed. + # 2. Recomputation: Discard the blocks of the preempted sequences and + # recompute them when the sequences are resumed. + # We originally used swapping, but it turned out that recomputation + # is more efficient. We keep the swapping code for future reference. + self.status = SequenceStatus.WAITING + self.block_manager.free(seq_group) + + def _free_seq(self, seq: Sequence) -> None: + seq.status = SequenceStatus.FINISHED + self.block_manager.free(seq) + def _return(self, seq_group: SequenceGroup) -> None: group_id = seq_group.group_id del self.num_steps[group_id] del self.sampling_params[group_id] self.frontend.print_response(seq_group) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + assert self.block_manager.can_swap_out(seq_group) + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 8cdd977237f1..607ea328c30a 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -7,7 +7,7 @@ class SequenceStatus(enum.Enum): - PENDING = enum.auto() + WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() FINISHED = enum.auto() @@ -28,7 +28,7 @@ def __init__( # Initialize the logical token blocks with the given token ids. self.add(token_ids) - self.status = SequenceStatus.PENDING + self.status = SequenceStatus.WAITING self.output_logprobs: List[Dict[int, float]] = [] self.cumulative_logprobs = 0.0 @@ -88,9 +88,11 @@ def __init__( self, group_id: int, seqs: List[Sequence], + arrival_time: float, ) -> None: self.group_id = group_id self.seqs = seqs + self.arrival_time = arrival_time def get_seqs( self, diff --git a/server.py b/server.py index 5838f439f53e..c9da9626e44d 100644 --- a/server.py +++ b/server.py @@ -99,6 +99,9 @@ def main(args: argparse.Namespace): max_num_batched_tokens=args.max_batch_size) num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks( swap_space=args.swap_space) + if num_cpu_blocks > 0: + raise ValueError( + 'CPU blocks are not used. Please set --swap-space to 0.') print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') # Create a controller for each pipeline stage. @@ -152,7 +155,7 @@ def main(args: argparse.Namespace): text, sampling_params = test_inputs.pop(0) frontend.query(text, **sampling_params) scheduler.step() - if not (scheduler.pending or scheduler.running or test_inputs): + if not (scheduler.waiting or scheduler.running or test_inputs): break @@ -171,7 +174,7 @@ def main(args: argparse.Namespace): parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') - parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') + parser.add_argument('--swap-space', type=int, default=0, help='CPU swap space size (GiB) per GPU') parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') args = parser.parse_args() From 3c2da3e374ed1a589e4b2d794a9aa574465666cd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 21:20:17 +0000 Subject: [PATCH 05/19] Bug fix in is_prompt --- cacheflow/master/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 8aaf1e34728d..35e1aca477fe 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -102,6 +102,7 @@ def step(self) -> None: ) # Join new sequences if possible. + prompt_group_ids: List[int] = [] self.waiting = self.policy.sort_by_priority(now, self.waiting) # FIXME(woosuk): This does not work if sequence groups have more than # one sequence. @@ -121,14 +122,13 @@ def step(self) -> None: self._allocate(seq_group) self.running.append(seq_group) num_batched_tokens += num_prompt_tokens + prompt_group_ids.append(seq_group.group_id) # Create input data structures. input_seq_groups: List[SequenceGroupInputs] = [] for seq_group in self.running: group_id = seq_group.group_id - num_steps = self.num_steps[group_id] - - is_prompt = num_steps == 0 + is_prompt = group_id in prompt_group_ids input_tokens: Dict[int, List[int]] = {} seq_logprobs: Dict[int, float] = {} From bcf6a6fd70c1cec434338ae99fe02c9f75e118f4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 21:23:55 +0000 Subject: [PATCH 06/19] Minor --- cacheflow/master/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 35e1aca477fe..cb81e0e37434 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -232,7 +232,7 @@ def _append( seq_group: SequenceGroup, blocks_to_copy: Dict[int, List[int]], ) -> None: - for seq in seq_group.seqs(status=SequenceStatus.RUNNING): + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): ret = self.block_manager.append(seq) if ret is not None: src_block, dst_block = ret From 4b56dc139082715fb6c933cc1d57078d1b7b7deb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 27 Mar 2023 22:12:06 +0000 Subject: [PATCH 07/19] Minor --- cacheflow/master/scheduler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index cb81e0e37434..98838449da4c 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -85,10 +85,12 @@ def step(self) -> None: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop(-1) self._preempt(victim_seq_group) + self.waiting.append(victim_seq_group) else: # No other sequence groups can be preempted. # Preempt the current sequence group. self._preempt(seq_group) + self.waiting.append(seq_group) break else: # Append new slots to the sequence group. @@ -249,8 +251,9 @@ def _preempt(self, seq_group: SequenceGroup) -> None: # recompute them when the sequences are resumed. # We originally used swapping, but it turned out that recomputation # is more efficient. We keep the swapping code for future reference. - self.status = SequenceStatus.WAITING - self.block_manager.free(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) def _free_seq(self, seq: Sequence) -> None: seq.status = SequenceStatus.FINISHED From 9578fea44d13bee44fb7bd66e507225d4f5e3b31 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 07:10:09 +0000 Subject: [PATCH 08/19] Add back swapping --- cacheflow/master/scheduler.py | 71 ++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 18 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 98838449da4c..10f5a6848e59 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -48,7 +48,6 @@ def __init__( # Mapping: group_id -> sampling params. self.sampling_params: Dict[int, SamplingParams] = {} # Sequence groups in the SWAPPED state. - # NOTE(woosuk): This is not used for now. self.swapped: List[SequenceGroup] = [] def _fetch_requests(self) -> int: @@ -60,10 +59,12 @@ def _fetch_requests(self) -> int: def step(self) -> None: # Blocks that need to be swaped or copied before model execution. + blocks_to_swap_in: Dict[int, int] = {} + blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} # Fetch new requests. - num_requests = self._fetch_requests() + self._fetch_requests() # Fix the current time. now = time.time() @@ -84,13 +85,11 @@ def step(self) -> None: if self.running: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop(-1) - self._preempt(victim_seq_group) - self.waiting.append(victim_seq_group) + self._preempt(victim_seq_group, blocks_to_swap_out) else: # No other sequence groups can be preempted. # Preempt the current sequence group. - self._preempt(seq_group) - self.waiting.append(seq_group) + self._preempt(seq_group, blocks_to_swap_out) break else: # Append new slots to the sequence group. @@ -98,18 +97,34 @@ def step(self) -> None: running.append(seq_group) self.running = running + # Swap in the sequence groups in the SWAPPED state if possible. + # NOTE(woosuk): The sequence groups in the SWAPPED state are + # prioritized over the sequence groups in the WAITING state. + # This is because the sequence groups in the SWAPPED state take up + # CPU memory, which is limited. + self.swapped = self.policy.sort_by_priority(now, self.swapped) + while self.swapped: + seq_group = self.swapped[0] + # If the sequence group cannot be swapped in, stop joining. + if not self.block_manager.can_swap_in(seq_group): + break + + seq_group = self.swapped.pop(0) + self._swap_in(seq_group, blocks_to_swap_in) + self._append(seq_group, blocks_to_copy) + self.running.append(seq_group) + num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running ) - # Join new sequences if possible. + # Join waiting sequences if possible. prompt_group_ids: List[int] = [] self.waiting = self.policy.sort_by_priority(now, self.waiting) - # FIXME(woosuk): This does not work if sequence groups have more than - # one sequence. while self.waiting: seq_group = self.waiting[0] + assert seq_group.num_seqs() == 1 # If the sequence group cannot be allocated, stop joining. if not self.block_manager.can_allocate(seq_group): break @@ -159,12 +174,13 @@ def step(self) -> None: input_seq_groups.append(input_seq_group) # Execute the first stage of the pipeline. - if input_seq_groups: + if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out: + assert not (blocks_to_swap_in and blocks_to_swap_out) self.controllers[0].execute_stage( input_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - blocks_to_swap_in={}, - blocks_to_swap_out={}, ) def post_step( @@ -243,17 +259,36 @@ def _append( else: blocks_to_copy[src_block] = [dst_block] - def _preempt(self, seq_group: SequenceGroup) -> None: + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: # NOTE(woosuk): There are two preemption mechanisms. # 1. Swapping: Swap out the blocks of the preempted sequences to CPU # memory and swap them back in when the sequences are resumed. # 2. Recomputation: Discard the blocks of the preempted sequences and # recompute them when the sequences are resumed. - # We originally used swapping, but it turned out that recomputation - # is more efficient. We keep the swapping code for future reference. - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) + + # How to choose between the two? + # Recomputation is more efficient than swapping if the sequence group + # consists of a single sequence. When the sequence group has multiple + # sequences, we only support swapping. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + if len(seqs) == 1: + # Recomputation. + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + self.waiting.append(seq_group) + else: + # Swapping. + for seq in seqs: + seq.status = SequenceStatus.SWAPPED + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) def _free_seq(self, seq: Sequence) -> None: seq.status = SequenceStatus.FINISHED From dd999f45c20baac1fafb8a7cb18787361655a6e4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 07:10:24 +0000 Subject: [PATCH 09/19] Apply watermark to can_swap_in --- cacheflow/master/block_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index 203ab5438643..aad3a164ae9c 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -160,7 +160,8 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append(). - return len(blocks) + num_swapped_seqs <= num_free_blocks + num_required_blocks = len(blocks) + num_swapped_seqs + return num_free_blocks - num_required_blocks >= self.watermark_blocks def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. From 86b7e9ab031880cc1b5bce4cad2bc33035bd2442 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 07:19:28 +0000 Subject: [PATCH 10/19] Refactor & Ensure priority swapped > waiting --- cacheflow/master/scheduler.py | 50 ++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 10f5a6848e59..f62bfea70696 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -98,14 +98,10 @@ def step(self) -> None: self.running = running # Swap in the sequence groups in the SWAPPED state if possible. - # NOTE(woosuk): The sequence groups in the SWAPPED state are - # prioritized over the sequence groups in the WAITING state. - # This is because the sequence groups in the SWAPPED state take up - # CPU memory, which is limited. self.swapped = self.policy.sort_by_priority(now, self.swapped) while self.swapped: seq_group = self.swapped[0] - # If the sequence group cannot be swapped in, stop joining. + # If the sequence group cannot be swapped in, stop. if not self.block_manager.can_swap_in(seq_group): break @@ -119,27 +115,33 @@ def step(self) -> None: for seq_group in self.running ) - # Join waiting sequences if possible. - prompt_group_ids: List[int] = [] - self.waiting = self.policy.sort_by_priority(now, self.waiting) - while self.waiting: - seq_group = self.waiting[0] - assert seq_group.num_seqs() == 1 - # If the sequence group cannot be allocated, stop joining. - if not self.block_manager.can_allocate(seq_group): - break + # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly + # prioritized over the sequence groups in the WAITING state. + # This is because we want to bound the amount of CPU memory taken by + # the swapped sequence groups. + if not self.swapped: + # Join waiting sequences if possible. + prompt_group_ids: List[int] = [] + self.waiting = self.policy.sort_by_priority(now, self.waiting) + + while self.waiting: + seq_group = self.waiting[0] + assert seq_group.num_seqs() == 1 + # If the sequence group cannot be allocated, stop. + if not self.block_manager.can_allocate(seq_group): + break - # If the number of batched tokens exceeds the limit, stop joining. - num_prompt_tokens = seq_group.seqs[0].get_len() - if (num_batched_tokens + num_prompt_tokens - > self.max_num_batched_tokens): - break + # If the number of batched tokens exceeds the limit, stop. + num_prompt_tokens = seq_group.seqs[0].get_len() + if (num_batched_tokens + num_prompt_tokens + > self.max_num_batched_tokens): + break - seq_group = self.waiting.pop(0) - self._allocate(seq_group) - self.running.append(seq_group) - num_batched_tokens += num_prompt_tokens - prompt_group_ids.append(seq_group.group_id) + seq_group = self.waiting.pop(0) + self._allocate(seq_group) + self.running.append(seq_group) + num_batched_tokens += num_prompt_tokens + prompt_group_ids.append(seq_group.group_id) # Create input data structures. input_seq_groups: List[SequenceGroupInputs] = [] From f2f13f59a9cc0b2b8e355c7892393fd6f8e00be4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 07:27:08 +0000 Subject: [PATCH 11/19] Revert server.py --- server.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server.py b/server.py index c9da9626e44d..07e186054d03 100644 --- a/server.py +++ b/server.py @@ -99,9 +99,6 @@ def main(args: argparse.Namespace): max_num_batched_tokens=args.max_batch_size) num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks( swap_space=args.swap_space) - if num_cpu_blocks > 0: - raise ValueError( - 'CPU blocks are not used. Please set --swap-space to 0.') print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') # Create a controller for each pipeline stage. @@ -174,7 +171,7 @@ def main(args: argparse.Namespace): parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') - parser.add_argument('--swap-space', type=int, default=0, help='CPU swap space size (GiB) per GPU') + parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') args = parser.parse_args() From 24968fca78bd58d80eada5f31e3d8a32a47f64b7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 07:30:33 +0000 Subject: [PATCH 12/19] Bugfix --- cacheflow/master/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index f62bfea70696..f98283a2a224 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -126,7 +126,6 @@ def step(self) -> None: while self.waiting: seq_group = self.waiting[0] - assert seq_group.num_seqs() == 1 # If the sequence group cannot be allocated, stop. if not self.block_manager.can_allocate(seq_group): break From 9b159e4b298780eba49bd8fa29332c4cc070cdf4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 08:03:59 +0000 Subject: [PATCH 13/19] Bugfix --- cacheflow/master/scheduler.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index f98283a2a224..0503abd4ba9f 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -65,7 +65,6 @@ def step(self) -> None: # Fetch new requests. self._fetch_requests() - # Fix the current time. now = time.time() @@ -79,6 +78,7 @@ def step(self) -> None: # Reserve new token slots for the running sequence groups. running: List[SequenceGroup] = [] + preempted: List[SequenceGroup] = [] while self.running: seq_group = self.running.pop(0) while not self.block_manager.can_append(seq_group): @@ -86,10 +86,12 @@ def step(self) -> None: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop(-1) self._preempt(victim_seq_group, blocks_to_swap_out) + preempted.append(victim_seq_group) else: # No other sequence groups can be preempted. # Preempt the current sequence group. self._preempt(seq_group, blocks_to_swap_out) + preempted.append(seq_group) break else: # Append new slots to the sequence group. @@ -101,6 +103,9 @@ def step(self) -> None: self.swapped = self.policy.sort_by_priority(now, self.swapped) while self.swapped: seq_group = self.swapped[0] + # If the sequence group has been preempted in this step, stop. + if seq_group in preempted: + break # If the sequence group cannot be swapped in, stop. if not self.block_manager.can_swap_in(seq_group): break @@ -115,17 +120,19 @@ def step(self) -> None: for seq_group in self.running ) + # Join waiting sequences if possible. + prompt_group_ids: List[int] = [] # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly # prioritized over the sequence groups in the WAITING state. # This is because we want to bound the amount of CPU memory taken by # the swapped sequence groups. if not self.swapped: - # Join waiting sequences if possible. - prompt_group_ids: List[int] = [] self.waiting = self.policy.sort_by_priority(now, self.waiting) - while self.waiting: seq_group = self.waiting[0] + # If the sequence group has been preempted in this step, stop. + if seq_group in preempted: + break # If the sequence group cannot be allocated, stop. if not self.block_manager.can_allocate(seq_group): break From 2452758da63368b816ae2c65ed688d5d72510f55 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 28 Mar 2023 08:33:51 +0000 Subject: [PATCH 14/19] Minor fix --- cacheflow/master/block_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index aad3a164ae9c..abc0486eefcb 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -77,11 +77,11 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup) -> bool: # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This is not true for preempted sequences. + # the same prompt. This may not be true for preempted sequences. seq = seq_group.seqs[0] num_required_blocks = len(seq.logical_token_blocks) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - # Use watermark to avoid thrashing. + # Use watermark to avoid frequent preemptions. return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks def allocate(self, seq_group: SequenceGroup) -> None: From 0a55d43b3968410293cec96b4d4063faaef59382 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Mar 2023 23:19:41 +0000 Subject: [PATCH 15/19] Fix merge errors --- cacheflow/master/scheduler.py | 8 ++++---- cacheflow/master/server.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 43b45ddb4fae..f5c606412b1b 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -51,12 +51,12 @@ def add_sequence_groups( self, seq_groups: List[Tuple[SequenceGroup, SamplingParams]], ) -> None: - # Add sequence groups to the pending queue. + # Add sequence groups to the waiting queue. for seq_group, sampling_params in seq_groups: - self.pending.append(seq_group) + self.waiting.append(seq_group) self.sampling_params[seq_group.group_id] = sampling_params - def step(self) -> None: + def step(self) -> List[SequenceGroup]: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} @@ -303,7 +303,7 @@ def _free_seq(self, seq: Sequence) -> None: seq.status = SequenceStatus.FINISHED self.block_manager.free(seq) - def _return(self, seq_group: SequenceGroup) -> None: + def _free_seq_group(self, seq_group: SequenceGroup) -> None: group_id = seq_group.group_id del self.num_steps[group_id] del self.sampling_params[group_id] diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 4498649ad936..1f224316c01b 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -92,7 +92,7 @@ def step(self): return self.scheduler.step() def has_unfinished_requests(self): - return (self.scheduler.pending or self.scheduler.running or + return (self.scheduler.waiting or self.scheduler.running or self.scheduler.swapped) From 9edf814b149f445cbc8db3b660a4a084e8e68196 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Mar 2023 23:26:40 +0000 Subject: [PATCH 16/19] Add more comments --- cacheflow/master/scheduler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index f5c606412b1b..bf9a57241240 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -182,6 +182,7 @@ def step(self) -> List[SequenceGroup]: # Execute the first stage of the pipeline. if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out: + # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) self.controllers[0].execute_stage( input_seq_groups, @@ -285,6 +286,11 @@ def _preempt( # sequences, we only support swapping. # TODO(woosuk): Support recomputation for sequence groups with multiple # sequences. + + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # Because swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) if len(seqs) == 1: # Recomputation. From 8d81e01ff9aee8d6bc060dd6ff74f9c2714bdf61 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Mar 2023 23:34:14 +0000 Subject: [PATCH 17/19] Minor --- cacheflow/master/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index bf9a57241240..80e60ed5366f 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -28,7 +28,7 @@ def __init__( self.max_num_batched_tokens = max_num_batched_tokens # Instantiate the scheduling policy. - self.policy = PolicyFactory.get_policy('fcfs') + self.policy = PolicyFactory.get_policy(policy_name='fcfs') # Create the block space manager. self.block_manager = BlockSpaceManager( block_size=block_size, From ea8c27c398ad7ae5611a5b1cec0f15fd62981af6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 00:03:33 +0000 Subject: [PATCH 18/19] Add arrival time in fastapi frontend --- cacheflow/http_frontend/fastapi_frontend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index dff7f7526ac6..d901baac82fc 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -84,8 +84,9 @@ async def generate(self, request_dict: Dict): seq = Sequence(seq_id, token_ids, block_size=self.block_size) seqs.append(seq) + arrival_time = time.time() group_id = next(self.seq_group_counter) - seq_group = SequenceGroup(group_id, seqs) + seq_group = SequenceGroup(group_id, seqs, arrival_time) group_event = asyncio.Event() self.sequence_group_events[group_id] = group_event await self.server.add_sequence_groups.remote([(seq_group, sampling_params)]) From 4c798c3b6fe8039ef60d8bc7731428a5a6226eb7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 21:43:49 +0000 Subject: [PATCH 19/19] Move scheduling logic to _schedule & Add preemption mode --- cacheflow/master/scheduler.py | 103 +++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 28 deletions(-) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 80e60ed5366f..c0ab33066c97 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,5 +1,6 @@ +import enum import time -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.policy import PolicyFactory @@ -11,6 +12,19 @@ from cacheflow.sequence import SequenceStatus +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + class Scheduler: def __init__( @@ -56,7 +70,9 @@ def add_sequence_groups( self.waiting.append(seq_group) self.sampling_params[seq_group.group_id] = sampling_params - def step(self) -> List[SequenceGroup]: + def _schedule( + self, + ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} @@ -146,6 +162,21 @@ def step(self) -> List[SequenceGroup]: num_batched_tokens += num_prompt_tokens prompt_group_ids.append(seq_group.group_id) + return (blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + prompt_group_ids) + + def step(self) -> List[SequenceGroup]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_output = self._schedule() + blocks_to_swap_in = scheduler_output[0] + blocks_to_swap_out = scheduler_output[1] + blocks_to_copy = scheduler_output[2] + prompt_group_ids = scheduler_output[3] + # Create input data structures. input_seq_groups: List[SequenceGroupInputs] = [] updated_seq_groups: List[SequenceGroup] = self.running.copy() @@ -273,37 +304,53 @@ def _preempt( self, seq_group: SequenceGroup, blocks_to_swap_out: Dict[int, int], + preemption_mode: Optional[PreemptionMode] = None, ) -> None: - # NOTE(woosuk): There are two preemption mechanisms. - # 1. Swapping: Swap out the blocks of the preempted sequences to CPU - # memory and swap them back in when the sequences are resumed. - # 2. Recomputation: Discard the blocks of the preempted sequences and - # recompute them when the sequences are resumed. - - # How to choose between the two? - # Recomputation is more efficient than swapping if the sequence group - # consists of a single sequence. When the sequence group has multiple - # sequences, we only support swapping. - # TODO(woosuk): Support recomputation for sequence groups with multiple - # sequences. - + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not supported. In such a case, + # we use swapping instead. # FIXME(woosuk): This makes our scheduling policy a bit bizarre. - # Because swapped sequences are prioritized over waiting sequences, + # As swapped sequences are prioritized over waiting sequences, # sequence groups with multiple sequences are implicitly prioritized # over sequence groups with a single sequence. - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - if len(seqs) == 1: - # Recomputation. - for seq in seqs: - seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) - self.waiting.append(seq_group) + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if preemption_mode is None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + if len(seqs) == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) else: - # Swapping. - for seq in seqs: - seq.status = SequenceStatus.SWAPPED - self._swap_out(seq_group, blocks_to_swap_out) - self.swapped.append(seq_group) + assert False, 'Invalid preemption mode.' + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + self.waiting.append(seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + for seq in seqs: + seq.status = SequenceStatus.SWAPPED + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) def _free_seq(self, seq: Sequence) -> None: seq.status = SequenceStatus.FINISHED