Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ PYBIND11_MODULE(c_ext, m) {
#endif

py::class_<flexkv::CRadixTreeIndex>(m, "CRadixTreeIndex")
.def(py::init<int, int>())
.def(py::init<int, int, int>())
.def("is_empty", &flexkv::CRadixTreeIndex::is_empty)
.def("reset", &flexkv::CRadixTreeIndex::reset)
.def("lock", &flexkv::CRadixTreeIndex::lock, py::arg("node"))
Expand Down
4 changes: 2 additions & 2 deletions csrc/radix_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) {

struct timeval now;
gettimeofday(&now, nullptr);
last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000;
grace_time = now.tv_sec * 1000 + now.tv_usec / 10000;

index->inc_node_count();
}
Expand Down Expand Up @@ -215,7 +215,7 @@ std::shared_ptr<CMatchResult> CRadixTreeIndex::match_prefix(

while (prefix_blocks_num < num_blocks) {
if (update_cache_info) {
current_node->update_time();
current_node->update_time(hit_reward_seconds);
}

child_hash = HashType(block_hashes_ptr[prefix_blocks_num + current_node->size()]);
Expand Down
21 changes: 15 additions & 6 deletions csrc/radix_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class CRadixNode {
bool on_leaf;
bool ready;
int lock_cnt;
time_t last_access_time;
time_t grace_time;

std::deque<int64_t> block_hashes;
std::deque<int64_t> physical_blocks;
Expand Down Expand Up @@ -48,18 +48,25 @@ class CRadixNode {
}

void set_time(time_t time) {
last_access_time = time;
grace_time = time;
}

time_t get_time() {
return last_access_time;
return grace_time;
}

void update_time() {
void update_time(int hit_reward_seconds) {
struct timeval now;
time_t now_time;

gettimeofday(&now, nullptr);
last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000;
now_time = now.tv_sec * 1000 + now.tv_usec / 10000;

if (grace_time > now_time) {
grace_time += hit_reward_seconds;
} else {
grace_time = now_time + hit_reward_seconds;
}
}

CRadixNode *get_parent() {
Expand Down Expand Up @@ -188,12 +195,14 @@ class CRadixTreeIndex {
int max_num_blocks;
int tokens_per_block;
int node_count;
int hit_reward_seconds;

public:
CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000) {
CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000, int hit_reward_seconds = 0) {
this->tokens_per_block = tokens_per_block;
this->max_num_blocks = max_num_blocks;
this->node_count = 0;
this->hit_reward_seconds = hit_reward_seconds;

root = new CRadixNode(this, true, 0);
node_list.push_back(root);
Expand Down
28 changes: 18 additions & 10 deletions flexkv/cache/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(self,
device_type: DeviceType,
num_total_blocks: int,
tokens_per_block: int,
evict_ratio: float):
evict_ratio: float,
hit_reward_seconds: int):
if not isinstance(device_type, DeviceType):
raise InvalidConfigError(f"Unknown device type: {device_type}")
if num_total_blocks <= 0:
Expand All @@ -63,7 +64,7 @@ def __init__(self,

self.device_type = device_type

self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks)
self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks, hit_reward_seconds)

self.mempool = Mempool(num_total_blocks=num_total_blocks)

Expand Down Expand Up @@ -145,7 +146,8 @@ def __init__(self,
device_type: DeviceType,
num_total_blocks: int,
tokens_per_block: int,
evict_ratio: float):
evict_ratio: float,
hit_reward_seconds: int):
if not isinstance(device_type, DeviceType):
raise InvalidConfigError(f"Unknown device type: {device_type}")
if num_total_blocks <= 0:
Expand All @@ -156,7 +158,7 @@ def __init__(self,

self.device_type = device_type

self.index = RadixTreeIndex(tokens_per_block=tokens_per_block)
self.index = RadixTreeIndex(tokens_per_block=tokens_per_block, hit_reward_seconds=hit_reward_seconds)

self.mempool = Mempool(num_total_blocks=num_total_blocks)

Expand Down Expand Up @@ -232,36 +234,42 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig):
self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU,
cache_config.num_cpu_blocks,
cache_config.tokens_per_block,
cache_config.evict_ratio)
cache_config.evict_ratio,
cache_config.hit_reward_seconds)
else:
self.cpu_cache_engine = CacheEngine(DeviceType.CPU,
cache_config.num_cpu_blocks,
cache_config.tokens_per_block,
cache_config.evict_ratio)
cache_config.evict_ratio,
cache_config.hit_reward_seconds)
self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine
if cache_config.enable_ssd:
if cache_config.index_accel:
self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD,
cache_config.num_ssd_blocks,
cache_config.tokens_per_block,
cache_config.evict_ratio)
cache_config.evict_ratio,
cache_config.hit_reward_seconds)
else:
self.ssd_cache_engine = CacheEngine(DeviceType.SSD,
cache_config.num_ssd_blocks,
cache_config.tokens_per_block,
cache_config.evict_ratio)
cache_config.evict_ratio,
cache_config.hit_reward_seconds)
self.cache_engines[DeviceType.SSD] = self.ssd_cache_engine
if cache_config.enable_remote:
if cache_config.index_accel:
self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE,
cache_config.num_remote_blocks,
cache_config.tokens_per_block,
cache_config.evict_ratio)
cache_config.evict_ratio,
cache_config.hit_reward_seconds)
else:
self.remote_cache_engine = CacheEngine(DeviceType.REMOTE,
cache_config.num_remote_blocks,
cache_config.tokens_per_block,
cache_config.evict_ratio)
cache_config.evict_ratio,
cache_config.hit_reward_seconds)
self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine

self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \
Expand Down
23 changes: 14 additions & 9 deletions flexkv/cache/radixtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class RadixNode:

is_ready: bool
lock_cnt: int
last_access_time: float
grace_time: float

parent: Optional['RadixNode'] = None
children: Dict[Optional[HashType], 'RadixNode'] = field(default_factory=dict)
Expand All @@ -59,7 +59,7 @@ def __post_init__(self) -> None:
assert self.block_hashes.size == self.physical_blocks.size

def __lt__(self, other: 'RadixNode') -> bool:
return self.last_access_time < other.last_access_time
return self.grace_time < other.grace_time

def size(self) -> int:
return self.block_hashes.size
Expand Down Expand Up @@ -93,7 +93,7 @@ def split(self, prefix_length: int) -> 'RadixNode':
physical_blocks=self.physical_blocks[:prefix_length],
is_ready=self.is_ready,
lock_cnt=0, # Note: only lock near-leaf node
last_access_time=self.last_access_time,
grace_time=self.grace_time,
)
self.block_hashes = self.block_hashes[prefix_length:]
self.physical_blocks = self.physical_blocks[prefix_length:]
Expand All @@ -120,29 +120,31 @@ def merge_child(self) -> None: # ignore status
child = list(self.children.values())[0]
self.block_hashes = np.concatenate([self.block_hashes, child.block_hashes])
self.physical_blocks = np.concatenate([self.physical_blocks, child.physical_blocks])
self.last_access_time = max(self.last_access_time, child.last_access_time)
self.grace_time = max(self.grace_time, child.grace_time)
self.children.clear()

class RadixTreeIndex:
def __init__(self, tokens_per_block: int, max_num_blocks: int = 1000000):
def __init__(self, tokens_per_block: int, max_num_blocks: int = 1000000, hit_reward_seconds: int = 0):
self.root_node: RadixNode = RadixNode(block_hashes=np.array([], dtype=np.int64),
physical_blocks=np.array([], dtype=np.int64),
is_ready=True,
lock_cnt=0,
last_access_time=time.time())
grace_time=time.time())

self.tokens_per_block = tokens_per_block

self.leaf_nodes: Dict[HashType, RadixNode] = {}

self.max_num_blocks = max_num_blocks

self.hit_reward_seconds = hit_reward_seconds

def reset(self) -> None:
self.root_node = RadixNode(block_hashes=np.array([], dtype=np.int64),
physical_blocks=np.array([], dtype=np.int64),
is_ready=True,
lock_cnt=0,
last_access_time=time.time())
grace_time=time.time())
self.leaf_nodes.clear()

def is_empty(self) -> bool:
Expand All @@ -160,7 +162,10 @@ def match_prefix(self,
physical_blocks = np.array([], dtype=np.int64)
while prefix_blocks_num < sequence.num_blocks:
if update_cache_info:
current_node.last_access_time = time.time()
if current_node.grace_time < time.time():
current_node.grace_time = time.time() + hit_reward_seconds
else:
current_node.grace_time += hit_reward_seconds
child_hash = sequence.get_hash(prefix_blocks_num + current_node.size())
if child_hash in current_node.children:
if current_node.is_ready:
Expand Down Expand Up @@ -238,7 +243,7 @@ def insert(self,
physical_blocks=physical_block_ids,
is_ready=is_ready,
lock_cnt=0,
last_access_time=time.time()
grace_time=time.time()
)

last_node_leaf = last_node.is_leaf() and not last_node.is_root()
Expand Down
1 change: 1 addition & 0 deletions flexkv/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CacheConfig:
enable_remote: bool = False
use_gds: bool = False
index_accel: bool = False
hit_reward_seconds: int = 0

# kv cache layout configs
gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE
Expand Down
Loading