diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index e113aeb69b..bac0e30ec9 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -110,9 +110,11 @@ def insert(self, def lock_node(self, node: CRadixNode) -> None: self.index.lock(node) - def cleanup(self, node: CRadixNode, cleanup_length: int) -> None: + def unlock(self, node: CRadixNode) -> None: self.index.unlock(node) - self.index.set_ready(node, True, cleanup_length) + + def set_ready(self, node: CRadixNode, ready: bool, ready_length: int) -> None: + self.index.set_ready(node, ready, ready_length) def take(self, num_required_blocks: int, @@ -121,7 +123,10 @@ def take(self, if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) - evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + evict_block_num = max( + num_required_blocks - self.mempool.num_free_blocks, + int(self.mempool.num_total_blocks * self.evict_ratio) + ) target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) num_evicted = self.index.evict(target_blocks, evict_block_num) if num_evicted != evict_block_num: @@ -188,9 +193,11 @@ def insert(self, def lock_node(self, node: RadixNode) -> None: self.index.lock(node) - def cleanup(self, node: RadixNode, cleanup_length: int) -> None: + def unlock(self, node: RadixNode) -> None: self.index.unlock(node) - self.index.set_ready(node, True, cleanup_length) + + def set_ready(self, node: RadixNode, ready: bool, ready_length: int) -> None: + self.index.set_ready(node, ready, ready_length) def take(self, num_required_blocks: int, @@ -199,7 +206,10 @@ def take(self, if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) - evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + evict_block_num = max( + num_required_blocks - self.mempool.num_free_blocks, + int(self.mempool.num_total_blocks * self.evict_ratio) + ) self.mempool.recycle_blocks( self.index.evict(evict_block_num) ) @@ -264,10 +274,10 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): cache_config.evict_ratio) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine - self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \ - lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, 0) - self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]] = \ - lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, 0, 0) + self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]] = \ + lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0) + self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]] = \ + lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0, 0) def reset(self) -> None: if self.cpu_cache_engine: @@ -284,7 +294,7 @@ def get(self, slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -312,25 +322,27 @@ def get(self, tokens_per_block=self.cache_config.tokens_per_block) if not self.cache_config.enable_remote: - transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, num_gpu_blocks_to_transfer = \ + (transfer_graph, finished_ops_ids, node_to_unlock, + op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer) = \ self._get_impl_local( - request_id, - sequence_meta, - block_start_idx, - block_end_idx, - gpu_block_ids, - layer_num - ) + request_id, + sequence_meta, + block_start_idx, + block_end_idx, + gpu_block_ids, + layer_num + ) else: - transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, num_gpu_blocks_to_transfer = \ + (transfer_graph, finished_ops_ids, node_to_unlock, + op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer) = \ self._get_impl_global( - request_id, - sequence_meta, - block_start_idx, - block_end_idx, - gpu_block_ids, - layer_num - ) + request_id, + sequence_meta, + block_start_idx, + block_end_idx, + gpu_block_ids, + layer_num + ) transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, @@ -355,7 +367,13 @@ def get(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, task_end_op_id + op_callback_dict = {} # dict, op_id -> callback + for op_id in op_node_to_ready: + op_callback_dict[op_id] = partial(self._op_callback, + device_type=op_node_to_ready[op_id][0], + node_to_ready=op_node_to_ready[op_id][1], + ready_length=op_node_to_ready[op_id][2]) + return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id def _get_impl_global(self, request_id: int, @@ -363,7 +381,7 @@ def _get_impl_global(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: + layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]: """ transfer pattern: @@ -521,7 +539,10 @@ def _get_impl_global(self, buffer_to_free = {DeviceType.CPU: cpu_blocks_to_free} # NOTE: for now in build transfer graph, we assume that cpu works as a cache for ssd - return transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, len(fragment123_gpu_blocks) + return ( + transfer_graph, finished_ops_ids, node_to_unlock, {}, buffer_to_free, + len(fragment123_gpu_blocks) # op_node_to_ready: {} + ) def _get_impl_local(self, request_id: int, @@ -529,7 +550,7 @@ def _get_impl_local(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: + layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]: """ transfer pattern: @@ -566,6 +587,7 @@ def _get_impl_local(self, transfer_graph = TransferOpGraph() finished_ops_ids = [] + op_node_to_ready = {} fragment12_gpu_blocks = gpu_block_ids[:fragment12_num_blocks] fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] @@ -610,6 +632,7 @@ def _get_impl_local(self, block_mask_start, is_ready=False, match_result=cpu_matched_result) + op_node_to_ready[op_disk2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size()) else: cpu_blocks_to_free = fragment2_cpu_blocks if fragment2_cpu_blocks is not None: @@ -636,7 +659,10 @@ def _get_impl_local(self, node_to_unlock[DeviceType.SSD] = (ssd_node_to_unlock, ssd_node_to_unlock.size()) buffer_to_free = {DeviceType.CPU: cpu_blocks_to_free} - return transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, len(fragment12_gpu_blocks) + return ( + transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, + buffer_to_free, len(fragment12_gpu_blocks) + ) def put(self, request_id: int, @@ -644,7 +670,7 @@ def put(self, token_mask: np.ndarray, slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -665,7 +691,7 @@ def put(self, tokens_per_block=self.cache_config.tokens_per_block) if not self.cache_config.enable_remote: - (transfer_graph, finished_ops_ids, node_to_unlock, + (transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer, skipped_gpu_blocks) = \ self._put_impl_local( request_id, @@ -676,7 +702,7 @@ def put(self, layer_num ) else: - (transfer_graph, finished_ops_ids, node_to_unlock, + (transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer, skipped_gpu_blocks) = \ self._put_impl_global( request_id, @@ -704,7 +730,14 @@ def put(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, task_end_op_id + op_callback_dict = {} + for op_id in op_node_to_ready: + op_callback_dict[op_id] = partial(self._op_callback, + device_type=op_node_to_ready[op_id][0], + node_to_ready=op_node_to_ready[op_id][1], + ready_length=op_node_to_ready[op_id][2]) + + return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id def _put_impl_global(self, request_id: int, @@ -712,7 +745,7 @@ def _put_impl_global(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: + layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]: """ transfer pattern: @@ -859,7 +892,10 @@ def _put_impl_global(self, node_to_unlock[DeviceType.REMOTE] = (remote_node_to_unlock, remote_node_to_unlock.size()) skipped_gpu_blocks = len(cpu_matched_blocks) - return transfer_graph, finished_ops_ids, node_to_unlock, {}, len(fragment12_gpu_blocks), skipped_gpu_blocks + return ( + transfer_graph, finished_ops_ids, node_to_unlock, {}, {}, + len(fragment12_gpu_blocks), skipped_gpu_blocks # op_node_to_ready: {} + ) def _put_impl_local(self, request_id: int, @@ -867,7 +903,7 @@ def _put_impl_local(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: + layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]: """ transfer pattern: @@ -923,6 +959,7 @@ def _put_impl_local(self, transfer_graph = TransferOpGraph() finished_ops_ids = [] + op_node_to_ready = {} op_d2h = TransferOp( graph_id = transfer_graph.graph_id, @@ -954,12 +991,14 @@ def _put_impl_local(self, fragment12_cpu_blocks, is_ready=False, match_result=cpu_matched_result) + op_node_to_ready[op_d2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size()) ssd_node_to_unlock = None if len(fragment2_ssd_blocks) > 0: ssd_node_to_unlock = self.ssd_cache_engine.insert(sequence_meta, fragment2_ssd_blocks, is_ready=False, match_result=ssd_matched_result) + op_node_to_ready[op_h2disk.op_id] = (DeviceType.SSD, ssd_node_to_unlock, ssd_node_to_unlock.size()) node_to_unlock = {} if cpu_node_to_unlock is not None: node_to_unlock[DeviceType.CPU] = (cpu_node_to_unlock, cpu_node_to_unlock.size()) @@ -967,20 +1006,28 @@ def _put_impl_local(self, node_to_unlock[DeviceType.SSD] = (ssd_node_to_unlock, ssd_node_to_unlock.size()) skipped_gpu_blocks = len(cpu_matched_blocks) - return transfer_graph, finished_ops_ids, node_to_unlock, {}, len(fragment12_gpu_blocks), skipped_gpu_blocks + return ( + transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, {}, + len(fragment12_gpu_blocks), skipped_gpu_blocks + ) def _transfer_callback(self, node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]], buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None) -> None: if DeviceType.CPU in node_to_unlock: assert self.cpu_cache_engine is not None - self.cpu_cache_engine.cleanup(node_to_unlock[DeviceType.CPU][0], node_to_unlock[DeviceType.CPU][1]) + self.cpu_cache_engine.unlock(node_to_unlock[DeviceType.CPU][0]) + self.cpu_cache_engine.set_ready(node_to_unlock[DeviceType.CPU][0], True, node_to_unlock[DeviceType.CPU][1]) if DeviceType.SSD in node_to_unlock: assert self.ssd_cache_engine is not None - self.ssd_cache_engine.cleanup(node_to_unlock[DeviceType.SSD][0], node_to_unlock[DeviceType.SSD][1]) + self.ssd_cache_engine.unlock(node_to_unlock[DeviceType.SSD][0]) + self.ssd_cache_engine.set_ready(node_to_unlock[DeviceType.SSD][0], True, node_to_unlock[DeviceType.SSD][1]) if DeviceType.REMOTE in node_to_unlock: assert self.remote_cache_engine is not None - self.remote_cache_engine.cleanup(node_to_unlock[DeviceType.REMOTE][0], node_to_unlock[DeviceType.REMOTE][1]) + self.remote_cache_engine.unlock(node_to_unlock[DeviceType.REMOTE][0]) + self.remote_cache_engine.set_ready( + node_to_unlock[DeviceType.REMOTE][0], True, node_to_unlock[DeviceType.REMOTE][1] + ) if buffer_to_free is not None: if DeviceType.CPU in buffer_to_free: assert self.cpu_cache_engine is not None @@ -992,6 +1039,17 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + def _op_callback(self, device_type: DeviceType, node_to_ready: RadixNode, ready_length: int) -> None: + if device_type == DeviceType.CPU: + assert self.cpu_cache_engine is not None + self.cpu_cache_engine.set_ready(node_to_ready, True, ready_length) + elif device_type == DeviceType.SSD: + assert self.ssd_cache_engine is not None + self.ssd_cache_engine.set_ready(node_to_ready, True, ready_length) + elif device_type == DeviceType.REMOTE: + assert self.remote_cache_engine is not None + self.remote_cache_engine.set_ready(node_to_ready, True, ready_length) + @nvtx.annotate("Match Prefix Accel", color="yellow") def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]: cpu_matched_result = MatchResultAccel() @@ -1002,7 +1060,7 @@ def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAcc ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result - + @nvtx.annotate("Match Prefix", color="yellow") def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]: cpu_matched_result = MatchResult() @@ -1013,7 +1071,7 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result - + @nvtx.annotate("Match All Prefix accel", color="yellow") def match_all_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: @@ -1029,7 +1087,7 @@ def match_all_accel(self, remote_matched_result = self.remote_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result, remote_matched_result - + @nvtx.annotate("Match All Prefix", color="yellow") def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]: cpu_matched_result = MatchResult() diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 3d661a66bf..aeea4bc5b2 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -56,6 +56,7 @@ class KVTask: graph: TransferOpGraph return_mask: np.ndarray callback: Optional[Callable] + op_callback_dict: Dict[int, Callable] def is_completed(self) -> bool: return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED] @@ -127,7 +128,7 @@ def create_get_task(self, ) -> None: if task_id in self.tasks: raise ValueError(f"Task ID {task_id} already exists") - graph, return_mask, callback, task_end_op_id = self.cache_engine.get(task_id, + graph, return_mask, callback, op_callback_dict, task_end_op_id = self.cache_engine.get(task_id, token_ids, token_mask, slot_mapping, @@ -146,7 +147,8 @@ def create_get_task(self, dp_id=dp_id, graph=graph, return_mask=return_mask, - callback=callback) + callback=callback, + op_callback_dict=op_callback_dict) self.graph_to_task[graph.graph_id] = task_id @@ -160,7 +162,7 @@ def create_put_task(self, ) -> None: if task_id in self.tasks: raise ValueError(f"Task ID {task_id} already exists") - graph, return_mask, callback, task_end_op_id = self.cache_engine.put(task_id, + graph, return_mask, callback, op_callback_dict, task_end_op_id = self.cache_engine.put(task_id, token_ids, token_mask, slot_mapping, @@ -178,7 +180,8 @@ def create_put_task(self, dp_id=dp_id, graph=graph, return_mask=return_mask, - callback=callback) + callback=callback, + op_callback_dict=op_callback_dict) self.graph_to_task[graph.graph_id] = task_id def _launch_task(self, task_id: int) -> None: @@ -204,6 +207,8 @@ def _update_tasks(self, timeout: float = 0.001) -> None: self._mark_completed(task_id) elif completed_op_id == task.task_end_op_id: self.tasks[task_id].task_end_op_finished = True + if completed_op_id in task.op_callback_dict: + task.op_callback_dict[completed_op_id]() def _cancel_task(self, task_id: int) -> None: task = self.tasks[task_id] diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index e224d4305b..70a12ffeb0 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -176,7 +176,8 @@ def test_take_and_recycle(cache_engine: CacheEngine): cache_engine.lock_node(radixnode) with pytest.raises(NotEnoughSpaceError): cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) - cache_engine.cleanup(radixnode, radixnode.size()) + cache_engine.unlock(radixnode) + cache_engine.set_ready(radixnode, True, radixnode.size()) physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) assert physical_blocks.shape == (num_total_blocks, ) @@ -227,11 +228,14 @@ def test_cleanup(cache_engine: CacheEngine): assert cache_engine.index.total_unready_blocks() == total_insert_blocks assert cache_engine.index.total_ready_blocks() == 0 - cache_engine.cleanup(radixnode2, radixnode2_size) + cache_engine.unlock(radixnode2) + cache_engine.set_ready(radixnode2, True, radixnode2_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 - cache_engine.cleanup(radixnode1, radixnode1_size) + cache_engine.unlock(radixnode1) + cache_engine.set_ready(radixnode1, True, radixnode1_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 - cache_engine.cleanup(radixnode0, radixnode0_size) + cache_engine.unlock(radixnode0) + cache_engine.set_ready(radixnode0, True, radixnode0_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py index 15fef43ec3..3c0e2b3cbe 100644 --- a/tests/test_cache_engine_accel.py +++ b/tests/test_cache_engine_accel.py @@ -172,7 +172,8 @@ def test_take_and_recycle(cache_engine: CacheEngineAccel): cache_engine.lock_node(radixnode) with pytest.raises(NotEnoughSpaceError): cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) - cache_engine.cleanup(radixnode, radixnode.size()) + cache_engine.unlock(radixnode) + cache_engine.set_ready(radixnode, True, radixnode.size()) physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) assert physical_blocks.shape == (num_total_blocks, ) @@ -222,11 +223,14 @@ def test_cleanup(cache_engine: CacheEngineAccel): assert cache_engine.index.total_unready_blocks() == total_insert_blocks assert cache_engine.index.total_ready_blocks() == 0 - cache_engine.cleanup(radixnode2, radixnode2_size) + cache_engine.unlock(radixnode2) + cache_engine.set_ready(radixnode2, True, radixnode2_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 - cache_engine.cleanup(radixnode1, radixnode1_size) + cache_engine.unlock(radixnode1) + cache_engine.set_ready(radixnode1, True, radixnode1_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 - cache_engine.cleanup(radixnode0, radixnode0_size) + cache_engine.unlock(radixnode0) + cache_engine.set_ready(radixnode0, True, radixnode0_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2