Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 104 additions & 46 deletions flexkv/cache/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def insert(self,
def lock_node(self, node: CRadixNode) -> None:
self.index.lock(node)

def cleanup(self, node: CRadixNode, cleanup_length: int) -> None:
def unlock(self, node: CRadixNode) -> None:
self.index.unlock(node)
self.index.set_ready(node, True, cleanup_length)

def set_ready(self, node: CRadixNode, ready: bool, ready_length: int) -> None:
self.index.set_ready(node, ready, ready_length)

def take(self,
num_required_blocks: int,
Expand All @@ -121,7 +123,10 @@ def take(self,
if num_required_blocks > self.mempool.num_free_blocks:
if protected_node is not None:
self.index.lock(protected_node)
evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio))
evict_block_num = max(
num_required_blocks - self.mempool.num_free_blocks,
int(self.mempool.num_total_blocks * self.evict_ratio)
)
target_blocks = torch.zeros(evict_block_num, dtype=torch.int64)
num_evicted = self.index.evict(target_blocks, evict_block_num)
if num_evicted != evict_block_num:
Expand Down Expand Up @@ -188,9 +193,11 @@ def insert(self,
def lock_node(self, node: RadixNode) -> None:
self.index.lock(node)

def cleanup(self, node: RadixNode, cleanup_length: int) -> None:
def unlock(self, node: RadixNode) -> None:
self.index.unlock(node)
self.index.set_ready(node, True, cleanup_length)

def set_ready(self, node: RadixNode, ready: bool, ready_length: int) -> None:
self.index.set_ready(node, ready, ready_length)

def take(self,
num_required_blocks: int,
Expand All @@ -199,7 +206,10 @@ def take(self,
if num_required_blocks > self.mempool.num_free_blocks:
if protected_node is not None:
self.index.lock(protected_node)
evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio))
evict_block_num = max(
num_required_blocks - self.mempool.num_free_blocks,
int(self.mempool.num_total_blocks * self.evict_ratio)
)
self.mempool.recycle_blocks(
self.index.evict(evict_block_num)
)
Expand Down Expand Up @@ -264,10 +274,10 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig):
cache_config.evict_ratio)
self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine

self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \
lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, 0)
self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]] = \
lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, 0, 0)
self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]] = \
lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0)
self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]] = \
lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0, 0)

def reset(self) -> None:
if self.cpu_cache_engine:
Expand All @@ -284,7 +294,7 @@ def get(self,
slot_mapping: np.ndarray,
layer_num: int = -1,
layer_granularity: int = -1,
dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]:
dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]:
self._check_input(token_ids, token_mask, slot_mapping)

if layer_num == -1:
Expand Down Expand Up @@ -312,25 +322,27 @@ def get(self,
tokens_per_block=self.cache_config.tokens_per_block)

if not self.cache_config.enable_remote:
transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, num_gpu_blocks_to_transfer = \
(transfer_graph, finished_ops_ids, node_to_unlock,
op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer) = \
self._get_impl_local(
request_id,
sequence_meta,
block_start_idx,
block_end_idx,
gpu_block_ids,
layer_num
)
request_id,
sequence_meta,
block_start_idx,
block_end_idx,
gpu_block_ids,
layer_num
)
else:
transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, num_gpu_blocks_to_transfer = \
(transfer_graph, finished_ops_ids, node_to_unlock,
op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer) = \
self._get_impl_global(
request_id,
sequence_meta,
block_start_idx,
block_end_idx,
gpu_block_ids,
layer_num
)
request_id,
sequence_meta,
block_start_idx,
block_end_idx,
gpu_block_ids,
layer_num
)

transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops(
transfer_graph,
Expand All @@ -355,15 +367,21 @@ def get(self,
node_to_unlock=node_to_unlock,
buffer_to_free=buffer_to_free)

return transfer_graph, return_mask, callback, task_end_op_id
op_callback_dict = {} # dict, op_id -> callback
for op_id in op_node_to_ready:
op_callback_dict[op_id] = partial(self._op_callback,
device_type=op_node_to_ready[op_id][0],
node_to_ready=op_node_to_ready[op_id][1],
ready_length=op_node_to_ready[op_id][2])
return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id

def _get_impl_global(self,
request_id: int,
sequence_meta: SequenceMeta,
block_mask_start: int,
block_mask_end: int,
gpu_block_ids: np.ndarray,
layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]:
layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]:
"""
transfer pattern:

Expand Down Expand Up @@ -521,15 +539,18 @@ def _get_impl_global(self,
buffer_to_free = {DeviceType.CPU: cpu_blocks_to_free}

# NOTE: for now in build transfer graph, we assume that cpu works as a cache for ssd
return transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, len(fragment123_gpu_blocks)
return (
transfer_graph, finished_ops_ids, node_to_unlock, {}, buffer_to_free,
len(fragment123_gpu_blocks) # op_node_to_ready: {}
)

def _get_impl_local(self,
request_id: int,
sequence_meta: SequenceMeta,
block_mask_start: int,
block_mask_end: int,
gpu_block_ids: np.ndarray,
layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]:
layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]:
"""
transfer pattern:

Expand Down Expand Up @@ -566,6 +587,7 @@ def _get_impl_local(self,

transfer_graph = TransferOpGraph()
finished_ops_ids = []
op_node_to_ready = {}

fragment12_gpu_blocks = gpu_block_ids[:fragment12_num_blocks]
fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:]
Expand Down Expand Up @@ -610,6 +632,7 @@ def _get_impl_local(self,
block_mask_start,
is_ready=False,
match_result=cpu_matched_result)
op_node_to_ready[op_disk2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size())
else:
cpu_blocks_to_free = fragment2_cpu_blocks
if fragment2_cpu_blocks is not None:
Expand All @@ -636,15 +659,18 @@ def _get_impl_local(self,
node_to_unlock[DeviceType.SSD] = (ssd_node_to_unlock, ssd_node_to_unlock.size())
buffer_to_free = {DeviceType.CPU: cpu_blocks_to_free}

return transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, len(fragment12_gpu_blocks)
return (
transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready,
buffer_to_free, len(fragment12_gpu_blocks)
)

def put(self,
request_id: int,
token_ids: np.ndarray,
token_mask: np.ndarray,
slot_mapping: np.ndarray,
layer_num : int = -1,
dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]:
dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]:
self._check_input(token_ids, token_mask, slot_mapping)

if layer_num == -1:
Expand All @@ -665,7 +691,7 @@ def put(self,
tokens_per_block=self.cache_config.tokens_per_block)

if not self.cache_config.enable_remote:
(transfer_graph, finished_ops_ids, node_to_unlock,
(transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready,
buffer_to_free, num_gpu_blocks_to_transfer, skipped_gpu_blocks) = \
self._put_impl_local(
request_id,
Expand All @@ -676,7 +702,7 @@ def put(self,
layer_num
)
else:
(transfer_graph, finished_ops_ids, node_to_unlock,
(transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready,
buffer_to_free, num_gpu_blocks_to_transfer, skipped_gpu_blocks) = \
self._put_impl_global(
request_id,
Expand Down Expand Up @@ -704,15 +730,22 @@ def put(self,
node_to_unlock=node_to_unlock,
buffer_to_free=buffer_to_free)

return transfer_graph, return_mask, callback, task_end_op_id
op_callback_dict = {}
for op_id in op_node_to_ready:
op_callback_dict[op_id] = partial(self._op_callback,
device_type=op_node_to_ready[op_id][0],
node_to_ready=op_node_to_ready[op_id][1],
ready_length=op_node_to_ready[op_id][2])

return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id

def _put_impl_global(self,
request_id: int,
sequence_meta: SequenceMeta,
block_mask_start: int,
block_mask_end: int,
gpu_block_ids: np.ndarray,
layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]:
layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]:
"""
transfer pattern:

Expand Down Expand Up @@ -859,15 +892,18 @@ def _put_impl_global(self,
node_to_unlock[DeviceType.REMOTE] = (remote_node_to_unlock, remote_node_to_unlock.size())

skipped_gpu_blocks = len(cpu_matched_blocks)
return transfer_graph, finished_ops_ids, node_to_unlock, {}, len(fragment12_gpu_blocks), skipped_gpu_blocks
return (
transfer_graph, finished_ops_ids, node_to_unlock, {}, {},
len(fragment12_gpu_blocks), skipped_gpu_blocks # op_node_to_ready: {}
)

def _put_impl_local(self,
request_id: int,
sequence_meta: SequenceMeta,
block_mask_start: int,
block_mask_end: int,
gpu_block_ids: np.ndarray,
layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]:
layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]:
"""
transfer pattern:

Expand Down Expand Up @@ -923,6 +959,7 @@ def _put_impl_local(self,

transfer_graph = TransferOpGraph()
finished_ops_ids = []
op_node_to_ready = {}

op_d2h = TransferOp(
graph_id = transfer_graph.graph_id,
Expand Down Expand Up @@ -954,33 +991,43 @@ def _put_impl_local(self,
fragment12_cpu_blocks,
is_ready=False,
match_result=cpu_matched_result)
op_node_to_ready[op_d2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size())
ssd_node_to_unlock = None
if len(fragment2_ssd_blocks) > 0:
ssd_node_to_unlock = self.ssd_cache_engine.insert(sequence_meta,
fragment2_ssd_blocks,
is_ready=False,
match_result=ssd_matched_result)
op_node_to_ready[op_h2disk.op_id] = (DeviceType.SSD, ssd_node_to_unlock, ssd_node_to_unlock.size())
node_to_unlock = {}
if cpu_node_to_unlock is not None:
node_to_unlock[DeviceType.CPU] = (cpu_node_to_unlock, cpu_node_to_unlock.size())
if ssd_node_to_unlock is not None:
node_to_unlock[DeviceType.SSD] = (ssd_node_to_unlock, ssd_node_to_unlock.size())

skipped_gpu_blocks = len(cpu_matched_blocks)
return transfer_graph, finished_ops_ids, node_to_unlock, {}, len(fragment12_gpu_blocks), skipped_gpu_blocks
return (
transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, {},
len(fragment12_gpu_blocks), skipped_gpu_blocks
)

def _transfer_callback(self,
node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]],
buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None) -> None:
if DeviceType.CPU in node_to_unlock:
assert self.cpu_cache_engine is not None
self.cpu_cache_engine.cleanup(node_to_unlock[DeviceType.CPU][0], node_to_unlock[DeviceType.CPU][1])
self.cpu_cache_engine.unlock(node_to_unlock[DeviceType.CPU][0])
self.cpu_cache_engine.set_ready(node_to_unlock[DeviceType.CPU][0], True, node_to_unlock[DeviceType.CPU][1])
if DeviceType.SSD in node_to_unlock:
assert self.ssd_cache_engine is not None
self.ssd_cache_engine.cleanup(node_to_unlock[DeviceType.SSD][0], node_to_unlock[DeviceType.SSD][1])
self.ssd_cache_engine.unlock(node_to_unlock[DeviceType.SSD][0])
self.ssd_cache_engine.set_ready(node_to_unlock[DeviceType.SSD][0], True, node_to_unlock[DeviceType.SSD][1])
if DeviceType.REMOTE in node_to_unlock:
assert self.remote_cache_engine is not None
self.remote_cache_engine.cleanup(node_to_unlock[DeviceType.REMOTE][0], node_to_unlock[DeviceType.REMOTE][1])
self.remote_cache_engine.unlock(node_to_unlock[DeviceType.REMOTE][0])
self.remote_cache_engine.set_ready(
node_to_unlock[DeviceType.REMOTE][0], True, node_to_unlock[DeviceType.REMOTE][1]
)
if buffer_to_free is not None:
if DeviceType.CPU in buffer_to_free:
assert self.cpu_cache_engine is not None
Expand All @@ -992,6 +1039,17 @@ def _transfer_callback(self,
assert self.remote_cache_engine is not None
self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE])

def _op_callback(self, device_type: DeviceType, node_to_ready: RadixNode, ready_length: int) -> None:
if device_type == DeviceType.CPU:
assert self.cpu_cache_engine is not None
self.cpu_cache_engine.set_ready(node_to_ready, True, ready_length)
elif device_type == DeviceType.SSD:
assert self.ssd_cache_engine is not None
self.ssd_cache_engine.set_ready(node_to_ready, True, ready_length)
elif device_type == DeviceType.REMOTE:
assert self.remote_cache_engine is not None
self.remote_cache_engine.set_ready(node_to_ready, True, ready_length)

@nvtx.annotate("Match Prefix Accel", color="yellow")
def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]:
cpu_matched_result = MatchResultAccel()
Expand All @@ -1002,7 +1060,7 @@ def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAcc
ssd_matched_result = self.ssd_cache_engine.match(sequence_meta)

return cpu_matched_result, ssd_matched_result

@nvtx.annotate("Match Prefix", color="yellow")
def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]:
cpu_matched_result = MatchResult()
Expand All @@ -1013,7 +1071,7 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe
ssd_matched_result = self.ssd_cache_engine.match(sequence_meta)

return cpu_matched_result, ssd_matched_result

@nvtx.annotate("Match All Prefix accel", color="yellow")
def match_all_accel(self,
sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]:
Expand All @@ -1029,7 +1087,7 @@ def match_all_accel(self,
remote_matched_result = self.remote_cache_engine.match(sequence_meta)

return cpu_matched_result, ssd_matched_result, remote_matched_result

@nvtx.annotate("Match All Prefix", color="yellow")
def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]:
cpu_matched_result = MatchResult()
Expand Down
Loading