diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index c03ad74c40..1deb723143 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -196,7 +196,7 @@ PYBIND11_MODULE(c_ext, m) { #endif py::class_(m, "CRadixTreeIndex") - .def(py::init()) + .def(py::init()) .def("is_empty", &flexkv::CRadixTreeIndex::is_empty) .def("reset", &flexkv::CRadixTreeIndex::reset) .def("lock", &flexkv::CRadixTreeIndex::lock, py::arg("node")) diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp index 32dc8ab0e7..9259678d0b 100644 --- a/csrc/radix_tree.cpp +++ b/csrc/radix_tree.cpp @@ -21,7 +21,7 @@ CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) { struct timeval now; gettimeofday(&now, nullptr); - last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + grace_time = now.tv_sec * 1000 + now.tv_usec / 10000; index->inc_node_count(); } @@ -215,7 +215,7 @@ std::shared_ptr CRadixTreeIndex::match_prefix( while (prefix_blocks_num < num_blocks) { if (update_cache_info) { - current_node->update_time(); + current_node->update_time(hit_reward_seconds); } child_hash = HashType(block_hashes_ptr[prefix_blocks_num + current_node->size()]); diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h index 63560a3a8d..cf448a5009 100644 --- a/csrc/radix_tree.h +++ b/csrc/radix_tree.h @@ -16,7 +16,7 @@ class CRadixNode { bool on_leaf; bool ready; int lock_cnt; - time_t last_access_time; + time_t grace_time; std::deque block_hashes; std::deque physical_blocks; @@ -48,18 +48,25 @@ class CRadixNode { } void set_time(time_t time) { - last_access_time = time; + grace_time = time; } time_t get_time() { - return last_access_time; + return grace_time; } - void update_time() { + void update_time(int hit_reward_seconds) { struct timeval now; + time_t now_time; gettimeofday(&now, nullptr); - last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + now_time = now.tv_sec * 1000 + now.tv_usec / 10000; + + if (grace_time > now_time) { + grace_time += hit_reward_seconds; + } else { + grace_time = now_time + hit_reward_seconds; + } } CRadixNode *get_parent() { @@ -188,12 +195,14 @@ class CRadixTreeIndex { int max_num_blocks; int tokens_per_block; int node_count; + int hit_reward_seconds; public: - CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000) { + CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000, int hit_reward_seconds = 0) { this->tokens_per_block = tokens_per_block; this->max_num_blocks = max_num_blocks; this->node_count = 0; + this->hit_reward_seconds = hit_reward_seconds; root = new CRadixNode(this, true, 0); node_list.push_back(root); diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index e113aeb69b..5a03f9da74 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -52,7 +52,8 @@ def __init__(self, device_type: DeviceType, num_total_blocks: int, tokens_per_block: int, - evict_ratio: float): + evict_ratio: float, + hit_reward_seconds: int): if not isinstance(device_type, DeviceType): raise InvalidConfigError(f"Unknown device type: {device_type}") if num_total_blocks <= 0: @@ -63,7 +64,7 @@ def __init__(self, self.device_type = device_type - self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks) + self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks, hit_reward_seconds) self.mempool = Mempool(num_total_blocks=num_total_blocks) @@ -145,7 +146,8 @@ def __init__(self, device_type: DeviceType, num_total_blocks: int, tokens_per_block: int, - evict_ratio: float): + evict_ratio: float, + hit_reward_seconds: int): if not isinstance(device_type, DeviceType): raise InvalidConfigError(f"Unknown device type: {device_type}") if num_total_blocks <= 0: @@ -156,7 +158,7 @@ def __init__(self, self.device_type = device_type - self.index = RadixTreeIndex(tokens_per_block=tokens_per_block) + self.index = RadixTreeIndex(tokens_per_block=tokens_per_block, hit_reward_seconds=hit_reward_seconds) self.mempool = Mempool(num_total_blocks=num_total_blocks) @@ -232,36 +234,42 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, cache_config.tokens_per_block, - cache_config.evict_ratio) + cache_config.evict_ratio, + cache_config.hit_reward_seconds) else: self.cpu_cache_engine = CacheEngine(DeviceType.CPU, cache_config.num_cpu_blocks, cache_config.tokens_per_block, - cache_config.evict_ratio) + cache_config.evict_ratio, + cache_config.hit_reward_seconds) self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: if cache_config.index_accel: self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, cache_config.num_ssd_blocks, cache_config.tokens_per_block, - cache_config.evict_ratio) + cache_config.evict_ratio, + cache_config.hit_reward_seconds) else: self.ssd_cache_engine = CacheEngine(DeviceType.SSD, cache_config.num_ssd_blocks, cache_config.tokens_per_block, - cache_config.evict_ratio) + cache_config.evict_ratio, + cache_config.hit_reward_seconds) self.cache_engines[DeviceType.SSD] = self.ssd_cache_engine if cache_config.enable_remote: if cache_config.index_accel: self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, cache_config.num_remote_blocks, cache_config.tokens_per_block, - cache_config.evict_ratio) + cache_config.evict_ratio, + cache_config.hit_reward_seconds) else: self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, cache_config.num_remote_blocks, cache_config.tokens_per_block, - cache_config.evict_ratio) + cache_config.evict_ratio, + cache_config.hit_reward_seconds) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \ diff --git a/flexkv/cache/radixtree.py b/flexkv/cache/radixtree.py index a9f69c6a35..c2dec039d0 100644 --- a/flexkv/cache/radixtree.py +++ b/flexkv/cache/radixtree.py @@ -48,7 +48,7 @@ class RadixNode: is_ready: bool lock_cnt: int - last_access_time: float + grace_time: float parent: Optional['RadixNode'] = None children: Dict[Optional[HashType], 'RadixNode'] = field(default_factory=dict) @@ -59,7 +59,7 @@ def __post_init__(self) -> None: assert self.block_hashes.size == self.physical_blocks.size def __lt__(self, other: 'RadixNode') -> bool: - return self.last_access_time < other.last_access_time + return self.grace_time < other.grace_time def size(self) -> int: return self.block_hashes.size @@ -93,7 +93,7 @@ def split(self, prefix_length: int) -> 'RadixNode': physical_blocks=self.physical_blocks[:prefix_length], is_ready=self.is_ready, lock_cnt=0, # Note: only lock near-leaf node - last_access_time=self.last_access_time, + grace_time=self.grace_time, ) self.block_hashes = self.block_hashes[prefix_length:] self.physical_blocks = self.physical_blocks[prefix_length:] @@ -120,16 +120,16 @@ def merge_child(self) -> None: # ignore status child = list(self.children.values())[0] self.block_hashes = np.concatenate([self.block_hashes, child.block_hashes]) self.physical_blocks = np.concatenate([self.physical_blocks, child.physical_blocks]) - self.last_access_time = max(self.last_access_time, child.last_access_time) + self.grace_time = max(self.grace_time, child.grace_time) self.children.clear() class RadixTreeIndex: - def __init__(self, tokens_per_block: int, max_num_blocks: int = 1000000): + def __init__(self, tokens_per_block: int, max_num_blocks: int = 1000000, hit_reward_seconds: int = 0): self.root_node: RadixNode = RadixNode(block_hashes=np.array([], dtype=np.int64), physical_blocks=np.array([], dtype=np.int64), is_ready=True, lock_cnt=0, - last_access_time=time.time()) + grace_time=time.time()) self.tokens_per_block = tokens_per_block @@ -137,12 +137,14 @@ def __init__(self, tokens_per_block: int, max_num_blocks: int = 1000000): self.max_num_blocks = max_num_blocks + self.hit_reward_seconds = hit_reward_seconds + def reset(self) -> None: self.root_node = RadixNode(block_hashes=np.array([], dtype=np.int64), physical_blocks=np.array([], dtype=np.int64), is_ready=True, lock_cnt=0, - last_access_time=time.time()) + grace_time=time.time()) self.leaf_nodes.clear() def is_empty(self) -> bool: @@ -160,7 +162,10 @@ def match_prefix(self, physical_blocks = np.array([], dtype=np.int64) while prefix_blocks_num < sequence.num_blocks: if update_cache_info: - current_node.last_access_time = time.time() + if current_node.grace_time < time.time(): + current_node.grace_time = time.time() + hit_reward_seconds + else: + current_node.grace_time += hit_reward_seconds child_hash = sequence.get_hash(prefix_blocks_num + current_node.size()) if child_hash in current_node.children: if current_node.is_ready: @@ -238,7 +243,7 @@ def insert(self, physical_blocks=physical_block_ids, is_ready=is_ready, lock_cnt=0, - last_access_time=time.time() + grace_time=time.time() ) last_node_leaf = last_node.is_leaf() and not last_node.is_root() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 2805ae606c..99b3be1347 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -33,6 +33,7 @@ class CacheConfig: enable_remote: bool = False use_gds: bool = False index_accel: bool = False + hit_reward_seconds: int = 0 # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE