diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..0fe668086a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,30 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [1.0.0] - 2025-09-11 + +### Added +- C++ radix tree for fast match, need set "index_accel": true in cache_config +- sync kernel launch +- a huge change that move cache engine to a library for accelerator(vLLM e.g.) to use instead of server-client mode. + This accelerate the get and put when no KVCache is matched. This version includes breaking API changes and is not backward compatible. +- add evict_ratio, need set "evict_ratio": 0.05 in cache_config +- reducing the bubble inner the launch kernel +- add vLLM 0.10.1.1 adapter + +### Fixed +- cython release package + + +## [0.1.0] - 2025-08-29 + +### Init +- init version +- add license + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..a395746aa0 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing to FlexKV + +Thank you for your interest in contributing to FlexKV! + +## PR Title and Classification +Use a prefixed PR title to indicate the type of changes. Please use one of the following: + +- `[bugfix]` for bugfixes +- `[feature]` for new features +- `[test]` for test cases +- `[ci/build]` for build or continuous integration improvements +- `[doc]` for documentation fixes +- `[misc]` for PRs that do not fit the above categories. Please use this sparingly. \ No newline at end of file diff --git a/README.md b/README.md index 56875811e3..de5bbc5acf 100644 --- a/README.md +++ b/README.md @@ -8,29 +8,27 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) ## How to Use -### Build FlexKV +### Install Dependencies ```bash -./build.sh +apt install liburing-dev +apt install libxxhash-dev ``` -### Use FlexKV with vLLM (v0.8.4) - -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: +### Build FlexKV ```bash -# Start FlexKV as server -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh +./build.sh +#./build.sh --release for cython package +``` -# Start vLLM as client -bash benchmarks/flexkv_benchmark/serving_vllm.sh +### Use FlexKV with vLLM -# Start benchmark -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. +See [docs/vllm_adapter/README_en.md](docs/vllm_adapter/README_en.md) + +### FlexKV Integration with Dynamo -> **Note**: The current script is only compatible with the `main` branch. Support for the latest features in the `dev` branch is under development. +See [docs/dynamo_integration/README_en.md](docs/dynamo_integration/README_en.md) ## Design Architecture @@ -88,6 +86,7 @@ FlexKV performs: - The main branch is the stable branch, which maintains already tested commits. Please pull from main branch if you need stable code. - The dev branch is the development branch, which contains newer features. Please branch from and merge into dev if you need new features or are developing new functionality. - The bugfix branch is for bug fixes, maintaining urgent bugs that need immediate resolution or documentation that requires prompt updates. If you need to fix a bug or update documentation urgently, please branch from and merge into the bugfix branch. +- The stable branch refers to the previous main branch state, intended only for rollback or extremely conservative use cases (e.g., production deployment). Its use is discouraged. ## Roadmap diff --git a/README_zh.md b/README_zh.md index 8223a5d9c0..1654ff17ae 100644 --- a/README_zh.md +++ b/README_zh.md @@ -8,29 +8,27 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE ## 如何使用 +### 安装依赖 + +```bash +apt install liburing-dev +apt install libxxhash-dev +``` + ### 编译 FlexKV ```bash ./build.sh +#./build.sh --release for cython package ``` ### 以 vLLM 为例使用 FlexKV -在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: +见[docs/vllm_adapter/README_zh.md](docs/vllm_adapter/README_zh.md) -```bash -# 启动 FlexKV 作为服务端 -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# 启动 vLLM 作为客户端 -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# 启动性能测试 -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` -在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch`,测试方法同上。 +### FlexKV和Dynamo框架的集成 -> **注意**:当前脚本仅适配 `main` 分支。`dev` 分支的最新特性支持脚本正在开发中。 +见[docs/dynamo_integration/README_zh.md](docs/dynamo_integration/README_zh.md) ## 设计框架 @@ -88,6 +86,7 @@ FlexKV 在处理 *get* 请求时: - main 为稳定分支,维护已经测试过的commit。需要稳定的代码请从此分支拉取。 - dev 为开发分支,维护较新特性。需要新特性和开发新特性请从此分支拉取和合入。 - bugfix 为bug分支,维护需要立即解决的bug或需要立即更新的文档。需要解决bug和立即更新的文档请从此分支拉取和合入。 +- stable 为上一个版本的main分支位置,仅用于回滚以及极其保守的情况使用(如产品化)。不鼓励使用此版本。 ## Roadmap diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000..3eefcb9dd5 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +1.0.0 diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 4a710f41ca..d4854557c3 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -14,7 +14,6 @@ "enable_remote": false, "tokens_per_block": 16, "use_gds": false, - "use_pinned_memory": true, "gpu_kv_layout_type": "LAYERWISE", "cpu_kv_layout_type": "BLOCKWISE", "ssd_kv_layout_type": "BLOCKWISE", diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 9ffa6489f7..32fe238e45 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -29,6 +29,7 @@ namespace py = pybind11; +#ifdef CUDA_AVAILABLE void transfer_kv_blocks_binding( torch::Tensor &gpu_block_id_tensor, torch::Tensor &gpu_layer_ptrs_tensor, int64_t gpu_kv_stride_in_bytes, int64_t gpu_block_stride_in_bytes, @@ -60,7 +61,9 @@ void transfer_kv_blocks_binding( throw std::runtime_error(cudaGetErrorString(err)); } } +#endif +#ifdef CUDA_AVAILABLE void transfer_kv_blocks_ssd_binding( flexkv::SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, @@ -82,6 +85,7 @@ void transfer_kv_blocks_ssd_binding( block_stride_in_bytes, is_read, num_blocks_per_file, round_robin, num_threads_per_device, is_mla); } +#endif #ifdef FLEXKV_ENABLE_CFS void transfer_kv_blocks_remote( const py::list &file_nodeid_list, const torch::Tensor &cpu_layer_id_list, @@ -162,6 +166,7 @@ void shared_transfer_kv_blocks_remote_read_binding( #endif PYBIND11_MODULE(c_ext, m) { +#ifdef CUDA_AVAILABLE m.def("transfer_kv_blocks", &transfer_kv_blocks_binding, "Transfer multi-layer KV-cache between CPU and GPU"); m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding, @@ -174,6 +179,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("block_stride_in_bytes"), py::arg("is_read"), py::arg("num_blocks_per_file"), py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false); +#endif #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, "Transfer KV blocks between remote and CPU memory", @@ -249,6 +255,7 @@ PYBIND11_MODULE(c_ext, m) { m.def("call_pcfs_write", &flexkv::call_pcfs_write, "Call Pcfs::write from C++", py::arg("file_nodeid"), py::arg("offset"), py::arg("buffer"), py::arg("size"), py::arg("thread_id")); +#ifdef CUDA_AVAILABLE m.def("shared_transfer_kv_blocks_remote_read", &shared_transfer_kv_blocks_remote_read_binding, "Shared transfer KV blocks from remote PCFS to CPU memory", @@ -266,6 +273,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("total_layers"), py::arg("is_mla") = false, py::arg("num_threads_per_file") = 8); +#endif #endif py::class_(m, "CRadixTreeIndex") @@ -297,7 +305,7 @@ PYBIND11_MODULE(c_ext, m) { .def("has_block_node_ids", &flexkv::CRadixNode::has_block_node_ids); py::class_>(m, "CMatchResult") - .def(py::init *>()) + .def(py::init()) .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) .def_readonly("last_node", &flexkv::CMatchResult::last_node) .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) @@ -318,14 +326,14 @@ PYBIND11_MODULE(c_ext, m) { // RedisMetaChannel binding py::class_(m, "RedisMetaChannel") - .def(py::init(), - py::arg("host"), py::arg("port"), py::arg("node_id"), py::arg("local_ip"), py::arg("blocks_key") = std::string("blocks")) + .def(py::init(), + py::arg("host"), py::arg("port"), py::arg("node_id"), py::arg("local_ip"), py::arg("blocks_key") = std::string("blocks"), py::arg("password") = std::string("")) .def("connect", &flexkv::RedisMetaChannel::connect) .def("get_node_id", &flexkv::RedisMetaChannel::get_node_id) .def("get_local_ip", &flexkv::RedisMetaChannel::get_local_ip) .def("make_block_key", &flexkv::RedisMetaChannel::make_block_key, py::arg("node_id"), py::arg("hash")) - .def("publish_one", [](flexkv::RedisMetaChannel &ch, const flexkv::BlockMeta &m){ ch.publish(m); }) - .def("publish_batch", [](flexkv::RedisMetaChannel &ch, const std::vector &metas, size_t batch_size){ ch.publish(metas, batch_size); }, py::arg("metas"), py::arg("batch_size")=100) + .def("publish_one", [](flexkv::RedisMetaChannel &ch, const flexkv::BlockMeta &m){ return ch.publish(m); }) + .def("publish_batch", [](flexkv::RedisMetaChannel &ch, const std::vector &metas, size_t batch_size){ return ch.publish(metas, batch_size); }, py::arg("metas"), py::arg("batch_size")=100) .def("load", [](flexkv::RedisMetaChannel &ch, size_t max_items){ std::vector out; ch.load(out, max_items); return out; }, py::arg("max_items")) .def("renew_node_leases", &flexkv::RedisMetaChannel::renew_node_leases, py::arg("node_id"), py::arg("new_lt"), py::arg("batch_size")=200) .def("list_keys", [](flexkv::RedisMetaChannel &ch, const std::string &pattern){ std::vector keys; ch.list_keys(pattern, keys); return keys; }, py::arg("pattern")) @@ -334,19 +342,18 @@ PYBIND11_MODULE(c_ext, m) { .def("hmget_field_for_keys", [](flexkv::RedisMetaChannel &ch, const std::vector &keys, const std::string &field){ std::vector values; ch.hmget_field_for_keys(keys, field, values); return values; }, py::arg("keys"), py::arg("field")) .def("hmget_two_fields_for_keys", [](flexkv::RedisMetaChannel &ch, const std::vector &keys, const std::string &f1, const std::string &f2){ std::vector> out; ch.hmget_two_fields_for_keys(keys, f1, f2, out); return out; }, py::arg("keys"), py::arg("field1"), py::arg("field2")) .def("load_metas_by_keys", [](flexkv::RedisMetaChannel &ch, const std::vector &keys){ std::vector out; ch.load_metas_by_keys(keys, out); return out; }, py::arg("keys")) - .def("update_block_state_batch", [](flexkv::RedisMetaChannel &ch, uint32_t node_id, const std::vector &hashes, flexkv::NodeState state, size_t batch_size){ std::deque dq(hashes.begin(), hashes.end()); ch.update_block_state_batch(node_id, &dq, state, batch_size); }, py::arg("node_id"), py::arg("hashes"), py::arg("state"), py::arg("batch_size")=200) - .def("delete_blockmeta_batch", [](flexkv::RedisMetaChannel &ch, uint32_t node_id, const std::vector &hashes, size_t batch_size){ std::deque dq(hashes.begin(), hashes.end()); ch.delete_blockmeta_batch(node_id, &dq, batch_size); }, py::arg("node_id"), py::arg("hashes"), py::arg("batch_size")=200); + .def("update_block_state_batch", [](flexkv::RedisMetaChannel &ch, uint32_t node_id, const std::vector &hashes, int state, size_t batch_size){ std::deque dq(hashes.begin(), hashes.end()); return ch.update_block_state_batch(node_id, &dq, state, batch_size); }, py::arg("node_id"), py::arg("hashes"), py::arg("state"), py::arg("batch_size")=200) + .def("delete_blockmeta_batch", [](flexkv::RedisMetaChannel &ch, uint32_t node_id, const std::vector &hashes, size_t batch_size){ std::deque dq(hashes.begin(), hashes.end()); return ch.delete_blockmeta_batch(node_id, &dq, batch_size); }, py::arg("node_id"), py::arg("hashes"), py::arg("batch_size")=200); // LocalRadixTree bindings (derived from CRadixTreeIndex) py::class_(m, "LocalRadixTree") - .def(py::init(), + .def(py::init(), py::arg("tokens_per_block"), py::arg("max_num_blocks") = 1000000, py::arg("lease_ttl_ms") = 100000, py::arg("renew_lease_ms") = 0, py::arg("refresh_batch_size") = 256, - py::arg("idle_sleep_ms") = 10, - py::arg("lt_pool_initial_capacity") = 0) + py::arg("idle_sleep_ms") = 10) .def("set_meta_channel", &flexkv::LocalRadixTree::set_meta_channel, py::arg("channel")) .def("start", &flexkv::LocalRadixTree::start, py::arg("channel")) .def("stop", &flexkv::LocalRadixTree::stop) @@ -374,14 +381,14 @@ PYBIND11_MODULE(c_ext, m) { // DistributedRadixTree bindings (remote reference tree manager) py::class_(m, "DistributedRadixTree") - .def(py::init(), + .def(py::init(), py::arg("tokens_per_block"), py::arg("max_num_blocks"), py::arg("node_id"), - py::arg("lt_pool_initial_capacity") = 0, py::arg("refresh_batch_size") = 128, py::arg("rebuild_interval_ms") = 1000, - py::arg("idle_sleep_ms") = 10) + py::arg("idle_sleep_ms") = 10, + py::arg("lease_renew_ms") = 5000) .def("start", &flexkv::DistributedRadixTree::start, py::arg("channel")) .def("stop", &flexkv::DistributedRadixTree::stop) .def("remote_tree_refresh", &flexkv::DistributedRadixTree::remote_tree_refresh, py::return_value_policy::reference) @@ -392,4 +399,14 @@ PYBIND11_MODULE(c_ext, m) { .def("unlock", &flexkv::DistributedRadixTree::unlock, py::arg("node")) .def("is_empty", &flexkv::DistributedRadixTree::is_empty) .def("set_ready", &flexkv::DistributedRadixTree::set_ready, py::arg("node"), py::arg("ready") = true, py::arg("ready_length") = -1); + + // RefRadixTree bindings (for type information) + py::class_(m, "RefRadixTree") + .def(py::init*>(), + py::arg("tokens_per_block"), + py::arg("max_num_blocks") = 1000000, + py::arg("lease_renew_ms") = 5000, + py::arg("renew_lease_queue") = nullptr) + .def("dec_ref_cnt", &flexkv::RefRadixTree::dec_ref_cnt) + .def("inc_ref_cnt", &flexkv::RefRadixTree::inc_ref_cnt); } diff --git a/csrc/block_meta.h b/csrc/block_meta.h index 2a8a0ae7e9..4bb0f6f928 100644 --- a/csrc/block_meta.h +++ b/csrc/block_meta.h @@ -2,7 +2,7 @@ #include -#include "radix_tree.h" // for NodeState +#include "lease_meta_mempool.h" // for NODE_STATE_* macros namespace flexkv { @@ -12,7 +12,7 @@ struct BlockMeta { uint32_t nid; // node id int64_t hash; // current block hash uint32_t lt; // lease time - NodeState state; // lease state + int state; // lease state }; } // namespace flexkv diff --git a/csrc/distributed_radix_tree.cpp b/csrc/distributed_radix_tree.cpp index a376d50624..41767c4333 100644 --- a/csrc/distributed_radix_tree.cpp +++ b/csrc/distributed_radix_tree.cpp @@ -12,13 +12,14 @@ namespace flexkv { DistributedRadixTree::DistributedRadixTree(int tokens_per_block, int max_num_blocks, - uint32_t nid, size_t lt_pool_initial_capacity, - size_t refresh_batch_size, uint32_t rebuild_interval_ms, uint32_t idle_sleep_ms) + uint32_t nid, + size_t refresh_batch_size, uint32_t rebuild_interval_ms, uint32_t idle_sleep_ms, + uint32_t lease_renew_ms) : channel(nullptr), node_id(nid), tokens_per_block(tokens_per_block), max_num_blocks(max_num_blocks), - lt_pool(lt_pool_initial_capacity) { + lt_pool(max_num_blocks) { refresh_batch_size_ = refresh_batch_size; rebuild_interval_ms_ = rebuild_interval_ms; - idle_sleep_ms_ = idle_sleep_ms; + lease_renew_ms_ = lease_renew_ms; old_index.store(nullptr, std::memory_order_relaxed); c_index.store(nullptr, std::memory_order_relaxed); } @@ -54,12 +55,17 @@ void* DistributedRadixTree::refresh_worker_trampoline(void* arg) { return nullptr; } -void DistributedRadixTree::start(RedisMetaChannel *ch) { - if (refresh_started) return; +bool DistributedRadixTree::start(RedisMetaChannel *ch) { + if (refresh_started) return true; channel = ch; refresh_should_stop = false; refresh_started = true; - pthread_create(&refresh_tid, nullptr, &DistributedRadixTree::refresh_worker_trampoline, this); + int result = pthread_create(&refresh_tid, nullptr, &DistributedRadixTree::refresh_worker_trampoline, this); + if (result != 0) { + refresh_started = false; + return false; + } + return true; } void DistributedRadixTree::stop() { @@ -134,14 +140,19 @@ void DistributedRadixTree::refresh_nodes_lease_from_redis(const std::vectorlease_time = min_lt; + lm->lease_time = min_lt == UINT32_MAX ? 0 : min_lt; } } } @@ -175,7 +186,12 @@ RefRadixTree* DistributedRadixTree::remote_tree_refresh() { if (k.size() <= 5) continue; if (ips.size() <= i) continue; // parse nid - uint32_t nid = (uint32_t)std::stoul(k.substr(5)); + uint32_t nid = 0; + try { + nid = (uint32_t)std::stoul(k.substr(5)); + } catch (const std::exception&) { + continue; + } if (nid == node_id) continue; // skip self std::string ip = ips[i]; uint32_t dst = compute_ip_distance(ip, self_ip); @@ -186,7 +202,8 @@ RefRadixTree* DistributedRadixTree::remote_tree_refresh() { std::sort(nodes.begin(), nodes.end(), [](const NodeInfo &a, const NodeInfo &b){ return a.dst < b.dst; }); // 4) iterate nodes and load their block metas - RefRadixTree* new_index = new RefRadixTree(tokens_per_block, max_num_blocks); + RefRadixTree* new_index = new RefRadixTree(tokens_per_block, max_num_blocks, lease_renew_ms_, + &renew_lease_queue, <_pool); for (const auto &nfo : nodes) { // list keys block::* std::vector bkeys; @@ -200,7 +217,10 @@ RefRadixTree* DistributedRadixTree::remote_tree_refresh() { // Merge into new_index via DFS+merge helpers if (!metas.empty()) { std::unordered_map> parent_to_children; - for (const auto& meta : metas) parent_to_children[meta.ph].push_back(&meta); + for (int i = 0; i < metas.size(); ++i) { + if (metas[i].state != NODE_STATE_NORMAL) continue; + parent_to_children[metas[i].ph].push_back(&metas[i]); + } std::unordered_set processed_hashes; std::vector temp_root_children; for (const auto& root_child_ptr : parent_to_children[0]) { @@ -222,7 +242,9 @@ std::shared_ptr DistributedRadixTree::match_prefix( torch::Tensor &block_hashes, int num_blocks, bool update_cache_info) { RefRadixTree *idx = c_index.load(std::memory_order_acquire); if (idx == nullptr) { - return std::make_shared(0, 0, 0, nullptr, nullptr, new std::vector()); + auto empty_i64 = torch::empty({0}, torch::dtype(torch::kInt64)); + auto empty_u32 = torch::empty({0}, torch::dtype(torch::kInt32)); + return std::make_shared(0, 0, 0, nullptr, nullptr, empty_i64, empty_u32); } RefCntGuard guard{idx}; return idx->match_prefix(block_hashes, num_blocks, update_cache_info); @@ -274,13 +296,23 @@ void DistributedRadixTree::set_ready(CRadixNode *node, bool ready, int ready_len idx->set_ready(node, ready, ready_length); } -RefRadixTree::RefRadixTree(int tokens_per_block, int max_num_blocks) +RefRadixTree::RefRadixTree(int tokens_per_block, int max_num_blocks, uint32_t lease_renew_ms, + LockFreeQueue *renew_lease_queue, LeaseMetaMemPool* lt_pool) : CRadixTreeIndex(tokens_per_block, max_num_blocks) { + lease_renew_ms_ = lease_renew_ms; + renew_lease_queue_ = renew_lease_queue; + lt_pool_ = lt_pool; ref_cnt.store(1); } RefRadixTree::~RefRadixTree() { - // No special cleanup beyond base class for now + while (node_list.size()) { + auto node = node_list.front(); + auto lm = node->get_lease_meta(); + if (lm != nullptr) { + lt_pool_->free(lm); + } + } } void RefRadixTree::dec_ref_cnt() { @@ -337,9 +369,15 @@ std::shared_ptr RefRadixTree::match_prefix( auto prefix_blocks_num = 0; auto ready_prefix_blocks_num = 0; auto last_node_matched_length = 0; - auto physical_blocks = new std::vector(); + auto physical_blocks_tensor = torch::empty({num_blocks}, torch::dtype(torch::kInt64)); + auto *pb_out = physical_blocks_tensor.data_ptr(); + int64_t pb_write = 0; auto block_hashes_ptr = block_hashes.data_ptr(); HashType child_hash; + // node ids stored as int32 tensor (PyTorch lacks uint32 dtype) + auto node_ids_tensor = torch::empty({num_blocks}, torch::dtype(torch::kInt32)); + auto *ni_out = node_ids_tensor.data_ptr(); + int32_t ni_write = 0; // now in ms struct timeval now_tv; gettimeofday(&now_tv, nullptr); @@ -359,8 +397,10 @@ std::shared_ptr RefRadixTree::match_prefix( // expired: stop matching and return what we have so far break; } - if ((int64_t)lt - (int64_t)now_ms <= (int64_t)renew_threshold_ms_) { - renew_lease_queue.push(current_node); + if ((int64_t)lt - (int64_t)now_ms <= (int64_t)lease_renew_ms_) { + if (renew_lease_queue_ != nullptr) { + renew_lease_queue_->push(current_node); + } } } } @@ -372,8 +412,17 @@ std::shared_ptr RefRadixTree::match_prefix( ready_prefix_blocks_num += current_node->size(); } prefix_blocks_num += current_node->size(); - physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), - current_node->get_physical_blocks().end()); + for (auto v : current_node->get_physical_blocks()) { + pb_out[pb_write++] = v; + } + auto bnis = current_node->get_block_node_ids(); + if (bnis != nullptr) { + for (auto v : *bnis) { + ni_out[ni_write++] = v; + } + } else { + std::cerr << "block_node_ids is nullptr" << std::endl; + } current_node = current_node->get_child(child_hash); } else { auto matched_length = 0; @@ -391,8 +440,18 @@ std::shared_ptr RefRadixTree::match_prefix( } } matched_length = left; - physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), - current_node->get_physical_blocks().begin() + matched_length); + auto &dq = current_node->get_physical_blocks(); + for (int i = 0; i < matched_length; ++i) { + pb_out[pb_write++] = dq[i]; + } + auto bnis = current_node->get_block_node_ids(); + if (bnis != nullptr) { + for (auto v : *bnis) { + ni_out[ni_write++] = v; + } + } else { + std::cerr << "block_node_ids is nullptr" << std::endl; + } } else { matched_length = 0; } @@ -408,8 +467,10 @@ std::shared_ptr RefRadixTree::match_prefix( } } + auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); + auto node_ids = node_ids_tensor.narrow(0, 0, ni_write); return std::make_shared(prefix_blocks_num, ready_prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks); + last_ready_node, current_node, physical_blocks, node_ids); } // DFS function to build subtree from BlockMeta with chain compression @@ -442,7 +503,7 @@ CRadixNode* dfs_build_subtree_from_meta(const BlockMeta* current_meta, auto& cbh = child_node->get_block_hashes(); auto& cpb = child_node->get_physical_blocks(); - auto& bni = child_node->get_block_node_ids(); + auto bni = child_node->get_block_node_ids(); if (bni == nullptr) return nullptr; // Seed with the current meta diff --git a/csrc/distributed_radix_tree.h b/csrc/distributed_radix_tree.h index eebb11ddd1..6af5dc3a79 100644 --- a/csrc/distributed_radix_tree.h +++ b/csrc/distributed_radix_tree.h @@ -28,7 +28,8 @@ class RedisMetaChannel; // forward declaration class RefRadixTree : public CRadixTreeIndex { public: - RefRadixTree(int tokens_per_block, int max_num_blocks = 1000000); + RefRadixTree(int tokens_per_block, int max_num_blocks = 1000000, uint32_t lease_renew_ms = 5000, + LockFreeQueue *renew_lease_queue = nullptr, LeaseMetaMemPool* lt_pool = nullptr); ~RefRadixTree(); // Decrement reference count; when it reaches zero, delete this instance void dec_ref_cnt(); @@ -50,7 +51,9 @@ class RefRadixTree : public CRadixTreeIndex { int evict(torch::Tensor &evicted_blocks, int num_evicted) override; std::atomic ref_cnt; - + uint32_t lease_renew_ms_; + LockFreeQueue *renew_lease_queue_; + LeaseMetaMemPool* lt_pool_; }; struct RefCntGuard { @@ -76,7 +79,7 @@ class DistributedRadixTree { size_t refresh_batch_size_ = 128; uint32_t rebuild_interval_ms_ = 1000; uint32_t idle_sleep_ms_ = 10; - uint32_t renew_threshold_ms_ = 5000; // 5 seconds before expiry to renew lease + uint32_t lease_renew_ms_ = 5000; LockFreeQueue renew_lease_queue; bool refresh_started = false; volatile bool refresh_should_stop = false; @@ -90,16 +93,16 @@ class DistributedRadixTree { public: DistributedRadixTree(int tokens_per_block, int max_num_blocks, uint32_t node_id, - size_t lt_pool_initial_capacity = 0, size_t refresh_batch_size = 128, uint32_t rebuild_interval_ms = 1000, - uint32_t idle_sleep_ms = 10); + uint32_t idle_sleep_ms = 10, + uint32_t lease_renew_ms = 5000); ~DistributedRadixTree(); void set_meta_channel(RedisMetaChannel *ch) { channel = ch; } void set_node_id(uint32_t nid) { node_id = nid; } - void start(RedisMetaChannel *channel); + bool start(RedisMetaChannel *channel); void stop(); RefRadixTree* remote_tree_refresh(); std::shared_ptr match_prefix(torch::Tensor &block_hashes, diff --git a/csrc/lease_meta_mempool.cpp b/csrc/lease_meta_mempool.cpp index 7840d3a259..2274bcc023 100644 --- a/csrc/lease_meta_mempool.cpp +++ b/csrc/lease_meta_mempool.cpp @@ -81,7 +81,7 @@ void LeaseMetaMemPool::free(LeaseMeta *ptr) { std::lock_guard lk(allocated_mu); auto it = allocated_set.find(ptr); if (it == allocated_set.end()) { - // not allocated from this pool or double free, ignore safely + // not allocated from this pool or already freed: ignore safely (idempotent) return; } allocated_set.erase(it); diff --git a/csrc/lease_meta_mempool.h b/csrc/lease_meta_mempool.h index 9112f0ca7c..90e455d77a 100644 --- a/csrc/lease_meta_mempool.h +++ b/csrc/lease_meta_mempool.h @@ -10,14 +10,13 @@ #include namespace flexkv { -enum NodeState { - NODE_STATE_NORMAL, - NODE_STATE_ABOUT_TO_EVICT, - NODE_STATE_EVICTED, -}; + +#define NODE_STATE_NORMAL 0 +#define NODE_STATE_ABOUT_TO_EVICT 1 +#define NODE_STATE_EVICTED 2 struct LeaseMeta { - volatile NodeState state; + volatile int state; volatile uint32_t lease_time; LeaseMeta() : state(NODE_STATE_NORMAL), lease_time(0) { } diff --git a/csrc/local_radix_tree.cpp b/csrc/local_radix_tree.cpp index c8e3bcdfb8..8a9c4ba46c 100644 --- a/csrc/local_radix_tree.cpp +++ b/csrc/local_radix_tree.cpp @@ -20,8 +20,8 @@ static inline uint64_t get_now_ms() { return (uint64_t)now.tv_sec * 1000 + (uint64_t)(now.tv_usec / 1000); } -LocalRadixTree::LocalRadixTree(int tokens_per_block, int max_num_blocks, uint32_t ttl_ms, uint32_t renew_ms, uint32_t batch_sz, uint32_t idle_sleep_ms, size_t lt_pool_initial_capacity) - : CRadixTreeIndex(tokens_per_block, max_num_blocks), channel(nullptr), node_id(0), lease_ttl_ms(ttl_ms), refresh_batch_size(batch_sz), lease_pool(lt_pool_initial_capacity) { +LocalRadixTree::LocalRadixTree(int tokens_per_block, int max_num_blocks, uint32_t ttl_ms, uint32_t renew_ms, uint32_t batch_sz, uint32_t idle_sleep_ms) + : CRadixTreeIndex(tokens_per_block, max_num_blocks), channel(nullptr), node_id(0), lease_ttl_ms(ttl_ms), refresh_batch_size(batch_sz), lease_pool(max_num_blocks) { this->idle_sleep_ms = idle_sleep_ms; if (renew_ms == 0) { renew_lease_ms = (uint32_t)(ttl_ms * 2 / 10); @@ -33,6 +33,11 @@ LocalRadixTree::LocalRadixTree(int tokens_per_block, int max_num_blocks, uint32_ } } +LocalRadixTree::~LocalRadixTree() { + // Ensure background worker is stopped before nodes/lease pool destruction + stop(); +} + CRadixNode *LocalRadixTree::insert(torch::Tensor &physical_block_ids, torch::Tensor &block_hashes, int num_blocks, int num_insert_blocks, bool ready, CRadixNode *last_node, int num_matched_blocks, int last_node_matched_length) { @@ -199,14 +204,19 @@ void LocalRadixTree::refresh_worker() { } } -void LocalRadixTree::start(RedisMetaChannel *ch) { - if (refresh_started) return; +bool LocalRadixTree::start(RedisMetaChannel *ch) { + if (refresh_started) return true; // Initialize channel and node_id from ch set_meta_channel(ch); - if (channel == nullptr) return; + if (channel == nullptr) return false; refresh_should_stop = false; refresh_started = true; - pthread_create(&refresh_tid, nullptr, &LocalRadixTree::refresh_worker_trampoline, this); + int result = pthread_create(&refresh_tid, nullptr, &LocalRadixTree::refresh_worker_trampoline, this); + if (result != 0) { + refresh_started = false; + return false; + } + return true; } void LocalRadixTree::renew_relese_time() { // compute new lease expiry (ms) @@ -394,7 +404,13 @@ int LocalRadixTree::total_cached_blocks() { return CRadixTreeIndex::total_cached int LocalRadixTree::total_node_num() { return CRadixTreeIndex::total_node_num(); } void LocalRadixTree::reset() { CRadixTreeIndex::reset(); } bool LocalRadixTree::is_root(CRadixNode *node) { return CRadixTreeIndex::is_root(node); } -void LocalRadixTree::remove_node(CRadixNode *node) { CRadixTreeIndex::remove_node(node); } +void LocalRadixTree::remove_node(CRadixNode *node) { + auto lm = node->get_lease_meta(); + if (lm != nullptr) { + lease_pool.free(lm); + } + CRadixTreeIndex::remove_node(node); +} void LocalRadixTree::remove_leaf(CRadixNode *node) { CRadixTreeIndex::remove_leaf(node); } void LocalRadixTree::add_node(CRadixNode *node) { CRadixTreeIndex::add_node(node); } void LocalRadixTree::add_leaf(CRadixNode *node) { CRadixTreeIndex::add_leaf(node); } diff --git a/csrc/local_radix_tree.h b/csrc/local_radix_tree.h index f82abb7886..25b104389a 100644 --- a/csrc/local_radix_tree.h +++ b/csrc/local_radix_tree.h @@ -44,7 +44,7 @@ class LocalRadixTree : public CRadixTreeIndex { void refresh_worker(); static void* refresh_worker_trampoline(void* arg); - void publish_node_blocks(CRadixNode *node); + void publish_node_blocks(NewBlockMeta *node); // Pop at most max_batch nodes from new_block_queue and publish their BlockMeta to Redis. // Returns number of nodes published. size_t local_block_report(size_t max_batch = 1024); @@ -56,8 +56,8 @@ class LocalRadixTree : public CRadixTreeIndex { uint32_t lease_ttl_ms = 100000, uint32_t renew_lease_ms = 0, uint32_t refresh_batch_size = 256, - uint32_t idle_sleep_ms = 10, - size_t lt_pool_initial_capacity = 0); + uint32_t idle_sleep_ms = 10); + ~LocalRadixTree(); void set_meta_channel(RedisMetaChannel *ch); @@ -65,7 +65,7 @@ class LocalRadixTree : public CRadixTreeIndex { void insert_and_publish(const CRadixNode *node); // Start background thread; initialize channel and node_id from ch first - void start(RedisMetaChannel *ch); + bool start(RedisMetaChannel *ch); // Stop background thread gracefully void stop(); diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp index 4e92688594..a637aa0e9d 100644 --- a/csrc/radix_tree.cpp +++ b/csrc/radix_tree.cpp @@ -41,7 +41,7 @@ CRadixNode::~CRadixNode() { delete block_node_ids; } if (lease_meta != nullptr) { - LeaseMetaMemPool::release(lease_meta); + // Avoid returning to pool during teardown to prevent double-free on shutdown lease_meta = nullptr; } index->dec_node_count(); @@ -241,7 +241,10 @@ std::shared_ptr CRadixTreeIndex::match_prefix( auto prefix_blocks_num = 0; auto ready_prefix_blocks_num = 0; auto last_node_matched_length = 0; - auto physical_blocks = new std::vector(); + // Preallocate tensor for up to num_blocks entries and fill directly to avoid extra copy + auto physical_blocks_tensor = torch::empty({num_blocks}, torch::dtype(torch::kInt64)); + auto *pb_out = physical_blocks_tensor.data_ptr(); + int64_t pb_write = 0; auto block_hashes_ptr = block_hashes.data_ptr(); HashType child_hash; @@ -257,8 +260,9 @@ std::shared_ptr CRadixTreeIndex::match_prefix( ready_prefix_blocks_num += current_node->size(); } prefix_blocks_num += current_node->size(); - physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), - current_node->get_physical_blocks().end()); + for (auto v : current_node->get_physical_blocks()) { + pb_out[pb_write++] = v; + } current_node = current_node->get_child(child_hash); } else { auto matched_length = 0; @@ -276,8 +280,10 @@ std::shared_ptr CRadixTreeIndex::match_prefix( } } matched_length = left; - physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), - current_node->get_physical_blocks().begin() + matched_length); + auto &dq = current_node->get_physical_blocks(); + for (int i = 0; i < matched_length; ++i) { + pb_out[pb_write++] = dq[i]; + } } else { matched_length = 0; } @@ -293,8 +299,10 @@ std::shared_ptr CRadixTreeIndex::match_prefix( } } + auto physical_blocks = physical_blocks_tensor.narrow(0, 0, pb_write); + auto empty_uint32 = torch::Tensor(); return std::make_shared(prefix_blocks_num, ready_prefix_blocks_num, last_node_matched_length, - last_ready_node, current_node, physical_blocks); + last_ready_node, current_node, physical_blocks, empty_uint32); } } // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h index df4aac285a..edc8a72477 100644 --- a/csrc/radix_tree.h +++ b/csrc/radix_tree.h @@ -50,6 +50,12 @@ class CRadixNode { void set_lease_meta(LeaseMeta* lease_meta) { this->lease_meta = lease_meta; } + + void set_lease_time(uint32_t lease_time) { + if (this->lease_meta != nullptr) { + this->lease_meta->lease_time = lease_time; + } + } void for_each_child(std::function func) { for (auto& child : children) { @@ -198,22 +204,17 @@ class CMatchResult { CRadixNode *last_ready_node; CRadixNode *last_node; - std::vector *physical_blocks; - std::vector *block_node_ids; + torch::Tensor physical_blocks; + torch::Tensor block_node_ids; CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, - CRadixNode *_last_ready_node, CRadixNode *_last_node, std::vector *blocks, std::vector *block_node_ids = nullptr) + CRadixNode *_last_ready_node, CRadixNode *_last_node, torch::Tensor blocks, torch::Tensor block_node_ids = torch::Tensor()) : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), last_node(_last_node), physical_blocks(blocks), block_node_ids(block_node_ids) { } - ~CMatchResult() { - delete physical_blocks; - if (block_node_ids) { - delete block_node_ids; - } - }; + ~CMatchResult() {} }; class CRadixTreeIndex { diff --git a/csrc/redis_meta_channel.cpp b/csrc/redis_meta_channel.cpp index 031d1bd21c..0a5b2dca53 100644 --- a/csrc/redis_meta_channel.cpp +++ b/csrc/redis_meta_channel.cpp @@ -1,252 +1,210 @@ #include "redis_meta_channel.h" -#include -#include -#include -#include -#include -#include -#include - +#include #include #include #include +#include +#include +#include +#include +#include namespace flexkv { -static std::string resp_bulk(const std::string &s) { - std::ostringstream oss; - oss << "$" << s.size() << "\r\n" << s << "\r\n"; - return oss.str(); -} - -static std::string resp_array(const std::vector &argv) { - std::ostringstream oss; - oss << "*" << argv.size() << "\r\n"; - for (auto &a : argv) { - oss << resp_bulk(a); - } - return oss.str(); -} - -RedisTCPClient::RedisTCPClient() : sockfd(-1), port(0), timeout_ms(3000) {} +RedisHiredisClient::RedisHiredisClient() : context_(nullptr), port_(0), timeout_ms_(3000), password_("") {} -RedisTCPClient::~RedisTCPClient() { close(); } - -bool RedisTCPClient::connect(const std::string &h, int p, int t_ms) { - host = h; port = p; timeout_ms = t_ms; - sockfd = ::socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) return false; - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - if (::inet_pton(AF_INET, host.c_str(), &addr.sin_addr) <= 0) return false; - if (::connect(sockfd, (sockaddr*)&addr, sizeof(addr)) < 0) return false; - return true; -} - -void RedisTCPClient::close() { - if (sockfd >= 0) { - ::shutdown(sockfd, SHUT_RDWR); - ::close(sockfd); - } - sockfd = -1; +RedisHiredisClient::~RedisHiredisClient() { + close(); } -bool RedisTCPClient::send_all(const std::string &buf) { - size_t sent = 0; - while (sent < buf.size()) { - ssize_t n = ::send(sockfd, buf.data() + sent, buf.size() - sent, 0); - if (n <= 0) return false; - sent += (size_t)n; +bool RedisHiredisClient::connect(const std::string &host, int port, int timeout_ms, const std::string &password) { + host_ = host; + port_ = port; + timeout_ms_ = timeout_ms; + password_ = password; + + // Create connection with timeout + struct timeval timeout = { timeout_ms / 1000, (timeout_ms % 1000) * 1000 }; + context_ = redisConnectWithTimeout(host.c_str(), port, timeout); + + if (context_ == nullptr || context_->err) { + if (context_) { + redisFree(context_); + context_ = nullptr; + } + return false; } - return true; -} - -bool RedisTCPClient::recv_line(std::string &line) { - line.clear(); - char c; - while (true) { - ssize_t n = ::recv(sockfd, &c, 1, 0); - if (n <= 0) return false; - if (c == '\r') { - char lf; - if (::recv(sockfd, &lf, 1, 0) <= 0) return false; - if (lf != '\n') return false; - break; + + // Authenticate if password is provided + if (!password_.empty()) { + redisReply* reply = (redisReply*)redisCommand(context_, "AUTH %s", password_.c_str()); + if (!reply) { + redisFree(context_); + context_ = nullptr; + return false; + } + + bool auth_success = (reply->type == REDIS_REPLY_STATUS && + strcmp(reply->str, "OK") == 0); + freeReplyObject(reply); + + if (!auth_success) { + redisFree(context_); + context_ = nullptr; + return false; } - line.push_back(c); } + return true; } -bool RedisTCPClient::recv_nbytes(size_t n, std::string &out) { - out.resize(n); - size_t got = 0; - while (got < n) { - ssize_t r = ::recv(sockfd, &out[got], n - got, 0); - if (r <= 0) return false; - got += (size_t)r; +void RedisHiredisClient::close() { + if (context_) { + redisFree(context_); + context_ = nullptr; } - // consume CRLF - char crlf[2]; - if (::recv(sockfd, crlf, 2, 0) != 2) return false; - return true; } -bool RedisTCPClient::command(const std::vector &argv, std::vector &out) { - out.clear(); - std::string req = resp_array(argv); - if (!send_all(req)) return false; - std::string line; - if (!recv_line(line)) return false; - if (line.empty()) return false; - if (line[0] == '+') { // simple string - out.push_back(line.substr(1)); - return true; - } else if (line[0] == ':') { // integer - out.push_back(line.substr(1)); - return true; - } else if (line[0] == '$') { // bulk string - int len = std::stoi(line.substr(1)); - if (len < 0) return true; // nil - std::string bulk; - if (!recv_nbytes((size_t)len, bulk)) return false; - out.push_back(bulk); - return true; - } else if (line[0] == '*') { // array - int cnt = std::stoi(line.substr(1)); - for (int i = 0; i < cnt; ++i) { - if (!recv_line(line)) return false; - if (line.empty() || line[0] != '$') return false; - int len = std::stoi(line.substr(1)); - if (len < 0) { out.emplace_back(); continue; } - std::string bulk; - if (!recv_nbytes((size_t)len, bulk)) return false; - out.push_back(bulk); - } - return true; +bool RedisHiredisClient::command(const std::vector &argv, std::vector &out) { + if (!context_) return false; + + // Convert vector to char* array + std::vector args; + std::vector arglens; + + for (const auto& arg : argv) { + args.push_back(arg.c_str()); + arglens.push_back(arg.length()); + } + + redisReply* reply = (redisReply*)redisCommandArgv(context_, args.size(), args.data(), arglens.data()); + if (!reply) { + return false; } - return false; + + bool result = parse_reply(reply, out); + freeReplyObject(reply); + return result; } -bool RedisTCPClient::pipeline(const std::vector> &batch, - std::vector> &replies) { +bool RedisHiredisClient::pipeline(const std::vector> &batch, + std::vector> &replies) { + if (!context_ || batch.empty()) return false; + replies.clear(); - if (batch.empty()) return true; - // Build one big request - std::ostringstream req; - for (const auto &argv : batch) req << resp_array(argv); - std::string payload = req.str(); - if (!send_all(payload)) return false; - - // Receive replies sequentially replies.reserve(batch.size()); + + // Append all commands to pipeline + for (const auto& cmd : batch) { + std::vector args; + std::vector arglens; + + for (const auto& arg : cmd) { + args.push_back(arg.c_str()); + arglens.push_back(arg.length()); + } + + int ret = redisAppendCommandArgv(context_, args.size(), args.data(), arglens.data()); + if (ret != REDIS_OK) { + return false; + } + } + + // Get all replies for (size_t i = 0; i < batch.size(); ++i) { - std::vector one; - // Parse a single reply using same logic as command() - std::string line; - if (!recv_line(line)) return false; - if (line.empty()) return false; - if (line[0] == '+') { - one.push_back(line.substr(1)); - } else if (line[0] == ':') { - one.push_back(line.substr(1)); - } else if (line[0] == '$') { - int len = std::stoi(line.substr(1)); - if (len >= 0) { - std::string bulk; - if (!recv_nbytes((size_t)len, bulk)) return false; - one.push_back(bulk); - } else { - one.emplace_back(); - } - } else if (line[0] == '*') { - int cnt = std::stoi(line.substr(1)); - for (int j = 0; j < cnt; ++j) { - if (!recv_line(line)) return false; - if (line.empty() || line[0] != '$') return false; - int len = std::stoi(line.substr(1)); - if (len < 0) { one.emplace_back(); continue; } - std::string bulk; - if (!recv_nbytes((size_t)len, bulk)) return false; - one.push_back(bulk); - } - } else { + redisReply* reply = nullptr; + int ret = redisGetReply(context_, (void**)&reply); + if (ret != REDIS_OK || !reply) { + if (reply) freeReplyObject(reply); return false; } - replies.push_back(std::move(one)); + + std::vector reply_vec; + bool success = parse_reply(reply, reply_vec); + freeReplyObject(reply); + + if (!success) { + return false; + } + + replies.push_back(std::move(reply_vec)); } + return true; } -bool RedisMetaChannel::hmget_two_fields_for_keys(const std::vector &keys, - const std::string &field1, - const std::string &field2, - std::vector> &out) { + +redisContext* RedisHiredisClient::get_context() const { + return context_; +} + +bool RedisHiredisClient::parse_reply(redisReply* reply, std::vector &out) { + if (!reply) return false; + out.clear(); - if (keys.empty()) return true; - std::vector> batch; - batch.reserve(keys.size()); - for (const auto &k : keys) batch.push_back({"HMGET", k, field1, field2}); - std::vector> replies; - if (!client.pipeline(batch, replies)) return false; - out.reserve(replies.size()); - for (const auto &r : replies) { - if (r.size() >= 2) out.emplace_back(r[0], r[1]); - else if (r.size() == 1) out.emplace_back(r[0], std::string()); - else out.emplace_back(std::string(), std::string()); + + switch (reply->type) { + case REDIS_REPLY_STRING: + case REDIS_REPLY_STATUS: + out.push_back(std::string(reply->str, reply->len)); + break; + + case REDIS_REPLY_INTEGER: + out.push_back(std::to_string(reply->integer)); + break; + + case REDIS_REPLY_ARRAY: + for (size_t i = 0; i < reply->elements; ++i) { + if (reply->element[i]->type == REDIS_REPLY_STRING) { + out.push_back(std::string(reply->element[i]->str, reply->element[i]->len)); + } else if (reply->element[i]->type == REDIS_REPLY_NIL) { + out.push_back(""); // Empty string for NIL + } else { + // For other types, convert to string representation + out.push_back(std::to_string(reply->element[i]->integer)); + } + } + break; + + case REDIS_REPLY_NIL: + out.push_back(""); // Empty string for NIL + break; + + case REDIS_REPLY_ERROR: + return false; // Error reply + + default: + return false; } - return out.size() == keys.size(); + + return true; } -static std::string to_hex_u64(uint64_t value) { - std::ostringstream oss; - oss << std::hex << std::nouppercase << value; - return oss.str(); -} + + RedisMetaChannel::RedisMetaChannel(const std::string &h, int p, uint32_t node_id, const std::string &lip, - const std::string &bk) - : host(h), port(p), node_id(node_id), blocks_key(bk), local_ip(lip) { + const std::string &bk, + const std::string &pwd) + : host(h), port(p), node_id(node_id), blocks_key(bk), local_ip(lip), password(pwd) { } bool RedisMetaChannel::connect() { - return client.connect(host, port, 3000); + return client.connect(host, port, 3000, password); } + std::string RedisMetaChannel::make_block_key(uint32_t node_id, uint64_t hash) const { std::ostringstream oss; oss << blocks_key << ":block:" << node_id << ":" << std::hex << std::nouppercase << hash; return oss.str(); } - -// register_node removed; node id is now set via constructor - -std::string RedisMetaChannel::to_string(const BlockMeta &m) { - // ph|pb|nid|hash|lt|state - std::ostringstream oss; - oss << m.ph << '|' << m.pb << '|' << m.nid << '|' << m.hash << '|' << m.lt << '|' << (int)m.state; - return oss.str(); -} - -bool RedisMetaChannel::from_string(const std::string &s, BlockMeta &m) { - std::istringstream iss(s); - std::string tok; - if (!std::getline(iss, tok, '|')) return false; m.ph = std::stoll(tok); - if (!std::getline(iss, tok, '|')) return false; m.pb = std::stoll(tok); - if (!std::getline(iss, tok, '|')) return false; m.nid = (uint32_t)std::stoul(tok); - if (!std::getline(iss, tok, '|')) return false; m.hash = std::stoll(tok); - if (!std::getline(iss, tok, '|')) return false; m.lt = (uint32_t)std::stoul(tok); - if (!std::getline(iss, tok, '|')) return false; m.state = (NodeState)std::stoi(tok); - return true; -} - -void RedisMetaChannel::publish(const BlockMeta &meta) { +bool RedisMetaChannel::publish(const BlockMeta &meta) { std::vector resp; // Key format: :block:: std::string key = make_block_key(meta.nid, (uint64_t)meta.hash); - client.command({ + bool ret = client.command({ "HSET", key, "ph", std::to_string(meta.ph), "pb", std::to_string(meta.pb), @@ -255,10 +213,11 @@ void RedisMetaChannel::publish(const BlockMeta &meta) { "lt", std::to_string(meta.lt), "state", std::to_string((int)meta.state) }, resp); + return ret; } -void RedisMetaChannel::publish(const std::vector &metas, size_t batch_size) { - if (metas.empty()) return; +bool RedisMetaChannel::publish(const std::vector &metas, size_t batch_size) { + if (metas.empty()) return true; if (batch_size == 0) batch_size = 100; size_t total = metas.size(); @@ -282,59 +241,88 @@ void RedisMetaChannel::publish(const std::vector &metas, size_t batch }); } std::vector> replies; - client.pipeline(batch, replies); + bool ret = client.pipeline(batch, replies); + if (!ret) { + return false; + } idx = end; } + return true; } size_t RedisMetaChannel::load(std::vector &out, size_t max_items) { out.clear(); if (max_items == 0) return 0; - // Fetch keys: KEYS :block:* + // Use SCAN instead of KEYS to avoid blocking std::vector keys; - if (!client.command({"KEYS", blocks_key + ":block:*"}, keys)) return 0; + std::string pattern = blocks_key + ":block:*"; + std::string cursor = "0"; + + do { + std::vector scan_result; + if (!client.command({"SCAN", cursor, "MATCH", pattern, "COUNT", "100"}, scan_result)) { + return 0; + } + + if (scan_result.size() >= 2) { + cursor = scan_result[0]; + // scan_result[1] contains the array of keys + // Parse the array response + for (size_t i = 1; i < scan_result.size(); ++i) { + keys.push_back(scan_result[i]); + if (keys.size() >= max_items) break; + } + } else { + break; + } + } while (cursor != "0" && keys.size() < max_items); + if (keys.empty()) return 0; - size_t total = std::min(keys.size(), max_items); - out.reserve(total); - - const size_t batch_size = 100; - size_t idx = 0; - while (idx < total) { - size_t end = std::min(idx + batch_size, total); - std::vector> batch; - batch.reserve(end - idx); - for (size_t i = idx; i < end; ++i) { - batch.push_back({"HMGET", keys[i], "ph", "pb", "nid", "hash", "lt", "state"}); - } - std::vector> replies; - if (!client.pipeline(batch, replies)) break; - for (const auto &fields : replies) { - if (fields.size() != 6) continue; - BlockMeta m{}; - if (!fields[0].empty()) m.ph = std::stoll(fields[0]); else m.ph = 0; - if (!fields[1].empty()) m.pb = std::stoll(fields[1]); else m.pb = 0; - if (!fields[2].empty()) m.nid = (uint32_t)std::stoul(fields[2]); else m.nid = 0; - if (!fields[3].empty()) m.hash = std::stoll(fields[3]); else m.hash = 0; - if (!fields[4].empty()) m.lt = (uint32_t)std::stoul(fields[4]); else m.lt = 0; - if (!fields[5].empty()) m.state = (NodeState)std::stoi(fields[5]); else m.state = (NodeState)0; - out.push_back(m); + // Batch HMGET for all fields + std::vector> batch; + batch.reserve(keys.size()); + + for (const auto& key : keys) { + batch.push_back({"HMGET", key, "ph", "pb", "nid", "hash", "lt", "state"}); + } + + std::vector> replies; + if (!client.pipeline(batch, replies)) return 0; + + // Parse replies into BlockMeta objects + for (size_t i = 0; i < replies.size() && i < keys.size(); ++i) { + const auto& reply = replies[i]; + if (reply.size() == 6) { + BlockMeta meta; + if (reply[0].empty() || reply[1].empty() || reply[2].empty() + || reply[3].empty() || reply[4].empty() || reply[5].empty()) { + meta.state = NODE_STATE_EVICTED; + } else { + meta.ph = std::stoll(reply[0]); + meta.pb = std::stoll(reply[1]); + meta.nid = std::stoul(reply[2]); + meta.hash = std::stoll(reply[3]); + meta.lt = std::stoul(reply[4]); + meta.state = std::stoi(reply[5]); + } + out.push_back(meta); + } else { + BlockMeta meta; + meta.state = NODE_STATE_EVICTED; + out.push_back(meta); } - idx = end; } + return out.size(); } -uint32_t RedisMetaChannel::get_node_id() const { - return node_id; -} - -void RedisMetaChannel::renew_node_leases(uint32_t node_id, uint32_t new_lt, size_t batch_size) { +bool RedisMetaChannel::renew_node_leases(uint32_t node_id, uint32_t new_lt, size_t batch_size) { // Discover keys for this node and update lt via pipeline std::vector keys; - if (!list_block_keys(node_id, keys)) return; - if (keys.empty()) return; + if (!list_block_keys(node_id, keys)) return false; + if (keys.empty()) return true; if (batch_size == 0) batch_size = 200; size_t idx = 0, total = keys.size(); while (idx < total) { @@ -345,72 +333,180 @@ void RedisMetaChannel::renew_node_leases(uint32_t node_id, uint32_t new_lt, size batch.push_back({"HSET", keys[i], "lt", std::to_string(new_lt)}); } std::vector> replies; - client.pipeline(batch, replies); + if (!client.pipeline(batch, replies)) return false; idx = end; } + return true; +} + +uint32_t RedisMetaChannel::get_node_id() const { + return node_id; } bool RedisMetaChannel::list_keys(const std::string &pattern, std::vector &keys) { keys.clear(); - return client.command({"KEYS", pattern}, keys); + std::string cursor = "0"; + + do { + // Use raw command to get proper SCAN response parsing + std::vector scan_cmd = {"SCAN", cursor, "MATCH", pattern, "COUNT", "100"}; + + // Get raw response from Redis + redisContext* context = client.get_context(); + if (!context) return false; + + // Prepare command arguments + std::vector argv; + std::vector arglen; + for (const auto& arg : scan_cmd) { + argv.push_back(arg.c_str()); + arglen.push_back(arg.length()); + } + + redisReply* reply = nullptr; + int result = redisAppendCommandArgv(context, argv.size(), argv.data(), arglen.data()); + if (result != REDIS_OK) return false; + + result = redisGetReply(context, (void**)&reply); + if (result != REDIS_OK || !reply) return false; + + // Parse SCAN response: [cursor, [keys...]] + if (reply->type == REDIS_REPLY_ARRAY && reply->elements >= 2) { + // First element is cursor + if (reply->element[0]->type == REDIS_REPLY_STRING) { + cursor = std::string(reply->element[0]->str, reply->element[0]->len); + } else if (reply->element[0]->type == REDIS_REPLY_INTEGER) { + cursor = std::to_string(reply->element[0]->integer); + } + + // Second element is array of keys + if (reply->element[1]->type == REDIS_REPLY_ARRAY) { + for (size_t i = 0; i < reply->element[1]->elements; ++i) { + if (reply->element[1]->element[i]->type == REDIS_REPLY_STRING) { + keys.push_back(std::string(reply->element[1]->element[i]->str, + reply->element[1]->element[i]->len)); + } + } + } + } + + freeReplyObject(reply); + + } while (cursor != "0"); + + return true; } bool RedisMetaChannel::list_node_keys(std::vector &keys) { - keys.clear(); - return client.command({"KEYS", "node:*"}, keys); + return list_keys("node:*", keys); } bool RedisMetaChannel::list_block_keys(uint32_t node_id, std::vector &keys) { - keys.clear(); std::string pattern = blocks_key + ":block:" + std::to_string(node_id) + ":*"; - return client.command({"KEYS", pattern}, keys); + return list_keys(pattern, keys); } bool RedisMetaChannel::hmget_field_for_keys(const std::vector &keys, const std::string &field, std::vector &values) { + if (keys.empty()) return true; + values.clear(); + values.reserve(keys.size()); + + // Batch HMGET for single field + std::vector> batch; + batch.reserve(keys.size()); + + for (const auto& key : keys) { + batch.push_back({"HMGET", key, field}); + } + + std::vector> replies; + if (!client.pipeline(batch, replies)) return false; + + for (const auto& reply : replies) { + if (!reply.empty()) { + values.push_back(reply[0]); + } else { + values.push_back(""); + } + } + + return true; +} + +bool RedisMetaChannel::hmget_two_fields_for_keys(const std::vector &keys, + const std::string &field1, + const std::string &field2, + std::vector> &out) { if (keys.empty()) return true; + + out.clear(); + out.reserve(keys.size()); + + // Batch HMGET for two fields std::vector> batch; batch.reserve(keys.size()); - for (const auto &k : keys) batch.push_back({"HMGET", k, field}); + + for (const auto& key : keys) { + batch.push_back({"HMGET", key, field1, field2}); + } + std::vector> replies; if (!client.pipeline(batch, replies)) return false; - values.reserve(replies.size()); - for (const auto &r : replies) { - if (!r.empty()) values.push_back(r[0]); else values.emplace_back(); + + for (const auto& reply : replies) { + if (reply.size() >= 2) { + out.emplace_back(reply[0], reply[1]); + } else { + out.emplace_back("", ""); + } } - return values.size() == keys.size(); + + return true; } size_t RedisMetaChannel::load_metas_by_keys(const std::vector &keys, std::vector &out) { out.clear(); if (keys.empty()) return 0; - const size_t batch_size = 100; - size_t idx = 0, total = keys.size(); - while (idx < total) { - size_t end = std::min(idx + batch_size, total); - std::vector> batch; - batch.reserve(end - idx); - for (size_t i = idx; i < end; ++i) { - batch.push_back({"HMGET", keys[i], "ph", "pb", "nid", "hash", "lt", "state"}); - } - std::vector> replies; - if (!client.pipeline(batch, replies)) break; - for (const auto &fields : replies) { - if (fields.size() != 6) continue; - BlockMeta m{}; - if (!fields[0].empty()) m.ph = std::stoll(fields[0]); else m.ph = 0; - if (!fields[1].empty()) m.pb = std::stoll(fields[1]); else m.pb = 0; - if (!fields[2].empty()) m.nid = (uint32_t)std::stoul(fields[2]); else m.nid = 0; - if (!fields[3].empty()) m.hash = std::stoll(fields[3]); else m.hash = 0; - if (!fields[4].empty()) m.lt = (uint32_t)std::stoul(fields[4]); else m.lt = 0; - if (!fields[5].empty()) m.state = (NodeState)std::stoi(fields[5]); else m.state = (NodeState)0; - out.push_back(m); + + // Batch HMGET for all fields + std::vector> batch; + batch.reserve(keys.size()); + + for (const auto& key : keys) { + batch.push_back({"HMGET", key, "ph", "pb", "nid", "hash", "lt", "state"}); + } + + std::vector> replies; + if (!client.pipeline(batch, replies)) return 0; + + // Parse replies into BlockMeta objects + for (size_t i = 0; i < replies.size() && i < keys.size(); ++i) { + const auto& reply = replies[i]; + if (reply.size() == 6) { + BlockMeta meta; + if (reply[0].empty() || reply[1].empty() || reply[2].empty() + || reply[3].empty() || reply[4].empty() || reply[5].empty()) { + meta.state = NODE_STATE_EVICTED; + } else { + meta.ph = std::stoll(reply[0]); + meta.pb = std::stoll(reply[1]); + meta.nid = std::stoul(reply[2]); + meta.hash = std::stoll(reply[3]); + meta.lt = std::stoul(reply[4]); + meta.state = std::stoi(reply[5]); + } + out.push_back(meta); + } else { + BlockMeta meta; + meta.state = NODE_STATE_EVICTED; + out.push_back(meta); } - idx = end; } + return out.size(); } @@ -418,11 +514,11 @@ static std::string key_for_block(RedisMetaChannel* ch, uint32_t node_id, int64_t return ch->make_block_key(node_id, (uint64_t)hash); } -void RedisMetaChannel::update_block_state_batch(uint32_t node_id, +bool RedisMetaChannel::update_block_state_batch(uint32_t node_id, std::deque *hashes, - NodeState state, + int state, size_t batch_size) { - if (hashes == nullptr || hashes->empty()) return; + if (hashes == nullptr || hashes->empty()) return true; if (batch_size == 0) batch_size = 200; size_t idx = 0, total = hashes->size(); while (idx < total) { @@ -434,15 +530,16 @@ void RedisMetaChannel::update_block_state_batch(uint32_t node_id, batch.push_back({"HSET", key, "state", std::to_string((int)state)}); } std::vector> replies; - client.pipeline(batch, replies); + if (!client.pipeline(batch, replies)) return false; idx = end; } + return true; } -void RedisMetaChannel::delete_blockmeta_batch(uint32_t node_id, +bool RedisMetaChannel::delete_blockmeta_batch(uint32_t node_id, std::deque *hashes, size_t batch_size) { - if (hashes == nullptr || hashes->empty()) return; + if (hashes == nullptr || hashes->empty()) return true; if (batch_size == 0) batch_size = 200; size_t idx = 0, total = hashes->size(); while (idx < total) { @@ -454,11 +551,10 @@ void RedisMetaChannel::delete_blockmeta_batch(uint32_t node_id, batch.push_back({"DEL", key}); } std::vector> replies; - client.pipeline(batch, replies); + if (!client.pipeline(batch, replies)) return false; idx = end; } + return true; } -} // namespace flexkv - - +} // namespace flexkv \ No newline at end of file diff --git a/csrc/redis_meta_channel.h b/csrc/redis_meta_channel.h index ae872b5287..af216a665d 100644 --- a/csrc/redis_meta_channel.h +++ b/csrc/redis_meta_channel.h @@ -6,28 +6,34 @@ #include #include #include +#include +#include +#include +#include +#include #include "block_meta.h" +// Forward declaration for hiredis +struct redisContext; +struct redisReply; + namespace flexkv { -// Minimal RESP/Redis TCP client (single-threaded, blocking) for a few commands we need. -class RedisTCPClient { +// Hiredis-based Redis client wrapper +class RedisHiredisClient { private: - int sockfd; - std::string host; - int port; - int timeout_ms; - - bool send_all(const std::string &buf); - bool recv_line(std::string &line); - bool recv_nbytes(size_t n, std::string &out); + redisContext* context_; + std::string host_; + int port_; + int timeout_ms_; + std::string password_; public: - RedisTCPClient(); - ~RedisTCPClient(); + RedisHiredisClient(); + ~RedisHiredisClient(); - bool connect(const std::string &host, int port, int timeout_ms = 3000); + bool connect(const std::string &host, int port, int timeout_ms = 3000, const std::string &password = ""); void close(); // Sends a RESP array command and parses a single reply into raw components. @@ -38,55 +44,63 @@ class RedisTCPClient { // replies[i] corresponds to batch[i]. Returns false if send/receive fails at any point. bool pipeline(const std::vector> &batch, std::vector> &replies); -}; + // Get raw Redis context for advanced operations + redisContext* get_context() const; + +private: + bool parse_reply(redisReply* reply, std::vector &out); + void free_reply(redisReply* reply); +}; class RedisMetaChannel { private: - RedisTCPClient client; + RedisHiredisClient client; std::string host; int port; uint32_t node_id; std::string blocks_key; // legacy, unused for list storage std::string local_ip; + std::string password; public: RedisMetaChannel(const std::string &host, int port, uint32_t node_id, const std::string &local_ip, - const std::string &blocks_key = "blocks"); + const std::string &blocks_key = "blocks", + const std::string &password = ""); bool connect(); // Build Redis block key: :block:: std::string make_block_key(uint32_t node_id, uint64_t hash) const; - void publish(const BlockMeta &meta); - void publish(const std::vector &metas, size_t batch_size = 100); + bool publish(const BlockMeta &meta); + bool publish(const std::vector &metas, size_t batch_size = 100); size_t load(std::vector &out, size_t max_items); // Batch update lt for all block metas belonging to node_id - void renew_node_leases(uint32_t node_id, uint32_t new_lt, size_t batch_size = 200); + bool renew_node_leases(uint32_t node_id, uint32_t new_lt, size_t batch_size = 200); // Returns the global node id assigned to this process, or UINT32_MAX if uninitialized. uint32_t get_node_id() const; const std::string &get_local_ip() const { return local_ip; } // Batch update state for given hashes belonging to node_id - void update_block_state_batch(uint32_t node_id, + bool update_block_state_batch(uint32_t node_id, std::deque *hashes, - NodeState state, + int state, size_t batch_size = 200); // Batch delete metas for given hashes belonging to node_id - void delete_blockmeta_batch(uint32_t node_id, + bool delete_blockmeta_batch(uint32_t node_id, std::deque *hashes, size_t batch_size = 200); // Generic helpers for metadata queries bool list_keys(const std::string &pattern, std::vector &keys); - // List node keys: KEYS node:* + // List node keys: SCAN node:* bool list_node_keys(std::vector &keys); - // List block keys for a specific node: KEYS :block::* + // List block keys for a specific node: SCAN :block::* bool list_block_keys(uint32_t node_id, std::vector &keys); // Pipeline HMGET for a single field over many keys. values.size()==keys.size() on success diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 06cb45c4fe..43a31f6a1d 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -15,7 +15,9 @@ * limitations under the License. */ #include "tp_transfer_thread_group.h" +#ifdef CUDA_AVAILABLE #include "transfer.cuh" +#endif #include namespace flexkv { @@ -146,6 +148,7 @@ void TPTransferThreadGroup::tp_group_transfer( int64_t cpu_startoff_inside_chunks = is_mla ? 0 : i * gpu_chunk_sizes_in_bytes_[i]; +#ifdef CUDA_AVAILABLE flexkv::transfer_kv_blocks( num_blocks, layer_id, layer_granularity, gpu_block_ids, gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i], @@ -154,12 +157,18 @@ void TPTransferThreadGroup::tp_group_transfer( cpu_startoff_inside_chunks, gpu_chunk_sizes_in_bytes_[i], streams_[i], transfer_sms, is_host_to_device, use_ce_transfer, is_mla ); +#else + // CUDA not available, skip transfer + throw std::runtime_error("CUDA not available, cannot perform transfer"); +#endif +#ifdef CUDA_AVAILABLE cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { failed = true; error_msg = cudaGetErrorString(err); } +#endif } catch (const std::exception &e) { failed = true; error_msg = e.what(); diff --git a/docs/dynamo_integration/README_en.md b/docs/dynamo_integration/README_en.md new file mode 100644 index 0000000000..6f3988e23e --- /dev/null +++ b/docs/dynamo_integration/README_en.md @@ -0,0 +1,155 @@ +# FlexKV and Dynamo Integration Guide + +This document demonstrates how to integrate FlexKV with NVIDIA's [Dynamo](https://github.com/ai-dynamo/dynamo) framework and complete performance testing. + +Dynamo is a framework designed by NVIDIA for large-scale distributed deployment, supporting multiple backend engines including TensorRT-LLM, vLLM, and SGLang. The KV Router is an intelligent request routing component that tracks and manages KV caches stored on different workers. It intelligently assigns requests to the most suitable worker based on the overlap between requests and KV cache, as well as the current worker load, thereby reducing expensive KV cache recomputations and improving inference efficiency. This document also explains how to integrate FlexKV into Dynamo when the KV Router is enabled. + +## 1. Environment Setup + +### Dynamo Image + +We use Dynamo 0.4.1 image with vLLM backend, which includes vLLM 0.10.1.1. + +```bash +docker pull nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.4.1 +``` + +### FlexKV Code Preparation + +```bash +git clone https://github.com/taco-project/FlexKV +``` + +### Install FlexKV + +```bash +apt update && apt install liburing-dev + +cd FlexKV && ./build.sh +``` + +### vLLM Apply Patch + +```bash +# Navigate to vLLM directory +cd /opt/vllm +# apply patch +git apply /your/path/to/FlexKV/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +### FlexKV Verification + +Please refer to the test scripts in [vLLM online serving](../../docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B). + +## 2. Dynamo Modifications + +### kv_transfer_config + +To integrate with FlexKV, you need to modify the kv_transfer_config in the Dynamo image. Change lines 245-248 in /opt/dynamo/venv/lib/python3.12/site-packages/dynamo/vllm/args.py to: + +```python +kv_transfer_config = KVTransferConfig( + kv_connector="FlexKVConnectorV1", kv_role="kv_both" +) +logger.info("Using FlexKVConnectorV1 configuration") +``` + +### CPU Offloading + +In Dynamo, the KV router updates its KV index by receiving events sent from workers, allowing it to track the KV cache status on each worker. When CPU offloading is enabled in FlexKV, we remove [BlockRemove](https://github.com/vllm-project/vllm/blob/v0.10.1.1/vllm/v1/core/block_pool.py#L221) in vLLM, allowing FlexKV to cache all KV blocks through CPU during the serving process. This ensures that the index maintained by the KV router accurately reflects the actual index in FlexKV. + +## 3. Starting and Verifying Dynamo Services + +### Starting Dynamo + FlexKV + +```bash +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +# Start nats and etcd +nats-server -js & + +etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379 --data-dir /tmp/etcd & + +sleep 3 + +# run ingress, set routing mode with --router-mode, options include kv, round-robin, random +python -m dynamo.frontend --router-mode kv --http-port 8000 & + +# Define number of worker nodes +NUM_WORKERS=4 + +# When using multiple workers, ensure FlexKV ports are different to avoid hanging at flexkv init +# Adjust num_cpu_blocks and num_ssd_blocks values according to your server configuration +for i in $(seq 0 $((NUM_WORKERS-1))); do + cat < ./flexkv_config_${i}.json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_${i}_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": false, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 10240, + "num_ssd_blocks": 256000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + + }, + "num_log_interval_requests": 200 +} +EOF +done + +# Use a loop to start worker nodes +for i in $(seq 0 $((NUM_WORKERS-1))); do + # Calculate GPU device IDs + GPU_START=$((i*2)) + GPU_END=$((i*2+1)) + + if [ $i -lt $((NUM_WORKERS-1)) ]; then + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 & + else + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 + fi +done +``` + +> Note: The `flexkv_config.json` configuration is provided as a simple example only. For full parameter options, please refer to [`docs/flexkv_config_reference/README_en.md`](../../docs/flexkv_config_reference/README_en.md) + +### Verification + +You can verify that the Dynamo service has started correctly with the following command: +```bash +curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "messages": [ + { + "role": "user", + "content": "Tell me a joke." + } + ], + "stream":false, + "max_tokens": 30 + }' +``` + +## 4. Benchmark + +We use [genai-perf](https://github.com/triton-inference-server/perf_analyzer/tree/main/genai-perf) as our benchmark tool and [mooncake trace](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#-open-source-trace) as our dataset to evaluate the performance of Dynamo + FlexKV. + +Mooncake Trace is an open-source request file saved in jsonl format. It records timestamps of request arrivals, ISL, OSL, and KV cache-related hash IDs, containing 23,608 requests over a 1-hour period. For our experiment with 4 LLaMA-70B workers, the concurrency in the mooncake trace was too high, so we sampled every 6th request from the trace to build our benchmark dataset. + +genai-perf can send requests according to the timestamps in the trace file and calculate metrics such as TTFT (Time To First Token) and TPOT (Tokens Per Output Token) for the LLM service. The command is as follows. Please use genai-perf==0.0.13, as newer versions have a bug in timestamp parsing. + +```bash +genai-perf profile --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-70B --endpoint-type chat --endpoint /v1/chat/completions --streaming --url http://localhost:8000 --input-file payload:mooncake_trace_1_6.jsonl --random-seed 100 -v -H 'Authorization: Bearer NOT USED' -H 'Accept: text/event-stream' -- --stability-percentage 99 +``` \ No newline at end of file diff --git a/docs/dynamo_integration/README_zh.md b/docs/dynamo_integration/README_zh.md new file mode 100644 index 0000000000..651b0d9aef --- /dev/null +++ b/docs/dynamo_integration/README_zh.md @@ -0,0 +1,155 @@ +# FlexKV 与 Dynamo 集成指南 + +该文档展示了如何将FlexKV和NVIDIA [Dynamo](https://github.com/ai-dynamo/dynamo) 框架集成,并完成性能测试的步骤。 + +Dynamo是NVIDIA专为大规模分离式部署而设计的框架,支持TensorRT-LLM, vLLM, SGLang等多个后端引擎。其中KV 路由器(KV Router)是一个智能的请求路由组件, 它能够追踪和管理存储在不同worker上的 KV cache,并根据请求与缓存的重叠程度和worker当前负载,智能地将请求分配给最合适的 GPU 节点,从而减少昂贵的 KV 缓存重新计算,提高推理效率。文档也介绍了如何在开启KV Router时,将FlexKV集成进Dynamo。 + +## 1. 环境准备 + +### Dynamo 镜像 + +该文档使用的是后端为vLLM的Dynamo 0.4.1 镜像,内置了vLLM 0.10.1.1。 + +```bash +docker pull nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.4.1 +``` + +### FlexKV代码准备 + +```bash +git clone https://github.com/taco-project/FlexKV +``` + +### 安装 FlexKV + +```bash +apt update && apt install liburing-dev + +cd FlexKV && ./build.sh +``` + +### vLLM Apply Patch + +```bash +# 进入 vLLM 目录 +cd /opt/vllm +# apply patch +git apply /your/path/to/FlexKV/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +### FlexKV 验证 + +请参考[vLLM online serving](../../docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B)里的测试脚本。 + + +## 2. Dynamo 配置修改 + +### kv_transfer_config + +为了和FlexKV集成,需要修改Dynamo镜像内的kv_transfer_config。将/opt/dynamo/venv/lib/python3.12/site-packages/dynamo/vllm/args.py 的245-248行修改为: + +```python +kv_transfer_config = KVTransferConfig( + kv_connector="FlexKVConnectorV1", kv_role="kv_both" +) +logger.info("Using FlexKVConnectorV1 configuration") +``` + +### CPU Offloading + +在Dynamo中,KV router通过接收worker发送的event来更新KV index,从而感知每个worker上的KV cache情况。当FlexKV开启CPU offloading时,我们删掉vLLM里[BlockRemove](https://github.com/vllm-project/vllm/blob/v0.10.1.1/vllm/v1/core/block_pool.py#L221),让FlexKV通过CPU能够缓存住所有serving过程中的KV block,这样KV router维护的index就能反映FlexKV的真实index了。 + +## 3. 启动和验证Dynamo服务 + +### 启动Dynamo + FlexKV + +```bash +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +# 启动nats和etcd +nats-server -js & + +etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379 --data-dir /tmp/etcd & + +sleep 3 + +# run ingress, 通过--router-mode设置路由方式,可选项为kv, round-robin, random +python -m dynamo.frontend --router-mode kv --http-port 8000 & + +# 定义工作节点数量 +NUM_WORKERS=4 + +# 多个worker时注意FlexKV的端口应不同,否则会卡在flexkv init这一步 +# 请根据服务器的配置,调整num_cpu_blocks和num_ssd_blocks的数值 +for i in $(seq 0 $((NUM_WORKERS-1))); do + cat < ./flexkv_config_${i}.json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_${i}_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": false, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 10240, + "num_ssd_blocks": 256000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + + }, + "num_log_interval_requests": 200 +} +EOF +done + +# 使用for循环启动工作节点 +for i in $(seq 0 $((NUM_WORKERS-1))); do + # 计算GPU设备ID + GPU_START=$((i*2)) + GPU_END=$((i*2+1)) + + if [ $i -lt $((NUM_WORKERS-1)) ]; then + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 & + else + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 + fi +done +``` + +> 注:`flexkv_config.json`配置仅为简单示例,选项请参考[`docs/flexkv_config_reference/README_zh.md`](../../docs/flexkv_config_reference/README_zh.md) + +### 验证 + +可通过如下命令验证Dynamo服务是否正确启动: +```bash +curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "messages": [ + { + "role": "user", + "content": "Tell me a joke." + } + ], + "stream":false, + "max_tokens": 30 + }' +``` +## 4. Benchmark + +我们使用[genai-perf](https://github.com/triton-inference-server/perf_analyzer/tree/main/genai-perf)作为benchmark工具、[mooncake trace](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#-open-source-trace)作为数据集来评估Dynamo + FlexKV的性能。 + +Mooncake Trace 是一个开源请求记录文件,以jsonl格式保存。它记录了请求到达的时间戳、输入文本长度、输出文本长度以及与缓存有关的hash id等信息,包含了1小时内的23608个请求。我们的实验资源是4个LLaMA-70B worker,mooncake trace对于该配置来说并发太高了,于是我们从mooncake trace里每6个抽取1个request,构建了用于benchmark的数据集。 + +genai-perf可以根据trace文件里的时间戳来发送请求,统计LLM服务的TTFT、TPOT等指标,命令如下。请使用genai-perf==0.0.13,更新的版本存在解析时间戳的bug。 + +```bash + genai-perf profile --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-70B --endpoint-type chat --endpoint /v1/chat/completions --streaming --url http://localhost:8000 --input-file payload:mooncake_trace_1_6.jsonl --random-seed 100 -v -H 'Authorization: Bearer NOT USED' -H 'Accept: text/event-stream' -- --stability-percentage 99 +``` \ No newline at end of file diff --git a/docs/flexkv_config_reference/README_en.md b/docs/flexkv_config_reference/README_en.md new file mode 100644 index 0000000000..f91ca77ba9 --- /dev/null +++ b/docs/flexkv_config_reference/README_en.md @@ -0,0 +1,147 @@ +# FlexKV Configuration Guide + +This guide explains how to configure and use the FlexKV online serving configuration file (`flexkv_config.json`), including the meaning of all parameters, recommended values, and typical usage scenarios. + +--- + +## Recommended Configuration + +Below is a production-grade recommended configuration that balances performance and stability: + +```json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": true, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 233000, + "num_ssd_blocks": 4096000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + }, + "num_log_interval_requests": 2000 +} +``` +- `num_cpu_blocks` and `num_ssd_blocks` represent the total number of blocks in CPU memory and SSD respectively. These values must be configured according to your machine specs and model size. See [Cache Capacity Configuration](#cache-capacity-config) for calculation details. +- `ssd_cache_dir` specifies the directory where SSD-stored KV cache files are saved. + +--- + +## Configuration File Structure Overview + +The FlexKV configuration file is a JSON file, primarily consisting of three parts: + +- `enable_flexkv`: Whether to enable FlexKV (must be set to `true` to take effect). +- `server_recv_port`: The IPC port on which the FlexKV service listens. +- `cache_config`: The core cache configuration object, containing all cache behavior parameters. +- `num_log_interval_requests`: Log statistics interval (outputs performance log every N requests). + +--- + +## Complete `cache_config` Parameter Reference (from [`flexkv/common/config.py`](../../flexkv/common/config.py)) + +### Basic Configuration + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `tokens_per_block` | int | 16 | Number of tokens per KV block. Must match the `block_size` used in the acceleration framework (e.g., vLLM). | +| `enable_cpu` | bool | true | Whether to enable CPU memory as a cache layer. Strongly recommended to enable. | +| `enable_ssd` | bool | false | Whether to enable SSD as a cache layer. Recommended if NVMe SSD is available. | +| `enable_remote` | bool | false | Whether to enable remote cache (e.g., scalable cloud storage). Requires remote cache engine and custom implementation. | +| `use_gds` | bool | false | Whether to use GPU Direct Storage (GDS) to accelerate SSD I/O. Not currently supported. | +| `index_accel` | bool | false | Whether to enable C++ RadixTree. Recommended to enable. | + +--- + +### KV Cache Layout Types (Generally No Need to Modify) + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `gpu_kv_layout_type` | enum | LAYERWISE | Organization of KV cache on GPU (layer-wise or block-wise). Must match vLLM’s layout (currently `LAYERWISE`). | +| `cpu_kv_layout_type` | enum | BLOCKWISE | Organization on CPU. Recommended to use `BLOCKWISE`. Does not need to match vLLM. | +| `ssd_kv_layout_type` | enum | BLOCKWISE | Organization on SSD. Recommended to use `BLOCKWISE`. Does not need to match vLLM. | +| `remote_kv_layout_type` | enum | BLOCKWISE | Organization for remote cache. Must be defined according to remote backend’s layout. | + +> Note: Do not modify layout types unless you have specific performance requirements. + +--- + +### Cache Capacity Configuration + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `num_cpu_blocks` | int | 1000000 | Number of blocks allocated in CPU memory. Adjust based on available RAM. | +| `num_ssd_blocks` | int | 10000000 | Number of blocks allocated on SSD. | +| `num_remote_blocks` | int \| None | None | Number of blocks allocated in remote cache. | + +> Note: Block size in all cache levels (CPU/SSD/Remote) matches the GPU block size. Estimate cache capacities based on GPU KV cache memory usage and block count. + +> Note: `block_size = num_layer * _kv_dim * tokens_per_block * num_head * head_size * dtype_size`. + +--- + +### CPU-GPU Transfer Optimization + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `use_ce_transfer_h2d` | bool | false | Whether to use CUDA Copy Engine for Host→Device transfers. Reduces SM usage but may slightly reduce bandwidth. Real-world difference is minimal. | +| `use_ce_transfer_d2h` | bool | false | Whether to use CUDA Copy Engine for Device→Host transfers. | +| `transfer_sms_h2d` | int | 8 | Number of SMs (Streaming Multiprocessors) allocated for H2D transfers. | +| `transfer_sms_d2h` | int | 8 | Number of SMs allocated for D2H transfers. | + +--- + +### SSD Cache Configuration + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `max_blocks_per_file` | int | 32000 | Maximum number of blocks per SSD file. `-1` means unlimited. | +| `ssd_cache_dir` | str \| List[str] | None | **Required.** Path to SSD cache directory, e.g., `"/data/flexkv_ssd/"`. | +| `ssd_cache_iouring_entries` | int | 0 | io_uring queue depth. Recommended: `512` for significantly improved concurrent I/O performance. | +| `ssd_cache_iouring_flags` | int | 0 | io_uring flags. Keep as `0` in most cases. | + +> Note: To maximize bandwidth across multiple SSDs, bind each SSD to a separate directory and specify them as a list: +> `"ssd_cache_dir": ["/data0/flexkv_ssd/", "/data1/flexkv_ssd/"]`. +> KV blocks will be evenly distributed across all SSDs. + +> Note: Setting `ssd_cache_iouring_entries` to `0` disables io_uring. Not recommended. + +--- + +### Remote Cache Configuration (Skip if not enabled) + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `remote_cache_size_mode` | str | "file_size" | Allocate remote cache space by file size or block count. | +| `remote_file_size` | int \| None | None | Size (in bytes) of each remote file. | +| `remote_file_num` | int \| None | None | Number of remote files. | +| `remote_file_prefix` | str \| None | None | Prefix for remote file names. | +| `remote_cache_path` | str \| List[str] | None | Remote cache path (e.g., Redis URL, S3 path). | +| `remote_config_custom` | dict \| None | None | Custom remote cache configurations (e.g., timeout, authentication). | + +--- + +### Tracing and Logging + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `enable_trace` | bool | true | Whether to enable performance tracing. Disable (`false`) in production to reduce overhead. | +| `trace_file_path` | str | "./flexkv_trace.log" | Path to trace log file. | +| `trace_max_file_size_mb` | int | 100 | Maximum size (MB) per trace log file. | +| `trace_max_files` | int | 5 | Maximum number of trace log files to retain. | +| `trace_flush_interval_ms` | int | 1000 | Trace log flush interval (milliseconds). | + +--- + +### Cache Eviction Policy + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `evict_ratio` | float | 0.0 | Ratio of blocks to proactively evict from CPU/SSD per eviction cycle. `0.0` = evict only the minimal necessary blocks (more eviction cycles may impact performance). Recommended: `0.05` (evict 5% of least recently used blocks per cycle). | \ No newline at end of file diff --git a/docs/flexkv_config_reference/README_zh.md b/docs/flexkv_config_reference/README_zh.md new file mode 100644 index 0000000000..1752f844bf --- /dev/null +++ b/docs/flexkv_config_reference/README_zh.md @@ -0,0 +1,145 @@ +# FlexKV 配置使用指南 + +本指南详细说明如何配置和使用 FlexKV 的在线服务配置文件(`flexkv_config.json`),涵盖所有参数的含义、推荐值及典型使用场景。 + +--- + +## 推荐配置方案 + +以下是一个兼顾性能与稳定性的生产级推荐配置: + +```json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": true, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 233000, + "num_ssd_blocks": 4096000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + }, + "num_log_interval_requests": 2000 +} +``` +- 其中的`num_cpu_blocks`和`num_ssd_blocks`分别代表内存和SSD中block的总数量,需要根据实际机器配置和模型来配置,具体计算方式见下文[缓存容量配置](#cache-capacity-config) +- `ssd_cache_dir`为ssd中KVCache存放的文件目录 + +--- + +## 配置文件结构概览 + +FlexKV 的配置文件是一个 JSON 文件,主要包含三个部分: + +- `enable_flexkv`: 是否启用 FlexKV 功能(必须设为 `true` 才生效) +- `server_recv_port`: FlexKV 服务监听的 IPC 端口 +- `cache_config`: 核心缓存配置对象,包含所有缓存行为参数 +- `num_log_interval_requests`: 日志统计间隔(每处理 N 个请求输出一次性能日志) + +--- + +## cache_config完整参数详解(来自 [`flexkv/common/config.py`](../../flexkv/common/config.py)) + +### 基础配置 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `tokens_per_block` | int | 16 | 每个 KV Block 包含的 token 数量。需要与加速框架(如vLLM)中`block_size`保持一致 | +| `enable_cpu` | bool | true | 是否启用 CPU 内存作为缓存层。强烈建议开启。 | +| `enable_ssd` | bool | false | 是否启用 SSD 作为缓存层。如配备 NVMe SSD,建议开启。 | +| `enable_remote` | bool | false | 是否启用远程缓存(如可扩展云存储等)。需要配合远程缓存和自定义的远程缓存引擎使用 | +| `use_gds` | bool | false | 是否使用 GPU Direct Storage(GDS)加速 SSD 读写。目前暂不支持。 | +| `index_accel` | bool | false | 是否启用C++ RadixTree。推荐开启。 | + +--- + +### KV 缓存布局类型(一般无需修改) + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `gpu_kv_layout_type` | enum | LAYERWISE | GPU 上 KV Cache 的组织方式(按层或按块)。目前vLLM在GPU组织方式为`LAYERWISE`,因此FlexKV的`gpu_kv_layout_type`须与vLLM保持一致 | +| `cpu_kv_layout_type` | enum | BLOCKWISE | CPU 上按块组织, 推荐使用`BLOCKWISE`,不需要与vLLM保持一致 | +| `ssd_kv_layout_type` | enum | BLOCKWISE | SSD 上按块组织, 推荐使用`BLOCKWISE`,不需要与vLLM保持一致 | +| `remote_kv_layout_type` | enum | BLOCKWISE | 远程缓存按块组织, 需要按照remote组织形式定义 | + +> 注:除非有特殊性能需求,否则不建议修改布局类型。 + +--- + +### 缓存容量配置 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `num_cpu_blocks` | int | 1000000 | CPU 缓存块数。根据内存大小调整。| +| `num_ssd_blocks` | int | 10000000 | SSD 缓存块数。| +| `num_remote_blocks` | int \| None | None | 远程缓存块数。| + +> 注:FlexKV里的各级缓存的block大小与GPU中的block大小保持一致,可以参考GPU的KVCache显存大小与block数量估算各级缓存中的block数量。 + +> 注:block_size = num_layer * _kv_dim * tokens_per_block * num_head * self.head_size * torch_dtype.size()。 + +--- + +### CPU-GPU 传输优化 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `use_ce_transfer_h2d` | bool | false | 是否使用 cuda copy engine 优化 Host→Device 传输,使用CE可以减少GPU SM在传输上的使用,但是传输速度会降低,实际测试差距不大 | +| `use_ce_transfer_d2h` | bool | false | 是否使用 cuda copy engine 优化 Device→Host 传输 | +| `transfer_sms_h2d` | int | 8 | H2D 传输使用的流处理器数量 | +| `transfer_sms_d2h` | int | 8 | D2H 传输使用的流处理器数量 | + +--- + +### SSD 缓存配置 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `max_blocks_per_file` | int | 32000 | 单个 SSD 文件最多包含的 block 数。-1 表示无限制 | +| `ssd_cache_dir` | str \| List[str] | None | SSD 缓存目录路径,**必须设置**,如 `"/data/flexkv_ssd/"` | +| `ssd_cache_iouring_entries` | int | 0 | io_uring 队列深度,推荐设为 `512` 以提升并发 IO 性能,实测比不使用iouring提升较大,推荐使用512 | +| `ssd_cache_iouring_flags` | int | 0 | io_uring 标志位,一般保持 0 | + +> 注:为了充分利用多块SSD的带宽上限,可以将多块SSD绑定至不同目录,并使用如 `"ssd cache dir": ["/data0/flexkv_ssd/", "/data1/flexkv_ssd/"]`方式初始化,SSD KVCache会均匀分布在所有SSD中,充分利用多个SSD带宽。 + +> 注:`ssd_cache_iouring_entries`设置为0即不适用iouring,不推荐设置为0 + +--- + +### 远程缓存配置(不启用时无需配置) + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `remote_cache_size_mode` | str | "file_size" | 按文件大小或块数分配远程缓存空间 | +| `remote_file_size` | int \| None | None | 单个远程文件大小(字节) | +| `remote_file_num` | int \| None | None | 远程文件数量 | +| `remote_file_prefix` | str \| None | None | 远程文件名前缀 | +| `remote_cache_path` | str \| List[str] | None | 远程缓存路径(如 Redis URL、S3 路径等) | +| `remote_config_custom` | dict \| None | None | 自定义远程缓存配置(如超时、认证等) | + +--- + +### 追踪与日志 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `enable_trace` | bool | true | 是否启用性能追踪。生产环境建议关闭(`false`)以减少开销 | +| `trace_file_path` | str | "./flexkv_trace.log" | 追踪日志路径 | +| `trace_max_file_size_mb` | int | 100 | 单个追踪文件最大大小(MB) | +| `trace_max_files` | int | 5 | 最多保留的追踪文件数 | +| `trace_flush_interval_ms` | int | 1000 | 追踪日志刷新间隔(毫秒) | + +--- + +### 缓存淘汰策略 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `evict_ratio` | float | 0.0 | cpu,ssd一次evict主动淘汰比例(0.0 = 只淘汰最小的必要的block数量,较多的淘汰次数会影响性能)。建议保持 `0.05`,即每一次淘汰5%的最久未使用的block | diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md new file mode 100644 index 0000000000..972cade803 --- /dev/null +++ b/docs/vllm_adapter/README_en.md @@ -0,0 +1,86 @@ +# Using FlexKV in vLLM + +## Current Version vs. Legacy Version +In commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934), we introduced a major update: +**FlexKV has transitioned from a client-server architecture to a library function that inference acceleration engines (such as vLLM) can directly invoke**, reducing inter-process communication overhead. + +This change involves significant API adjustments. Therefore, please note: + +- **Version >= `1.0.0`**: Use the **current version API**; the vLLM patch is located in `examples/vllm_adaption/`. +- **Version == `0.1.0`**: Supports the **legacy version API**; the vLLM patch is located in `examples/vllm_adaption_legacy/`. + +--- + +## Current Version (>= 1.0.0) + +### Supported Versions +- FlexKV >= `1.0.0` +- vLLM versions >= `0.8.5` can generally follow this version for adaptation + +### Example +We provide an adaptation example based on **vLLM 0.10.1.1**: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +> Note: The `flexkv_config.json` configuration is provided as a simple example only. For full parameter options, please refer to [`docs/flexkv_config_reference/README_en.md`](../../docs/flexkv_config_reference/README_en.md) + +## Legacy Version (<= 0.1.0) – Not Recommended for Current Use + +### Supported Versions +- FlexKV <= `0.1.0` + +### Example +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: + +```bash +# Start FlexKV as server +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# Start vLLM as client +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# Start benchmark +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md new file mode 100644 index 0000000000..bb9b51c292 --- /dev/null +++ b/docs/vllm_adapter/README_zh.md @@ -0,0 +1,85 @@ +# 在 vLLM 中使用 FlexKV + +## 当前版本与 Legacy 版本说明 +在 commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934),我们更新了一个重要功能: + **FlexKV 从 client-server 模式,变为推理加速引擎(如 vLLM)可直接调用的库函数**,以减少进程间消息传递的开销。 +这一变更引发了较大的 API 调整。因此,请注意: + +- **版本 >= `1.0.0`**:应使用 **当前版本 API**,vLLM patch位于 `examples/vllm_adaption/`。 +- **版本 == `0.1.0`**:仅支持 **Legacy 版本 API**, vLLM patch位于`examples/vllm_adaption_legacy/`。 + +--- + +## 当前版本(>= 1.0.0) + +### 适用版本 +- FlexKV >= `1.0.0` +- vLLM 原则上>= `0.8.5`版本均可参考示例代码进行修改 + +### 示例 +我们提供了基于 **vLLM 0.10.1.1** 的适配示例: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +> 注:`flexkv_config.json`配置仅为简单示例,选项请参考[`docs/flexkv_config_reference/README_zh.md`](../../docs/flexkv_config_reference/README_zh.md) + +## Legacy版本(<= 0.1.0),目前的版本尽量不要使用 + +### 适用版本 +- FlexKV <= `0.1.0` + +### 示例 +在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: + +```bash +# 启动 FlexKV 作为服务端 +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# 启动 vLLM 作为客户端 +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# 启动性能测试 +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 diff --git a/examples/run_server.py b/examples/run_server.py index d5b6a182ec..48b24ecad1 100644 --- a/examples/run_server.py +++ b/examples/run_server.py @@ -12,16 +12,16 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - + # NAME - parser.add_argument("--enable-cpu", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-cpu", + action=argparse.BooleanOptionalAction, default=True) - parser.add_argument("--enable-ssd", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-ssd", + action=argparse.BooleanOptionalAction, default=False,) - parser.add_argument("--enable-remote", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-remote", + action=argparse.BooleanOptionalAction, default=False,) parser.add_argument("--model-path", type=str, help="model path", default="") parser.add_argument("--tp-size", type=int, default=1) @@ -54,7 +54,7 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() hf_config = AutoConfig.from_pretrained(args.model_path) - + num_layers=hf_config.num_hidden_layers if hasattr(hf_config, 'num_key_value_heads'): num_kv_heads=hf_config.num_key_value_heads @@ -65,7 +65,7 @@ def parse_args() -> argparse.Namespace: head_size=(hf_config.head_dim if hasattr(hf_config, 'head_dim') else hf_config.hidden_size//hf_config.num_attention_heads) use_mla=hf_config.architectures[0].startswith("Deepseek") - + # TODO: different model config may have different attribute name model_config = ModelConfig( num_layers=num_layers, @@ -76,14 +76,13 @@ def parse_args() -> argparse.Namespace: dp_size=args.dp_size, dtype=hf_config.torch_dtype ) - + cache_config = CacheConfig( enable_cpu=args.enable_cpu, enable_ssd=args.enable_ssd, enable_remote=args.enable_remote, use_gds=False, enable_trace=False, - use_pinned_memory=False, ssd_cache_iouring_entries=512, tokens_per_block=args.block_size, num_cpu_blocks=args.num_cpu_blocks, @@ -93,6 +92,6 @@ def parse_args() -> argparse.Namespace: remote_cache_size_mode=args.remote_cache_size_mode, remote_cache_path=args.remote_cache_path, ) - + kvserver = KVServer(model_config, cache_config, args.server_recv_port) - kvserver.run() \ No newline at end of file + kvserver.run() diff --git a/examples/scheduler_server_example.py b/examples/scheduler_server_example.py index 29826afc9a..059cc467aa 100644 --- a/examples/scheduler_server_example.py +++ b/examples/scheduler_server_example.py @@ -16,9 +16,9 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, model_config, gpu_kv_layout): """Run TP client process""" from flexkv.server.client import KVTPClient - + print(f"Starting TP client: dp_client_id={dp_client_id}, tp_rank={tp_rank}, device_id={device_id}") - + try: # Set CUDA device for this process if torch.cuda.is_available(): @@ -27,7 +27,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo torch.cuda.init() # Clear cache torch.cuda.empty_cache() - + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Create GPU blocks for this TP client @@ -51,7 +51,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo # Keep TP client running while True: time.sleep(1) - + except Exception as e: print(f"TP client {tp_rank} error: {e}") import traceback @@ -84,7 +84,6 @@ def main(): enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks, ) @@ -106,14 +105,14 @@ def main(): cache_config=cache_config, server_recv_port="ipc:///tmp/scheduler_server_example" # TPClient connects to this port ) - + # Start background server thread to handle TPClient registration scheduler_server.start_server_thread() - - print(f"SchedulerServer started!") + + print("SchedulerServer started!") print(f"TPClient can connect to: {scheduler_server.get_server_port()}") print("Starting TP client processes...") - + # Start TP client processes tp_client_processes = [] for tp_rank in range(tp_size): @@ -123,7 +122,7 @@ def main(): if device_id >= available_gpus: device_id = device_id % available_gpus print(f"Warning: Using GPU {device_id} for TP rank {tp_rank} (not enough GPUs)") - + tp_client_process = Process( target=run_tp_client_process, args=(0, tp_rank, device_id, scheduler_server.get_server_port(), model_config, gpu_kv_layout), @@ -134,32 +133,32 @@ def main(): print(f"Started TP client process for rank {tp_rank} on device {device_id}") print("Waiting for all TP clients to register...") - + time.sleep(5) - + # Now we can directly use scheduler_server without network communication # Example: Create some test data (following benchmark_kvmanager.py pattern) batch_size = 4 seq_len = 128 - + print("\n=== Generating test data ===") # Generate separate sequences for each request (correct approach) batch_token_ids = [] batch_slot_mappings = [] batch_token_masks = [] - + for i in range(batch_size): # Each sequence is independent (seq_len,) shape token_ids = torch.randint(0, 1000, (seq_len,)) slot_mapping = torch.arange(i * seq_len, (i + 1) * seq_len) token_mask = torch.ones(seq_len, dtype=torch.bool) - + batch_token_ids.append(token_ids) batch_slot_mappings.append(slot_mapping) batch_token_masks.append(token_mask) - + print(f"Generated {batch_size} sequences, each with {seq_len} tokens") - + print("\n=== Executing PUT Operations ===") # PUT operations - each sequence processed separately start_time = time.time() @@ -173,7 +172,7 @@ def main(): if task_id: put_task_ids.append(task_id) print(f"PUT task {task_id} created for sequence {i}") - + put_time = (time.time() - start_time) * 1000 print(f"Created {len(put_task_ids)} PUT tasks, time: {put_time:.2f}ms") time.sleep(2) @@ -190,10 +189,10 @@ def main(): if task_id: get_task_ids.append(task_id) print(f"GET task {task_id} created for sequence {i}") - + get_time = (time.time() - start_time) * 1000 print(f"Created {len(get_task_ids)} GET tasks, time: {get_time:.2f}ms") - + print("\n=== Waiting for All Tasks to Complete ===") # Wait for all tasks to complete - can wait for multiple tasks at once all_task_ids = put_task_ids + get_task_ids @@ -202,7 +201,7 @@ def main(): masks = scheduler_server.wait(all_task_ids) wait_time = (time.time() - start_time) * 1000 print(f"All {len(all_task_ids)} tasks completed, time: {wait_time:.2f}ms") - + # Analyze results if masks: total_tokens = 0 @@ -211,7 +210,7 @@ def main(): tokens = mask.sum().item() if hasattr(mask, 'sum') else len(mask) total_tokens += tokens print(f"Task {task_id}: {tokens} tokens processed") - + print("\n=== Trying Non-blocking Wait ===") # Create a few more tasks and try non-blocking wait extra_task_ids = [] @@ -223,7 +222,7 @@ def main(): ) if task_id: extra_task_ids.append(task_id) - + if extra_task_ids: # Immediately try to wait (might not be completed yet) masks = scheduler_server.try_wait(extra_task_ids) @@ -233,15 +232,15 @@ def main(): print(f"Tasks {extra_task_ids} not ready yet, will wait...") masks = scheduler_server.wait(extra_task_ids) print(f"Tasks {extra_task_ids} completed after wait") - + print("\n✅ All operations completed successfully!") - - + + # Clean up resources print("\n=== Shutting down SchedulerServer ===") scheduler_server.shutdown() print("SchedulerServer has been shut down") - + # Terminate TP client processes print("Terminating TP client processes...") for i, process in enumerate(tp_client_processes): @@ -253,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/flexkv/integration/vllm/0001-add-flexkv-connector.patch b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch similarity index 93% rename from flexkv/integration/vllm/0001-add-flexkv-connector.patch rename to examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch index fc0a558d03..812a1d6e2f 100644 --- a/flexkv/integration/vllm/0001-add-flexkv-connector.patch +++ b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch @@ -1,24 +1,9 @@ -From a434b67b8097990f20d8c020a8c713b10dd3d5b0 Mon Sep 17 00:00:00 2001 -From: zuogan -Date: Wed, 3 Sep 2025 05:11:50 -0700 -Subject: [PATCH] add flexkv connector - ---- - .../prefix_caching_flexkv.py | 163 +++++++++++++++ - .../kv_transfer/kv_connector/factory.py | 5 + - .../kv_connector/v1/flexkv_connector.py | 191 ++++++++++++++++++ - vllm/v1/core/sched/scheduler.py | 13 +- - .../worker/kv_connector_model_runner_mixin.py | 6 +- - 5 files changed, 373 insertions(+), 5 deletions(-) - create mode 100644 examples/offline_inference/prefix_caching_flexkv.py - create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py - diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py new file mode 100644 -index 000000000..4cfe2ef7f +index 000000000..a57328ffd --- /dev/null +++ b/examples/offline_inference/prefix_caching_flexkv.py -@@ -0,0 +1,163 @@ +@@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import time @@ -36,7 +21,6 @@ index 000000000..4cfe2ef7f + "cache_config": { + "enable_cpu": True, + "num_cpu_blocks": 10240, -+ "use_pinned_memory": True + }, + "num_log_interval_requests": 200 +} @@ -84,7 +68,7 @@ index 000000000..4cfe2ef7f + +def main(): + # Create an LLM without prefix caching as a baseline. -+ regular_llm = LLM(model=model_path, ++ regular_llm = LLM(model=model_path, + enable_prefix_caching=False, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size @@ -114,7 +98,7 @@ index 000000000..4cfe2ef7f + # return + + # Create an LLM with prefix caching enabled. -+ prefix_cached_llm = LLM(model=model_path, ++ prefix_cached_llm = LLM(model=model_path, + enable_prefix_caching=True, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, @@ -124,7 +108,7 @@ index 000000000..4cfe2ef7f + # Warmup so that the shared prompt's KV cache is computed. + prefix_cached_llm.generate(generating_prompts[0], sampling_params) + -+ # wait for offload kv task finished. ++ # wait for offload kv task finished. + time.sleep(2) + + # Generate with prefix caching. @@ -149,7 +133,7 @@ index 000000000..4cfe2ef7f + ]) + print(f"Generated answers are the same: {generated_same}") + -+ # wait for offload kv task finished. ++ # wait for offload kv task finished. + time.sleep(2) + + # reset prefix cache to use flexkv @@ -249,9 +233,9 @@ index 000000000..bdfa9f321 + **kwargs: additional arguments for the load operation + + Note: -+ The number of elements in kv_caches and layer_names should be ++ The number of elements in kv_caches and layer_names should be + the same. -+ ++ + """ + self._flexkv_connector.start_load_kv(forward_context, **kwargs) + @@ -260,7 +244,7 @@ index 000000000..bdfa9f321 + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. -+ ++ + This interface will be useful for layer-by-layer pipelining. + + Args: @@ -271,13 +255,13 @@ index 000000000..bdfa9f321 + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ -+ Start saving the a layer of KV cache from vLLM's paged buffer ++ Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. -+ kv_layer (torch.Tensor): the paged KV buffer of the current ++ kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. @@ -310,7 +294,7 @@ index 000000000..bdfa9f321 + call to this method (this call or a prior one). + """ + return self._flexkv_connector.get_finished(finished_req_ids) -+ ++ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the @@ -332,14 +316,14 @@ index 000000000..bdfa9f321 + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. -+ ++ + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: -+ the number of tokens that can be loaded from the ++ the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._flexkv_connector.get_num_new_matched_tokens( @@ -398,30 +382,30 @@ index 981023409..a6c8fac38 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -118,6 +118,7 @@ class Scheduler(SchedulerInterface): - + # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + self.sending_kv_reqs: dict[str, Request] = {} - + # Encoder-related. # Calculate encoder cache size if applicable @@ -1029,7 +1030,8 @@ class Scheduler(SchedulerInterface): - + if not delay_free_blocks: self._free_blocks(request) - + else: + self.sending_kv_reqs[request.request_id] = request return kv_xfer_params - + def _free_blocks(self, request: Request): @@ -1041,7 +1043,7 @@ class Scheduler(SchedulerInterface): return len(self.waiting) + len(self.running) - + def has_finished_requests(self) -> bool: - return len(self.finished_req_ids) > 0 + return len(self.finished_req_ids) > 0 or len(self.sending_kv_reqs) > 0 - + def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() @@ -1082,6 +1084,8 @@ class Scheduler(SchedulerInterface): @@ -430,20 +414,20 @@ index 981023409..a6c8fac38 100644 self.kv_event_publisher.shutdown() + if self.connector and hasattr(self.connector, "shutdown"): + self.connector.shutdown() - + ######################################################################## # KV Connector Related Methods @@ -1149,6 +1153,10 @@ class Scheduler(SchedulerInterface): scheduler the request during the next step. """ - + + # avoid busy checking + if len(self.running) == 0: + time.sleep(0.01) + if self.connector is not None: self.connector.update_connector_output(kv_connector_output) - + @@ -1158,4 +1166,5 @@ class Scheduler(SchedulerInterface): self.finished_recving_kv_req_ids.add(req_id) for req_id in (kv_connector_output.finished_sending or ()): @@ -457,16 +441,13 @@ index a03ebe35d..8e4460957 100644 @@ -66,9 +66,9 @@ class KVConnectorModelRunnerMixin: scheduler_output, wait_for_save=False) as kv_connector_output: pass - + - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT + # if (not kv_connector_output.finished_sending + # and not kv_connector_output.finished_recving): + # return EMPTY_MODEL_RUNNER_OUTPUT - + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = kv_connector_output --- -2.34.1 - diff --git a/examples/vllm_adaption/flexkv_vllm_0_10_0.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_10_0.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch diff --git a/examples/vllm_adaption/flexkv_vllm_0_8_4.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_8_4.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch diff --git a/flexkv/cache/__init__.py b/flexkv/cache/__init__.py index cabe880787..e4782fc36e 100644 --- a/flexkv/cache/__init__.py +++ b/flexkv/cache/__init__.py @@ -1,9 +1,15 @@ -from .redis_meta import RedisMetaChannel, BlockMeta +# Ensure C++ extensions are loaded first +import flexkv.c_ext + +# Import other modules from .radix_remote import DistributedRadixTree, LocalRadixTree +from .redis_meta import RedisMetaChannel, BlockMeta, RedisNodeInfo __all__ = [ + "RedisMeta", "RedisMetaChannel", - "BlockMeta", + "RedisNodeInfo", + "BlockMeta", "DistributedRadixTree", "LocalRadixTree", ] diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 1bc73c744f..515148913e 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -71,12 +71,15 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: sequence_meta.gen_hashes() match_result = self.index.match_prefix(torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), sequence_meta.num_blocks, True) - # physical blocks - phys = torch.tensor(match_result.physical_blocks, dtype=torch.int64).numpy() + # physical blocks (torch.Tensor -> numpy, zero-copy on CPU) + phys = match_result.physical_blocks.cpu().numpy() # optional block_node_ids try: - bnis = getattr(match_result, "block_node_ids") - bnids_np = torch.tensor(bnis, dtype=torch.uint32).numpy() if bnis is not None else None + bnis = getattr(match_result, "block_node_ids", None) + if isinstance(bnis, torch.Tensor) and bnis.numel() > 0: + bnids_np = bnis.cpu().numpy() + else: + bnids_np = None except Exception: bnids_np = None return MatchResultAccel(match_result.num_ready_matched_blocks, match_result.num_matched_blocks, @@ -111,9 +114,11 @@ def insert(self, def lock_node(self, node: CRadixNode) -> None: self.index.lock(node) - def cleanup(self, node: CRadixNode, cleanup_length: int) -> None: + def unlock(self, node: CRadixNode) -> None: self.index.unlock(node) - self.index.set_ready(node, True, cleanup_length) + + def set_ready(self, node: CRadixNode, ready: bool, ready_length: int) -> None: + self.index.set_ready(node, ready, ready_length) def take(self, num_required_blocks: int, @@ -122,7 +127,10 @@ def take(self, if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) - evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + evict_block_num = max( + num_required_blocks - self.mempool.num_free_blocks, + int(self.mempool.num_total_blocks * self.evict_ratio) + ) target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) num_evicted = self.index.evict(target_blocks, evict_block_num) if num_evicted != evict_block_num: @@ -189,9 +197,11 @@ def insert(self, def lock_node(self, node: RadixNode) -> None: self.index.lock(node) - def cleanup(self, node: RadixNode, cleanup_length: int) -> None: + def unlock(self, node: RadixNode) -> None: self.index.unlock(node) - self.index.set_ready(node, True, cleanup_length) + + def set_ready(self, node: RadixNode, ready: bool, ready_length: int) -> None: + self.index.set_ready(node, ready, ready_length) def take(self, num_required_blocks: int, @@ -200,7 +210,10 @@ def take(self, if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) - evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + evict_block_num = max( + num_required_blocks - self.mempool.num_free_blocks, + int(self.mempool.num_total_blocks * self.evict_ratio) + ) self.mempool.recycle_blocks( self.index.evict(evict_block_num) ) @@ -232,8 +245,10 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): cache_config.redis_password, cache_config.local_ip, ) - self.redis_meta.init_meta() - self.node_id = self.redis_meta.get_node_id() + node_id = self.redis_meta.init_meta() + if node_id is None: + raise RuntimeError("Failed to initialize Redis metadata") + self.node_id = node_id self.enable_kv_sharing = True else: self.enable_kv_sharing = False @@ -283,10 +298,10 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): cache_config.evict_ratio) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine - self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \ - lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, 0) - self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]] = \ - lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, 0, 0) + self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]] = \ + lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0) + self._empty_put_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]] = \ + lambda request_id: (TransferOpGraph.create_empty_graph(), [], {}, {}, {}, 0, 0) def reset(self) -> None: if self.cpu_cache_engine: @@ -303,7 +318,7 @@ def get(self, slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -320,7 +335,7 @@ def get(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block @@ -331,25 +346,27 @@ def get(self, tokens_per_block=self.cache_config.tokens_per_block) if not self.cache_config.enable_remote: - transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, num_gpu_blocks_to_transfer = \ + (transfer_graph, finished_ops_ids, node_to_unlock, + op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer) = \ self._get_impl_local( - request_id, - sequence_meta, - block_start_idx, - block_end_idx, - gpu_block_ids, - layer_num - ) + request_id, + sequence_meta, + block_start_idx, + block_end_idx, + gpu_block_ids, + layer_num + ) else: - transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, num_gpu_blocks_to_transfer = \ + (transfer_graph, finished_ops_ids, node_to_unlock, + op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer) = \ self._get_impl_global( - request_id, - sequence_meta, - block_start_idx, - block_end_idx, - gpu_block_ids, - layer_num - ) + request_id, + sequence_meta, + block_start_idx, + block_end_idx, + gpu_block_ids, + layer_num + ) transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, @@ -374,7 +391,13 @@ def get(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, task_end_op_id + op_callback_dict = {} # dict, op_id -> callback + for op_id in op_node_to_ready: + op_callback_dict[op_id] = partial(self._op_callback, + device_type=op_node_to_ready[op_id][0], + node_to_ready=op_node_to_ready[op_id][1], + ready_length=op_node_to_ready[op_id][2]) + return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id def _get_impl_global(self, request_id: int, @@ -382,7 +405,7 @@ def _get_impl_global(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: + layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]: """ transfer pattern: @@ -547,7 +570,10 @@ def _get_impl_global(self, buffer_to_free = {DeviceType.CPU: cpu_blocks_to_free} # NOTE: for now in build transfer graph, we assume that cpu works as a cache for ssd - return transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, len(fragment123_gpu_blocks) + return ( + transfer_graph, finished_ops_ids, node_to_unlock, {}, buffer_to_free, + len(fragment123_gpu_blocks) # op_node_to_ready: {} + ) def _get_impl_local(self, request_id: int, @@ -555,7 +581,7 @@ def _get_impl_local(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: + layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int]: """ transfer pattern: @@ -592,6 +618,7 @@ def _get_impl_local(self, transfer_graph = TransferOpGraph() finished_ops_ids = [] + op_node_to_ready = {} fragment12_gpu_blocks = gpu_block_ids[:fragment12_num_blocks] fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] @@ -636,6 +663,7 @@ def _get_impl_local(self, block_mask_start, is_ready=False, match_result=cpu_matched_result) + op_node_to_ready[op_disk2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size()) else: cpu_blocks_to_free = fragment2_cpu_blocks if fragment2_cpu_blocks is not None: @@ -662,7 +690,10 @@ def _get_impl_local(self, node_to_unlock[DeviceType.SSD] = (ssd_node_to_unlock, ssd_node_to_unlock.size()) buffer_to_free = {DeviceType.CPU: cpu_blocks_to_free} - return transfer_graph, finished_ops_ids, node_to_unlock, buffer_to_free, len(fragment12_gpu_blocks) + return ( + transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, + buffer_to_free, len(fragment12_gpu_blocks) + ) def put(self, request_id: int, @@ -670,7 +701,7 @@ def put(self, token_mask: np.ndarray, slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, Dict, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -678,7 +709,7 @@ def put(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) # the mask should has a prefix of True @@ -691,7 +722,7 @@ def put(self, tokens_per_block=self.cache_config.tokens_per_block) if not self.cache_config.enable_remote: - (transfer_graph, finished_ops_ids, node_to_unlock, + (transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer, skipped_gpu_blocks) = \ self._put_impl_local( request_id, @@ -702,7 +733,7 @@ def put(self, layer_num ) else: - (transfer_graph, finished_ops_ids, node_to_unlock, + (transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, buffer_to_free, num_gpu_blocks_to_transfer, skipped_gpu_blocks) = \ self._put_impl_global( request_id, @@ -728,9 +759,17 @@ def put(self, callback = partial(self._transfer_callback, node_to_unlock=node_to_unlock, - buffer_to_free=buffer_to_free) + buffer_to_free=buffer_to_free, + is_put=True) - return transfer_graph, return_mask, callback, task_end_op_id + op_callback_dict = {} + for op_id in op_node_to_ready: + op_callback_dict[op_id] = partial(self._op_callback, + device_type=op_node_to_ready[op_id][0], + node_to_ready=op_node_to_ready[op_id][1], + ready_length=op_node_to_ready[op_id][2]) + + return transfer_graph, return_mask, callback, op_callback_dict, task_end_op_id def _put_impl_global(self, request_id: int, @@ -738,7 +777,7 @@ def _put_impl_global(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: + layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]: """ transfer pattern: @@ -885,7 +924,10 @@ def _put_impl_global(self, node_to_unlock[DeviceType.REMOTE] = (remote_node_to_unlock, remote_node_to_unlock.size()) skipped_gpu_blocks = len(cpu_matched_blocks) - return transfer_graph, finished_ops_ids, node_to_unlock, {}, len(fragment12_gpu_blocks), skipped_gpu_blocks + return ( + transfer_graph, finished_ops_ids, node_to_unlock, {}, {}, + len(fragment12_gpu_blocks), skipped_gpu_blocks # op_node_to_ready: {} + ) def _put_impl_local(self, request_id: int, @@ -893,7 +935,7 @@ def _put_impl_local(self, block_mask_start: int, block_mask_end: int, gpu_block_ids: np.ndarray, - layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: + layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, Dict, int, int]: """ transfer pattern: @@ -949,6 +991,7 @@ def _put_impl_local(self, transfer_graph = TransferOpGraph() finished_ops_ids = [] + op_node_to_ready = {} op_d2h = TransferOp( graph_id = transfer_graph.graph_id, @@ -980,12 +1023,14 @@ def _put_impl_local(self, fragment12_cpu_blocks, is_ready=False, match_result=cpu_matched_result) + op_node_to_ready[op_d2h.op_id] = (DeviceType.CPU, cpu_node_to_unlock, cpu_node_to_unlock.size()) ssd_node_to_unlock = None if len(fragment2_ssd_blocks) > 0: ssd_node_to_unlock = self.ssd_cache_engine.insert(sequence_meta, fragment2_ssd_blocks, is_ready=False, match_result=ssd_matched_result) + op_node_to_ready[op_h2disk.op_id] = (DeviceType.SSD, ssd_node_to_unlock, ssd_node_to_unlock.size()) node_to_unlock = {} if cpu_node_to_unlock is not None: node_to_unlock[DeviceType.CPU] = (cpu_node_to_unlock, cpu_node_to_unlock.size()) @@ -993,20 +1038,31 @@ def _put_impl_local(self, node_to_unlock[DeviceType.SSD] = (ssd_node_to_unlock, ssd_node_to_unlock.size()) skipped_gpu_blocks = len(cpu_matched_blocks) - return transfer_graph, finished_ops_ids, node_to_unlock, {}, len(fragment12_gpu_blocks), skipped_gpu_blocks + return ( + transfer_graph, finished_ops_ids, node_to_unlock, op_node_to_ready, {}, + len(fragment12_gpu_blocks), skipped_gpu_blocks + ) def _transfer_callback(self, node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]], - buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None) -> None: + buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None, + is_put: bool = False) -> None: if DeviceType.CPU in node_to_unlock: assert self.cpu_cache_engine is not None - self.cpu_cache_engine.cleanup(node_to_unlock[DeviceType.CPU][0], node_to_unlock[DeviceType.CPU][1]) + self.cpu_cache_engine.unlock(node_to_unlock[DeviceType.CPU][0]) + self.cpu_cache_engine.set_ready(node_to_unlock[DeviceType.CPU][0], True, node_to_unlock[DeviceType.CPU][1]) if DeviceType.SSD in node_to_unlock: assert self.ssd_cache_engine is not None - self.ssd_cache_engine.cleanup(node_to_unlock[DeviceType.SSD][0], node_to_unlock[DeviceType.SSD][1]) + self.ssd_cache_engine.unlock(node_to_unlock[DeviceType.SSD][0]) + self.ssd_cache_engine.set_ready(node_to_unlock[DeviceType.SSD][0], True, node_to_unlock[DeviceType.SSD][1]) if DeviceType.REMOTE in node_to_unlock: assert self.remote_cache_engine is not None - self.remote_cache_engine.cleanup(node_to_unlock[DeviceType.REMOTE][0], node_to_unlock[DeviceType.REMOTE][1]) + self.remote_cache_engine.unlock(node_to_unlock[DeviceType.REMOTE][0]) + self.remote_cache_engine.set_ready( + node_to_unlock[DeviceType.REMOTE][0], True, node_to_unlock[DeviceType.REMOTE][1] + ) + if is_put and self.enable_kv_sharing: + self.remote_cache_engine.insert_and_publish(node_to_unlock[DeviceType.REMOTE][0]) if buffer_to_free is not None: if DeviceType.CPU in buffer_to_free: assert self.cpu_cache_engine is not None @@ -1018,6 +1074,17 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + def _op_callback(self, device_type: DeviceType, node_to_ready: RadixNode, ready_length: int) -> None: + if device_type == DeviceType.CPU: + assert self.cpu_cache_engine is not None + self.cpu_cache_engine.set_ready(node_to_ready, True, ready_length) + elif device_type == DeviceType.SSD: + assert self.ssd_cache_engine is not None + self.ssd_cache_engine.set_ready(node_to_ready, True, ready_length) + elif device_type == DeviceType.REMOTE: + assert self.remote_cache_engine is not None + self.remote_cache_engine.set_ready(node_to_ready, True, ready_length) + @nvtx.annotate("Match Prefix Accel", color="yellow") def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]: cpu_matched_result = MatchResultAccel() @@ -1028,7 +1095,7 @@ def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAcc ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result - + @nvtx.annotate("Match Prefix", color="yellow") def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]: cpu_matched_result = MatchResult() @@ -1039,7 +1106,7 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result - + @nvtx.annotate("Match All Prefix accel", color="yellow") def match_all_accel(self, sequence_meta: SequenceMeta, @@ -1062,7 +1129,7 @@ def match_all_accel(self, remote_matched_result = self.remote_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result, remote_matched_result - + @nvtx.annotate("Match All Prefix", color="yellow") def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]: cpu_matched_result = MatchResult() @@ -1101,7 +1168,7 @@ def _get_block_range(self, token_mask: np.ndarray) -> Tuple[int, int]: mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: - return 0, 0 + return len(token_mask)//self.tokens_per_block, len(token_mask)//self.tokens_per_block start_idx = mask_idx[0].item() // self.tokens_per_block end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 diff --git a/flexkv/cache/pcfs_cache_engine.py b/flexkv/cache/pcfs_cache_engine.py index 735c60b7a6..b567e5c65f 100644 --- a/flexkv/cache/pcfs_cache_engine.py +++ b/flexkv/cache/pcfs_cache_engine.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Optional, Tuple, TYPE_CHECKING, List, Dict import numpy as np @@ -26,16 +24,13 @@ def __init__(self, evict_ratio: float, *, # Optional runtime wiring for remote/local trees - local_max_num_blocks: Optional[int] = None, local_lease_ttl_ms: int = 100000, - local_renew_lease_ms: int = 0, - local_refresh_batch_size: int = 256, + local_renew_lease_ms: int = 10000, + local_refresh_batch_size: int = 1000, local_idle_sleep_ms: int = 10, - local_lt_pool_initial_capacity: int = 0, - remote_max_num_blocks: Optional[int] = None, - remote_node_id: int = 0, - remote_lt_pool_initial_capacity: int = 0, - remote_refresh_batch_size: int = 128, + remote_max_num_blocks: int = 4000000, + redis_node_id: int = 0, + remote_refresh_batch_size: int = 1000, remote_rebuild_interval_ms: int = 1000, remote_idle_sleep_ms: int = 10, meta: Optional[RedisMeta] = None) -> None: @@ -57,24 +52,23 @@ def __init__(self, # Local index (authoritative for mutations) self.local_index = LocalRadixTree( tokens_per_block=tokens_per_block, - max_num_blocks=int(local_max_num_blocks or num_total_blocks), + max_num_blocks=int(num_total_blocks), lease_ttl_ms=int(local_lease_ttl_ms), renew_lease_ms=int(local_renew_lease_ms), refresh_batch_size=int(local_refresh_batch_size), idle_sleep_ms=int(local_idle_sleep_ms), - lt_pool_initial_capacity=int(local_lt_pool_initial_capacity), ) # Remote reference index (read-only, built from Redis) self.remote_index = DistributedRadixTree( tokens_per_block=tokens_per_block, - max_num_blocks=int(remote_max_num_blocks or num_total_blocks), - node_id=int(remote_node_id), - lt_pool_initial_capacity=int(remote_lt_pool_initial_capacity), + max_num_blocks=int(remote_max_num_blocks or (num_total_blocks * 10)), + node_id=int(redis_node_id), refresh_batch_size=int(remote_refresh_batch_size), rebuild_interval_ms=int(remote_rebuild_interval_ms), idle_sleep_ms=int(remote_idle_sleep_ms), + lease_renew_ms=int(local_renew_lease_ms), ) # defer channel start to start(meta) @@ -106,6 +100,20 @@ def reset(self) -> None: self.local_index.reset() self.mempool.reset() + def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: + """Match a sequence against the cache index. + + This method provides a simple interface similar to CacheEngine.match(), + delegating to match_all() for consistency. + + Args: + sequence_meta: The sequence metadata to match + + Returns: + MatchResultAccel: The match result + """ + return self.match_all(sequence_meta) + def match_all(self, sequence_meta: SequenceMeta) -> MatchResultAccel: sequence_meta.gen_hashes() block_hashes_t = torch.from_numpy(sequence_meta.block_hashes).to(torch.int64) @@ -121,20 +129,20 @@ def match_all(self, sequence_meta: SequenceMeta) -> MatchResultAccel: chosen = mr_local if local_key >= remote_key else mr_remote # physical blocks - phys_np = torch.tensor(chosen.physical_blocks, dtype=torch.int64).numpy() - #block_node_ids = torch.tensor(chosen.block_node_ids, dtype=torch.uint32).numpy() if chosen.block_node_ids is not None else None - # optional block_node_ids bnids_np = None if chosen is mr_remote: - block_node_ids = torch.tensor(chosen.block_node_ids, dtype=torch.uint32).numpy() if chosen.block_node_ids is not None else None - if block_node_ids is None: - raise Exception("Failed to get block_node_ids") - bnids_np = self.nodeids_to_file_nodeids(block_node_ids, phys_np) + #尝试使用DistributedRadixTree的block_node_ids + #如果检查失败,则使用LocalRadixTree的匹配结果 + nids = chosen.block_node_ids + nps = chosen.physical_blocks + # Convert tensors to numpy views (CPU) if present + if isinstance(nids, torch.Tensor) and nids.numel() > 0: + bnids_np = self.nodeids_to_file_nodeids(nids.cpu().numpy(), nps.cpu().numpy()) + else: + bnids_np = None if bnids_np is None: - raise Exception("Failed to get file_nodeids") - bnids_len = bnids_np.shape[0] - if bnids_len != phys_np.shape[0]: - raise Exception("bnids_len != phys_np.shape[0]") + chosen = mr_local + phys_np = chosen.physical_blocks.cpu().numpy() return MatchResultAccel( num_ready_matched_blocks=int(chosen.num_ready_matched_blocks), num_matched_blocks=int(chosen.num_matched_blocks), @@ -146,38 +154,44 @@ def match_all(self, sequence_meta: SequenceMeta) -> MatchResultAccel: ) def nodeids_to_file_nodeids(self, - block_node_ids: np.ndarray, - physical_blocks: np.ndarray) -> Optional[np.ndarray]: + bnids_np: np.ndarray, + phys: np.ndarray) -> Optional[np.ndarray]: """Convert per-block node ids to per-block PCFS file_nodeids. - For each i: - nid = block_node_ids[i] - file_nodeids_list = self.nid_to_file_nodeids[nid] - remote_file_num = len(file_nodeids_list) - block_id = physical_blocks[i] - f_idx = (block_id // self.round_robin) % remote_file_num - out[i] = file_nodeids_list[f_idx] + Args: + bnids_np: block_node_ids from MatchResultAccel.block_node_ids + phys: physical_blocks from MatchResultAccel.physical_blocks + + Returns: + file_nodeids array with dtype=uint32, or None if conversion fails """ + if bnids_np is None or phys is None: + return None try: - bnids_np = np.asarray(block_node_ids, dtype=np.uint32) - phys_np = np.asarray(physical_blocks, dtype=np.int64) + bnids_np = np.asarray(bnids_np, dtype=np.uint32) + phys_np = np.asarray(phys, dtype=np.int64) except Exception: return None if bnids_np.shape[0] != phys_np.shape[0]: - raise Exception("block_node_ids and physical_blocks must have the same length") - out = np.full(phys_np.shape, fill_value=-1, dtype=np.int64) + return None + out = np.full(phys_np.shape, fill_value=0, dtype=np.uint32) rr = max(1, int(self.round_robin)) + for i in range(bnids_np.shape[0]): nid = int(bnids_np[i]) + #检查节点是否活跃 + if not self._meta.is_node_active(nid): + return None file_list = self.nid_to_file_nodeids.get(nid) + #检查文件列表是否为空 if not file_list: - break + return None remote_file_num = len(file_list) if remote_file_num <= 0: - break + return None block_id = int(phys_np[i]) f_idx = (block_id // rr) % remote_file_num - out[i] = int(file_list[f_idx]) + out[i] = np.uint32(file_list[f_idx]) return out def match_local(self, sequence_meta: SequenceMeta) -> MatchResultAccel: @@ -187,7 +201,7 @@ def match_local(self, sequence_meta: SequenceMeta) -> MatchResultAccel: mr_local = self.local_index.match_prefix(block_hashes_t, int(num_blocks), True) - phys_np = torch.tensor(mr_local.physical_blocks, dtype=torch.int64).numpy() + phys_np = mr_local.physical_blocks.cpu().numpy() return MatchResultAccel( num_ready_matched_blocks=int(mr_local.num_ready_matched_blocks), @@ -230,6 +244,42 @@ def lock_node(self, node: CRadixNode) -> None: else: self.local_index.lock(node) + def unlock(self, node: CRadixNode) -> None: + """Unlock a node in the appropriate index (local or remote). + + Args: + node: The radix node to unlock + """ + if node is None: + return + try: + is_remote_node = bool(node.has_block_node_ids()) + except Exception: + is_remote_node = False + if is_remote_node: + self.remote_index.unlock(node) + else: + self.local_index.unlock(node) + + def set_ready(self, node: CRadixNode, ready: bool = True, ready_length: int = -1) -> None: + """Set the ready state of a node in the appropriate index (local or remote). + + Args: + node: The radix node to set ready state + ready: Whether the node is ready (default: True) + ready_length: The ready length (default: -1, meaning use node's current length) + """ + if node is None: + return + try: + is_remote_node = bool(node.has_block_node_ids()) + except Exception: + is_remote_node = False + if is_remote_node: + self.remote_index.set_ready(node, ready, ready_length) + else: + self.local_index.set_ready(node, ready, ready_length) + def cleanup(self, node: CRadixNode, cleanup_length: int) -> None: if node is None: return @@ -340,16 +390,13 @@ def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Opti num_total_blocks=num_blocks, tokens_per_block=int(cache_config.tokens_per_block), evict_ratio=float(cache_config.evict_ratio), - local_max_num_blocks=num_blocks, local_lease_ttl_ms=int(getattr(cache_config, "lease_ttl_ms", 100000)), - local_renew_lease_ms=int(getattr(cache_config, "renew_lease_ms", 0)), + local_renew_lease_ms=int(getattr(cache_config, "renew_lease_ms", 10)), local_refresh_batch_size=int(getattr(cache_config, "refresh_batch_size", 256)), local_idle_sleep_ms=int(getattr(cache_config, "idle_sleep_ms", 10)), - local_lt_pool_initial_capacity=int(getattr(cache_config, "lt_pool_initial_capacity", 0)), remote_max_num_blocks=num_blocks, - remote_node_id=int(node_id), - remote_lt_pool_initial_capacity=int(getattr(cache_config, "lt_pool_initial_capacity", 0)), - remote_refresh_batch_size=int(getattr(cache_config, "refresh_batch_size", 128)), + redis_node_id=int(node_id), + remote_refresh_batch_size=int(getattr(cache_config, "refresh_batch_size", 256)), remote_rebuild_interval_ms=int(getattr(cache_config, "rebuild_interval_ms", 1000)), remote_idle_sleep_ms=int(getattr(cache_config, "idle_sleep_ms", 10)), meta=meta, diff --git a/flexkv/cache/radix_remote.py b/flexkv/cache/radix_remote.py index d65f45f284..012c195620 100644 --- a/flexkv/cache/radix_remote.py +++ b/flexkv/cache/radix_remote.py @@ -1,14 +1,15 @@ -from __future__ import annotations +from typing import Optional, List, Tuple from typing import Optional, Any from flexkv.cache.redis_meta import RedisMetaChannel as _PyRedisMetaChannel import torch -from c_ext import DistributedRadixTree as _CDistributedRadixTree -from c_ext import LocalRadixTree as _CLocalRadixTree -from c_ext import RedisMetaChannel as _CRedisMetaChannel -from c_ext import CMatchResult -from c_ext import CRadixNode +from flexkv.c_ext import DistributedRadixTree as _CDistributedRadixTree +from flexkv.c_ext import LocalRadixTree as _CLocalRadixTree +from flexkv.c_ext import RedisMetaChannel as _CRedisMetaChannel +from flexkv.c_ext import CMatchResult +from flexkv.c_ext import CRadixNode +from flexkv.c_ext import RefRadixTree class DistributedRadixTree: @@ -16,23 +17,56 @@ def __init__(self, tokens_per_block: int, max_num_blocks: int, node_id: int, - lt_pool_initial_capacity: int = 0, refresh_batch_size: int = 128, rebuild_interval_ms: int = 1000, - idle_sleep_ms: int = 10) -> None: + idle_sleep_ms: int = 10, + lease_renew_ms: int = 5000) -> None: if _CDistributedRadixTree is None: raise ImportError("c_ext.DistributedRadixTree is not available") self._c = _CDistributedRadixTree(int(tokens_per_block), int(max_num_blocks), int(node_id), - int(lt_pool_initial_capacity), int(refresh_batch_size), int(rebuild_interval_ms), int(idle_sleep_ms)) - - def start(self, channel: _PyRedisMetaChannel) -> None: - ch = getattr(channel, "_c", channel) - self._c.start(ch) + int(refresh_batch_size), int(rebuild_interval_ms), int(idle_sleep_ms), int(lease_renew_ms)) + self._started = False + + def __del__(self) -> None: + """析构函数,确保在对象被销毁时调用stop方法""" + try: + if hasattr(self, '_started') and self._started: + self.stop() + except Exception: + # 忽略析构函数中的异常,避免影响程序退出 + pass + + def start(self, channel: _PyRedisMetaChannel) -> bool: + """Start the DistributedRadixTree with the given Redis meta channel. + + Args: + channel: RedisMetaChannel instance to use for Redis communication + + Returns: + bool: True if start was successful, False otherwise + """ + try: + ch = getattr(channel, "_c", channel) + if not self._c.start(ch): + return False + self._started = True + return True + except Exception: + self._started = False + return False def stop(self) -> None: self._c.stop() - - def remote_tree_refresh(self): + self._started = False + + def remote_tree_refresh(self) -> Optional["RefRadixTree"]: + """Refresh the remote tree by loading block metadata from Redis. + + Returns: + Optional[RefRadixTree]: The refreshed reference tree, or None if refresh fails + """ + if not self._started: + raise RuntimeError("DistributedRadixTree must be started before calling remote_tree_refresh") return self._c.remote_tree_refresh() def match_prefix(self, block_hashes: torch.Tensor, num_blocks: int, update_cache_info: bool = True): @@ -58,22 +92,52 @@ def __init__(self, lease_ttl_ms: int = 100000, renew_lease_ms: int = 0, refresh_batch_size: int = 256, - idle_sleep_ms: int = 10, - lt_pool_initial_capacity: int = 0) -> None: + idle_sleep_ms: int = 10) -> None: if _CLocalRadixTree is None: raise ImportError("c_ext.LocalRadixTree is not available") - self._c = _CLocalRadixTree(int(tokens_per_block), int(max_num_blocks), int(lease_ttl_ms), int(renew_lease_ms), int(refresh_batch_size), int(idle_sleep_ms), int(lt_pool_initial_capacity)) + self._c = _CLocalRadixTree(int(tokens_per_block), int(max_num_blocks), int(lease_ttl_ms), int(renew_lease_ms), int(refresh_batch_size), int(idle_sleep_ms)) + self._started = False + + def __del__(self) -> None: + """析构函数,确保在对象被销毁时调用stop方法""" + try: + if hasattr(self, '_started') and self._started: + self.stop() + except Exception: + # 忽略析构函数中的异常,避免影响程序退出 + pass def set_meta_channel(self, channel: _PyRedisMetaChannel) -> None: + """Set the Redis meta channel for this LocalRadixTree. + + Args: + channel: RedisMetaChannel instance to use for Redis communication + """ ch = getattr(channel, "_c", channel) self._c.set_meta_channel(ch) - def start(self, channel: _PyRedisMetaChannel) -> None: - ch = getattr(channel, "_c", channel) - self._c.start(ch) + def start(self, channel: _PyRedisMetaChannel) -> bool: + """Start the LocalRadixTree with the given Redis meta channel. + + Args: + channel: RedisMetaChannel instance to use for Redis communication + + Returns: + bool: True if start was successful, False otherwise + """ + try: + ch = getattr(channel, "_c", channel) + if not self._c.start(ch): + return False + self._started = True + return True + except Exception: + self._started = False + return False def stop(self) -> None: self._c.stop() + self._started = False # Mirror base class methods on LocalRadixTree def match_prefix(self, block_hashes: torch.Tensor, num_blocks: int, update_cache_info: bool = True): @@ -128,7 +192,45 @@ def dec_node_count(self) -> None: def set_ready(self, node: "CRadixNode", ready: bool = True, ready_length: int = -1) -> None: self._c.set_ready(node, bool(ready), int(ready_length)) + def insert(self, physical_block_ids: torch.Tensor, block_hashes: torch.Tensor, + num_blocks: int, num_insert_blocks: int = -1, ready: bool = True, + node: "CRadixNode" = None, num_matched_blocks: int = -1, + last_node_matched_length: int = -1) -> "CRadixNode": + """Insert blocks into the LocalRadixTree. + + Args: + physical_block_ids: Tensor containing physical block IDs + block_hashes: Tensor containing block hash values + num_blocks: Total number of blocks + num_insert_blocks: Number of blocks to insert (-1 for all) + ready: Whether the inserted blocks are ready + node: Last node for continuation (-1 for auto-match) + num_matched_blocks: Number of matched blocks (-1 for auto-match) + last_node_matched_length: Length of last node match (-1 for auto-match) + + Returns: + CRadixNode: The newly inserted node, or None if no insertion occurred + """ + return self._c.insert( + physical_block_ids, block_hashes, int(num_blocks), int(num_insert_blocks), + bool(ready), node, int(num_matched_blocks), int(last_node_matched_length) + ) + + def evict(self, evicted_blocks: torch.Tensor, num_evicted: int) -> int: + """Evict blocks from the LocalRadixTree. + + Args: + evicted_blocks: Tensor to store evicted block IDs + num_evicted: Number of blocks to evict + + Returns: + int: Number of blocks actually evicted + """ + return int(self._c.evict(evicted_blocks, int(num_evicted))) + def insert_and_publish(self, node: "CRadixNode") -> None: + if not self._started: + raise RuntimeError("LocalRadixTree must be started before calling insert_and_publish") self._c.insert_and_publish(node) diff --git a/flexkv/cache/redis_meta.py b/flexkv/cache/redis_meta.py index 1eb78d60a8..02e0f8b615 100644 --- a/flexkv/cache/redis_meta.py +++ b/flexkv/cache/redis_meta.py @@ -1,17 +1,24 @@ -from __future__ import annotations - -from typing import Iterable, List, Tuple +from typing import Iterable, List, Tuple, Optional, Union, Dict from dataclasses import dataclass from enum import IntEnum from uuid import uuid1 +import threading +import time +import atexit +import signal +import sys try: # redis-py import redis as _redis except Exception: # pragma: no cover _redis = None # type: ignore +# Import C++ extensions with explicit error handling try: - from c_ext import RedisMetaChannel as _CRedisMetaChannel, BlockMeta as _CBlockMeta + # Ensure flexkv.c_ext is loaded first + import flexkv.c_ext + from flexkv.c_ext import RedisMetaChannel as _CRedisMetaChannel, BlockMeta as _CBlockMeta except Exception as e: # pragma: no cover + raise ImportError(f"Failed to import C++ extensions: {e}") _CRedisMetaChannel = None # type: ignore _CBlockMeta = None # type: ignore @@ -54,10 +61,10 @@ def from_c(cm: "_CBlockMeta") -> "BlockMeta": class RedisMetaChannel: - def __init__(self, host: str, port: int, node_id: int, local_ip: str, blocks_key: str = "blocks") -> None: + def __init__(self, host: str, port: int, node_id: int, local_ip: str, blocks_key: str = "blocks", password: str = "") -> None: if _CRedisMetaChannel is None: raise ImportError("c_ext.RedisMetaChannel is not available") - self._c = _CRedisMetaChannel(host, int(port), int(node_id), str(local_ip), str(blocks_key)) + self._c = _CRedisMetaChannel(host, int(port), int(node_id), str(local_ip), str(blocks_key), str(password)) def connect(self) -> bool: return bool(self._c.connect()) @@ -73,16 +80,14 @@ def local_ip(self) -> str: def make_block_key(self, node_id: int, hash_value: int) -> str: return str(self._c.make_block_key(int(node_id), int(hash_value))) - def publish_one(self, meta: BlockMeta) -> None: - self._c.publish_one(meta.to_c()) + def publish_one(self, meta: BlockMeta) -> bool: + """发布单个 BlockMeta 到 Redis""" + return self._c.publish_one(meta.to_c()) - def publish_batch(self, metas: Iterable[BlockMeta], batch_size: int = 100) -> None: + def publish_batch(self, metas: Iterable[BlockMeta], batch_size: int = 100) -> bool: + """批量发布 BlockMeta 到 Redis""" cms = [m.to_c() for m in metas] - self._c.publish_batch(cms, int(batch_size)) - - def load(self, max_items: int) -> List[BlockMeta]: - cms = self._c.load(int(max_items)) - return [BlockMeta.from_c(cm) for cm in cms] + return self._c.publish_batch(cms, int(batch_size)) def list_keys(self, pattern: str) -> List[str]: return list(self._c.list_keys(pattern)) @@ -99,56 +104,377 @@ def hmget_field_for_keys(self, keys: Iterable[str], field: str) -> List[str]: def hmget_two_fields_for_keys(self, keys: Iterable[str], f1: str, f2: str) -> List[Tuple[str, str]]: return [(a, b) for a, b in self._c.hmget_two_fields_for_keys(list(keys), f1, f2)] - def update_block_state_batch(self, node_id: int, hashes: Iterable[int], state: int, batch_size: int = 200) -> None: - self._c.update_block_state_batch(int(node_id), list(int(h) for h in hashes), int(state), int(batch_size)) + def renew_node_leases(self, node_id: int, new_lt: int, batch_size: int = 200) -> bool: + """批量更新指定节点的租约时间""" + return self._c.renew_node_leases(int(node_id), int(new_lt), int(batch_size)) + + def update_block_state_batch(self, node_id: int, hashes: Iterable[int], state: int, batch_size: int = 200) -> bool: + """批量更新指定节点的块状态""" + return self._c.update_block_state_batch(int(node_id), list(int(h) for h in hashes), int(state), int(batch_size)) + + def delete_blockmeta_batch(self, node_id: int, hashes: Iterable[int], batch_size: int = 200) -> bool: + """批量删除指定节点的块元数据""" + return self._c.delete_blockmeta_batch(int(node_id), list(int(h) for h in hashes), int(batch_size)) + +class RedisNodeInfo: + """Redis node information management class implemented in Python""" + + def __init__(self, host: str, port: int, local_ip: str, password: str = "") -> None: + if _redis is None: + raise ImportError("redis-py is required: pip install redis") + self.host = host + self.port = int(port) + self.local_ip = str(local_ip) + self.password = str(password) + self.uuid = str(uuid1()) + self._node_id: Optional[int] = None + self._running = False + self._listener_thread: Optional[threading.Thread] = None + self.current_node_id_set: set = set() + self._client: Optional[_redis.Redis] = None + self._sub_client: Optional[_redis.Redis] = None + self._cleanup_done = False + + # 注册退出时的清理函数 + atexit.register(self._cleanup_on_exit) + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) + + def __del__(self) -> None: + """析构函数,确保在对象被销毁时进行清理""" + try: + self._cleanup_on_exit() + except Exception: + # 忽略析构函数中的异常,避免影响程序退出 + pass + + def _get_client(self) -> _redis.Redis: + """Get Redis client with connection settings""" + return _redis.Redis( + host=self.host, + port=self.port, + password=self.password if self.password else None, + decode_responses=True, + health_check_interval=30, + socket_keepalive=True + ) + + def connect(self) -> bool: + """Connect to Redis and start listener thread""" + try: + self._client = self._get_client() + # Test connection + self._client.ping() + + # Start listener thread + self._running = True + self._listener_thread = threading.Thread( + target=self._listener_worker, + name="redis-node-info-listener", + daemon=True + ) + self._listener_thread.start() + + return True + except Exception: + return False + + def disconnect(self) -> None: + """Disconnect from Redis and stop listener thread""" + self._running = False + if self._listener_thread and self._listener_thread.is_alive(): + self._listener_thread.join(timeout=2.0) + self._listener_thread = None + + if self._client: + self._client.close() + self._client = None + if self._sub_client: + self._sub_client.close() + self._sub_client = None + + def _signal_handler(self, signum: int, frame) -> None: + """Signal handler for graceful shutdown""" + print(f"收到信号 {signum},开始清理 RedisNodeInfo...") + self._cleanup() + sys.exit(0) + + def _cleanup_on_exit(self) -> None: + """Cleanup function registered with atexit""" + self._cleanup() + + def _cleanup(self) -> None: + """Internal cleanup method""" + if self._cleanup_done: + return + + self._cleanup_done = True + + try: + # 注销节点 + if self._node_id is not None: + self.unregister_node() + + # 断开连接 + self.disconnect() + except Exception: + # 忽略清理过程中的异常 + pass + + def register_node(self) -> Optional[int]: + """Register a new node and get node_id""" + if not self._client: + return None + + try: + # Atomically increment global:node_id to get new node_id + node_id = self._client.incr("global:node_id") + self._node_id = node_id + + # Store node information in node:node_id hash + node_key = f"node:{node_id}" + self._client.hset(node_key, mapping={ + "node_id": str(node_id), + "local_ip": self.local_ip, + "uuid": self.uuid, + "status": "active", + "timestamp": str(int(time.time())) + }) + + # Publish node update event + self._client.publish("flexkv_node_id_updated", str(node_id)) + + return node_id + except Exception: + return None + + def unregister_node(self) -> bool: + """Unregister current node""" + if not self._client or self._node_id is None: + return False + + try: + # Delete node:node_id key + node_key = f"node:{self._node_id}" + self._client.delete(node_key) + + # Publish node update event + self._client.publish("flexkv_node_id_updated", str(self._node_id)) + + self._node_id = None + return True + except Exception: + return False + + @property + def node_id(self) -> Optional[int]: + """Get current node_id""" + return self._node_id + + def get_uuid(self) -> str: + """Get the UUID of this node""" + return self.uuid + + def get_active_node_ids(self) -> List[int]: + """Get all active node IDs - 无锁RCU读取""" + return list(self.current_node_id_set) + + def is_node_active(self, node_id: int) -> bool: + """Check if a node_id is active - 无锁RCU检查""" + return node_id in self.current_node_id_set + + def _listener_worker(self) -> None: + """Background thread that listens for node updates""" + backoff = 0.5 + while self._running: + try: + # Create a separate connection for pub/sub + self._sub_client = self._get_client() + + # Subscribe to flexkv_node_id_updated channel + pubsub = self._sub_client.pubsub() + pubsub.subscribe("flexkv_node_id_updated") + + # Listen for messages with blocking read + for message in pubsub.listen(): + if not self._running: + break + + if message["type"] == "message" and message["channel"] == "flexkv_node_id_updated": + # Scan active nodes when we receive an update + self.scan_active_nodes() + + # Normal exit from listen loop + break + + except Exception: + # Network/reconnection exception: exponential backoff + time.sleep(backoff) + backoff = min(backoff * 2, 5.0) + finally: + if self._sub_client: + try: + self._sub_client.close() + except Exception: + pass + self._sub_client = None + + def scan_active_nodes(self) -> None: + """Scan Redis for active node keys and update current_node_id_set + + This method can be called externally to manually refresh the active nodes list. + It uses SCAN to avoid blocking Redis server. + """ + if not self._client: + return + + try: + new_active_nodes = set() + cursor = 0 + + while True: + cursor, keys = self._client.scan(cursor=cursor, match="node:*", count=100) + + for key in keys: + if key.startswith("node:"): + try: + node_id = int(key[5:]) # Remove "node:" prefix + new_active_nodes.add(node_id) + except (ValueError, IndexError): + # Skip invalid node IDs + continue + + if cursor == 0: + break + + # 无锁RCU切换:原子性赋值 + self.current_node_id_set = new_active_nodes + + except Exception: + # If scan fails, continue with current active nodes + pass - def delete_blockmeta_batch(self, node_id: int, hashes: Iterable[int], batch_size: int = 200) -> None: - self._c.delete_blockmeta_batch(int(node_id), list(int(h) for h in hashes), int(batch_size)) class RedisMeta: - def __init__(self, host: str, port: int, password: str | None = None, local_ip: str = "127.0.0.1", decode_responses: bool = True) -> None: + def __init__(self, host: str, port: int, password: Optional[str] = None, local_ip: str = "127.0.0.1", decode_responses: bool = True) -> None: if _redis is None: # pragma: no cover raise ImportError("redis-py is required: pip install redis") self.host = host self.port = int(port) self.local_ip = str(local_ip) - self._uuid = str(uuid1()) self.db = 0 self.password = password self.decode_responses = bool(decode_responses) - self._node_id: int | None = None + self._node_id: Optional[int] = None + + # 初始化状态管理 + self._init_lock = threading.Lock() + self._initialized = False + self._init_error: Optional[Exception] = None + + # 创建 RedisNodeInfo 对象 + self.nodeinfo = RedisNodeInfo(host, port, local_ip, password or "") + # 通过 nodeinfo 获取 UUID + self._uuid = self.nodeinfo.get_uuid() def _client(self): return _redis.Redis(host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=self.decode_responses) - def init_meta(self) -> int: - r = self._client() - node_id = int(r.incr("global:node_id")) - r.hset(f"node:{node_id}", mapping={"ip": self.local_ip, "uuid": self._uuid}) - self._node_id = node_id - return node_id + def init_meta(self) -> Optional[int]: + """Initialize Redis metadata. This method is thread-safe and can only be called once per instance. + + Returns: + Optional[int]: The registered node ID, or None if initialization fails + + Raises: + RuntimeError: If initialization fails or has already been called + """ + with self._init_lock: + # 检查是否已经初始化 + if self._initialized: + if self._init_error: + raise self._init_error + return self._node_id + + try: + # 连接 RedisNodeInfo + if not self.nodeinfo.connect(): + raise RuntimeError("Failed to connect to Redis via RedisNodeInfo") + + # 注册节点 + node_id = self.nodeinfo.register_node() + if node_id is None: + raise RuntimeError("Failed to register node via RedisNodeInfo") + + self._node_id = node_id + #初始化阶段,先扫描一次活跃节点 + self.nodeinfo.scan_active_nodes() + + # 标记为已初始化 + self._initialized = True + + return node_id + + except Exception as e: + # 记录初始化错误 + self._init_error = e + return None def get_node_id(self) -> int: if self._node_id is None: raise RuntimeError("node_id is not registered yet. Call init_meta() first.") return int(self._node_id) + + def is_initialized(self) -> bool: + """Check if RedisMeta has been initialized. + + Returns: + bool: True if initialized, False otherwise + """ + with self._init_lock: + return self._initialized + + def get_init_error(self) -> Optional[Exception]: + """Get the initialization error if any. + + Returns: + Optional[Exception]: The initialization error, or None if no error + """ + with self._init_lock: + return self._init_error def get_redis_meta_channel(self, blocks_key: str = "blocks") -> "RedisMetaChannel": nid = self.get_node_id() - return RedisMetaChannel(self.host, int(self.port), int(nid), self.local_ip, str(blocks_key)) - - def unregister_node(self, node_id: int | None = None) -> None: - r = self._client() - nid = int(node_id) if node_id is not None else (self._node_id if self._node_id is not None else -1) - if nid >= 0: - r.delete(f"node:{nid}") + # Avoid passing string "None" when no password is set + pwd = "" if (self.password is None or str(self.password).lower() == "none") else str(self.password) + channel = RedisMetaChannel(self.host, int(self.port), int(nid), self.local_ip, str(blocks_key), pwd) + if not channel.connect(): + raise RuntimeError("Failed to connect to Redis") + return channel + + def unregister_node(self, node_id: Optional[int] = None) -> None: + # 使用 RedisNodeInfo 注销节点 + if self.nodeinfo: + self.nodeinfo.unregister_node() self._node_id = None def get_uuid(self) -> str: return self._uuid + + def get_active_node_ids(self) -> List[int]: + """获取所有活跃节点ID列表""" + if self.nodeinfo: + return self.nodeinfo.get_active_node_ids() + return [] + + def is_node_active(self, node_id: int) -> bool: + """检查指定节点是否活跃""" + if self.nodeinfo: + return self.nodeinfo.is_node_active(node_id) + return False - def add_node_ids(self, node_ids: Iterable[int | str]) -> int: + def add_node_ids(self, node_ids: Iterable[Union[int, str]]) -> int: # Append a list of pcfs file node ids to Redis list key pcfs: nid = self.get_node_id() values = [str(v) for v in node_ids] @@ -188,7 +514,7 @@ def regist_buffer(self, mrs: Iterable[object]) -> int: pipe.execute() return processed - def unregist_buffer(self, buffer_ptr: int | str) -> bool: + def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool: """Unregister a previously registered RDMA memory region by buffer_ptr. Looks up key buffer:: and deletes it if present. @@ -230,7 +556,7 @@ def get_node_meta(self, node_id: int) -> dict: data = r.hgetall(key) if not data: return {} - out: dict[str, int | str] = {} + out: Dict[str, Union[int, str]] = {} nid = data.get("node_id") out["node_id"] = int(nid) if nid is not None and nid != "" else int(node_id) out["addr"] = data.get("addr", "") @@ -247,34 +573,40 @@ def unregist_node_meta(self, node_id: int) -> bool: return bool(r.delete(key)) - def load_pcfs_file_nodeids(self) -> dict[int, list[int]]: + def load_pcfs_file_nodeids(self) -> Dict[int, List[int]]: """Load all PCFS file node IDs grouped by node id from Redis. + - Uses SCAN instead of KEYS to avoid blocking Redis server - Scans keys matching pattern "pcfs:*" (each is a list for a node's file node IDs) - For each key, fetches the list via LRANGE and converts elements to ints - Returns dict: { node_id: [file_nodeid, ...], ... } """ r = self._client() - result: dict[int, list[int]] = {} + result: Dict[int, List[int]] = {} try: - keys = r.keys("pcfs:*") + # Use SCAN instead of KEYS to avoid blocking + cursor = 0 + while True: + cursor, keys = r.scan(cursor=cursor, match="pcfs:*", count=100) + for key in keys: + try: + if not isinstance(key, str): + key = str(key) + if not key.startswith("pcfs:"): + continue + nid_part = key.split(":", 1)[1] + node_id = int(nid_part) + except Exception: + continue + try: + values = r.lrange(key, 0, -1) + file_nodeids = [int(v) for v in values] + except Exception: + file_nodeids = [] + result[node_id] = file_nodeids + + if cursor == 0: + break except Exception: return result - for key in keys: - try: - if not isinstance(key, str): - key = str(key) - if not key.startswith("pcfs:"): - continue - nid_part = key.split(":", 1)[1] - node_id = int(nid_part) - except Exception: - continue - try: - values = r.lrange(key, 0, -1) - file_nodeids = [int(v) for v in values] - except Exception: - file_nodeids = [] - result[node_id] = file_nodeids return result - diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 81152ebf4a..b3b7486b9b 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -33,14 +33,13 @@ class CacheConfig: enable_remote: bool = False enable_kv_sharing: bool = False use_gds: bool = False - use_pinned_memory: bool = False index_accel: bool = False # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE + cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE # mempool capacity configs num_cpu_blocks: int = 1000000 @@ -68,7 +67,6 @@ class CacheConfig: remote_config_custom: Optional[Dict[str, Any]] = None # KV sharing / distributed radix tree tunables - lt_pool_initial_capacity: int = 10000000 refresh_batch_size: int = 128 rebuild_interval_ms: int = 1000 idle_sleep_ms: int = 10 @@ -87,6 +85,17 @@ class CacheConfig: trace_max_file_size_mb: int = 100 trace_max_files: int = 5 trace_flush_interval_ms: int = 1000 - + #evict ratio evict_ratio: float = 0.0 + + def __post_init__(self): + layout_fields = ['gpu_kv_layout_type', + 'cpu_kv_layout_type', + 'ssd_kv_layout_type', + 'remote_kv_layout_type'] + for field in layout_fields: + value = getattr(self, field) + if isinstance(value, str): + setattr(self, field, KVCacheLayoutType[value.upper()]) + diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index 0f79cf869b..a522c5549a 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -16,14 +16,18 @@ def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") - formatter = logging.Formatter( - fmt=_FORMAT, - datefmt=_DATE_FORMAT, + has_console_handler = any( + isinstance(handler, logging.StreamHandler) + for handler in self.logger.handlers ) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - self.logger.addHandler(console_handler) + if not has_console_handler: + formatter = logging.Formatter( + fmt=_FORMAT, + datefmt=_DATE_FORMAT, + ) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) self.set_level(debug_level) diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index 5013308b40..12d84b612f 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -19,10 +19,17 @@ class TensorSharedHandle: rebuild_args: Tuple[Any] device: torch.device - def __init__(self, tensor: torch.Tensor): + def __init__(self, tensor: torch.Tensor, device_id: int = -1): if not tensor.is_cuda: raise ValueError("Only support CUDA tensor sharing") - self.rebuild_func, self.rebuild_args, self.device = self._export_tensor_handle(tensor) + self.rebuild_func, self.rebuild_args, tensor_device_id = self._export_tensor_handle(tensor) + if device_id == -1: + self.device = tensor_device_id + else: + self.device = torch.device(f"cuda:{device_id}") + tmp_list = list(self.rebuild_args) + tmp_list[6] = device_id + self.rebuild_args = tuple(tmp_list) def get_tensor(self) -> torch.Tensor: tensor = self._import_tensor_handle(self.rebuild_func, self.rebuild_args, self.device) @@ -40,10 +47,10 @@ def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, Tuple[Any], t def _import_tensor_handle(rebuild_func: Callable, rebuild_args: Tuple[Any], device: torch.device) -> torch.Tensor: try: tensor = rebuild_func(*rebuild_args) - assert isinstance(tensor, torch.Tensor) if tensor.device != device: + flexkv_logger.warning(f"Tensor device {tensor.device} is not the same as the target device {device}") tensor = tensor.to(device) return tensor diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 92668ae3f8..dff6b1ff3a 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -121,7 +121,6 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_kv_layout_type": str(cache_config.ssd_kv_layout_type), "remote_kv_layout_type": str(cache_config.remote_kv_layout_type), "use_gds": cache_config.use_gds, - "use_pinned_memory": cache_config.use_pinned_memory, "remote_cache_size_mode": cache_config.remote_cache_size_mode, "num_cpu_blocks": cache_config.num_cpu_blocks, "num_ssd_blocks": cache_config.num_ssd_blocks, diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 91229b7834..248d658b2f 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -1,7 +1,7 @@ import threading from dataclasses import dataclass, field from enum import Enum -from typing import ClassVar, List, Set, Dict +from typing import ClassVar, List, Set, Dict, Optional import numpy as np @@ -48,6 +48,7 @@ class TransferOp: dst_block_ids: np.ndarray layer_id: int = 0 layer_granularity: int = -1 + src_block_node_ids: Optional[np.ndarray] = None # this will change dynamically as transfer ops executed predecessors: Set[int] = field(default_factory=set) # this will keep the full info diff --git a/flexkv/common/type.py b/flexkv/common/type.py index 0a79a7f7c8..c7185a2798 100644 --- a/flexkv/common/type.py +++ b/flexkv/common/type.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dataclasses import dataclass, field from typing import Optional import numpy as np diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py index 76f27f5b34..b437657c3e 100644 --- a/flexkv/integration/config.py +++ b/flexkv/integration/config.py @@ -31,7 +31,7 @@ class FlexKVConfig: dtype: torch.dtype = None use_mla: bool = False tp_size: int = 1 - + dp_size: int = 1 # log config num_log_interval_requests: int = 200 @@ -65,4 +65,5 @@ def post_init_from_vllm_config( self.head_size = vllm_config.model_config.get_head_size() self.dtype = vllm_config.model_config.dtype self.use_mla = vllm_config.model_config.is_deepseek_mla - self.tp_size = vllm_config.parallel_config.tensor_parallel_size \ No newline at end of file + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.dp_size = vllm_config.parallel_config.data_parallel_size \ No newline at end of file diff --git a/flexkv/integration/vllm/README.md b/flexkv/integration/vllm/README.md deleted file mode 100644 index 136f8b8682..0000000000 --- a/flexkv/integration/vllm/README.md +++ /dev/null @@ -1,44 +0,0 @@ -Use flexkv on vllm v0.10.1.1 - -1. apply patch -```bash -cd vllm -git apply 0001-add-flexkv-connector.patch -``` - -2. offline test -```bash -python examples/offline_inference/prefix_caching_flexkv.py -``` - -3. online serving -```bash -# generate config -cat < ./flexkv_config.json -{ - "server_recv_port": "ipc:///tmp/flexkv_test", - "cache_config": { - "enable_cpu": true, - "num_cpu_blocks": 10240, - "use_pinned_memory": true - }, - "num_log_interval_requests": 200 -} -EOF -export FLEXKV_CONFIG_PATH="./flexkv_config.json" - -VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ - --tensor-parallel-size 8 \ - --trust-remote-code \ - --port 30001 \ - --max-num-seqs 128 \ - --max-num-batched-tokens 8192 \ - --max_model_len 8192 \ - --max-seq-len-to-capture 8192 \ - --gpu-memory-utilization 0.8 \ - --enable-chunked-prefill \ - --enable-prefix-caching \ - --kv-transfer-config \ - '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' - -``` \ No newline at end of file diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 951259474b..7bec7141fd 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -46,29 +46,29 @@ class FlexKVResponse: class FlexKVTask(ABC): task_id: int = 0 request: "Request" = 0 - + # slot mapping slot_mapping: Optional[np.ndarray] = None - + # timer match_start_time: float = 0 match_end_time: float = 0 task_launch_time: float = 0 task_finished_time: float = 0 - + @property def match_cost(self) -> float: return (self.match_end_time - self.match_start_time) - + @property def task_execute_cost(self) -> float: return (self.task_finished_time - self.task_launch_time) - + @property @abstractmethod def task_type(self) -> str: ... - + def __str__(self): return (f"FlexKVTask(task_id={self.task_id}, " f"request={self.request.request_id}, " @@ -80,11 +80,11 @@ def __str__(self): class FlexKVGetTask(FlexKVTask): num_computed_tokens: int num_new_matched_tokens: int - + @property def task_type(self) -> str: return "get" - + def __str__(self): return (f"FlexKVGetTask(task_id={self.task_id}, " f"request={self.request.request_id}, " @@ -93,16 +93,16 @@ def __str__(self): f"match_cost {self.match_cost*1000:.2f} ms, " f"task execute cost {self.task_execute_cost*1000:.2f} ms)") - + @dataclass(kw_only=True) class FlexKVPutTask(FlexKVTask): num_matched_tokens: int num_unmatched_tokens: int - + @property def task_type(self) -> str: return "put" - + def __str__(self): return (f"FlexKVPutTask(task_id={self.task_id}, " f"request={self.request.request_id}, " @@ -110,17 +110,19 @@ def __str__(self): f"num_unmatched_tokens={self.num_unmatched_tokens}, " f"match_cost {self.match_cost*1000:.2f} ms, " f"task execute cost {self.task_execute_cost*1000:.2f} ms)") - + class FlexKVSchedulerConnector: def __init__( self, - flexkv_config: FlexKVConfig + flexkv_config: FlexKVConfig, + dp_rank: int = 0, ): logger.info(f"Start init FlexKVSchedulerConnector with {flexkv_config}") self.flexkv_config = flexkv_config self.server_recv_port = flexkv_config.server_recv_port self.tp_size = flexkv_config.tp_size + self.dp_size = flexkv_config.dp_size self.block_size = flexkv_config.block_size self.model_config = ModelConfig( num_layers=flexkv_config.num_layers, @@ -129,6 +131,7 @@ def __init__( use_mla=flexkv_config.use_mla, dtype=flexkv_config.dtype, tp_size=flexkv_config.tp_size, + dp_size=flexkv_config.dp_size, ) if "tokens_per_block" in flexkv_config.cache_config: assert flexkv_config.cache_config.pop("tokens_per_block") == flexkv_config.block_size @@ -136,12 +139,13 @@ def __init__( tokens_per_block=flexkv_config.block_size, **flexkv_config.cache_config, ) - self.flexkv_manager = KVManager(model_config=self.model_config, + self.flexkv_manager = KVManager(model_config=self.model_config, cache_config=self.cache_config, - gpu_register_port=flexkv_config.server_recv_port) + gpu_register_port=flexkv_config.server_recv_port, + dp_client_id=dp_rank) self.flexkv_manager.start() # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) - + # request_id -> task_id self.req_id_to_task_dict: dict[str, int] = {} # launched but unfinished tasks @@ -150,32 +154,32 @@ def __init__( # unlaunched tasks self.tasks_to_launch: dict[int, FlexKVTask] = {} self.tasks_to_cancel: dict[int, FlexKVTask] = {} - + self.flexkv_stats = FlexKVStats(flexkv_config.num_log_interval_requests) while not self.is_ready(): - logger.info(f"Waiting for flexkv init...") + logger.info("Waiting for flexkv init...") time.sleep(5) - logger.info(f"Finish init FlexKVSchedulerConnector") - + logger.info("Finish init FlexKVSchedulerConnector") + def is_ready( self, ) -> bool: " Ask flexkv is ready " return self.flexkv_manager.is_ready() - + def shutdown(self) -> None: self.flexkv_manager.shutdown() - + @property def dp_client_id(self) -> int: return self.flexkv_manager.dp_client_id - + #################### #### Get Method #### - #################### - + #################### + def get_num_new_matched_tokens( self, request: "Request", @@ -188,24 +192,24 @@ def get_num_new_matched_tokens( which means not need to transfer from flexkv. Returns: - tuple[int, bool]: A tuple containing two integer values representing the - number of new matched tokens and whether it is necessary + tuple[int, bool]: A tuple containing two integer values representing the + number of new matched tokens and whether it is necessary to get the new matched blocks from flexkv. """ - task_id, num_new_matched_tokens = self._get_match(request=request, + task_id, num_new_matched_tokens = self._get_match(request=request, num_computed_tokens=num_computed_tokens) - self.flexkv_stats.record_get(num_prompt_tokens=request.num_prompt_tokens, + self.flexkv_stats.record_get(num_prompt_tokens=request.num_tokens, num_gpu_matched_tokens=num_computed_tokens, num_flexkv_matched_tokens=num_new_matched_tokens) - if not self._need_to_get(num_prompt_tokens=request.num_prompt_tokens, + if not self._need_to_get(num_prompt_tokens=request.num_tokens, num_computed_tokens=num_computed_tokens, num_new_matched_tokens=num_new_matched_tokens): return 0, False - + return num_new_matched_tokens, True - - + + def _get_match( self, request: "Request", @@ -222,22 +226,23 @@ def _get_match( the task_id and number of new matched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_get = (cdiv(request.num_prompt_tokens, self.block_size)-1)*self.block_size - token_ids = request.prompt_token_ids[:num_tokens_to_get] - - assert num_computed_tokens <= num_tokens_to_get + num_tokens_to_get = (request.num_tokens//self.block_size)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_get] + + assert num_computed_tokens <= num_tokens_to_get, ( + f"{num_computed_tokens=} must less equal to {num_tokens_to_get=}") assert num_computed_tokens % self.block_size == 0 - + if num_tokens_to_get == num_computed_tokens: return -1, 0 - + np_token_ids = np.array(token_ids) np_token_mask = np.ones_like(np_token_ids, dtype=bool) np_token_mask[:num_computed_tokens] = False task_id, matched_mask = self.flexkv_manager.get_match(token_ids=np_token_ids, token_mask=np_token_mask) num_new_matched_tokens = matched_mask.sum().item() - + # Auto cancel if not call update_state_after_alloc() match_end_time = time.perf_counter() logger.debug(f"Get match cost {(match_end_time-match_start_time)*1000:.2f} ms.") @@ -249,11 +254,11 @@ def _get_match( num_new_matched_tokens=num_new_matched_tokens, match_start_time=match_start_time, match_end_time=match_end_time) - + logger.debug(f"FlexKV create get task: {self.tasks_to_cancel[task_id]}") - + return task_id, num_new_matched_tokens - + def _need_to_get( self, num_prompt_tokens: int, @@ -264,21 +269,21 @@ def _need_to_get( Determine whether it is necessary to get the new matched blocks from flexkv. """ return num_new_matched_tokens > 0 - + def update_state_after_alloc( self, request: "Request", - blocks: "KVCacheBlocks", + blocks: "KVCacheBlocks", num_new_matched_tokens: int, ) -> None: """ Compute slot mapping and prepare to launch task. Only call after get_num_new_matched_tokens(). - + Args: request: Request to get. blocks: All blocks of the request. - num_new_matched_tokens: Number of new matched tokens returned by + num_new_matched_tokens: Number of new matched tokens returned by get_num_new_matched_tokens(). Returns: @@ -290,27 +295,27 @@ def update_state_after_alloc( task_id = self.req_id_to_task_dict[request.request_id] task: FlexKVGetTask = self.tasks_to_cancel.pop(task_id) self.tasks_to_launch[task_id] = task - + # compute slot_mapping num_computed_blocks = task.num_computed_tokens // self.block_size num_blocks_to_get = num_new_matched_tokens // self.block_size all_block_ids = blocks.get_block_ids()[0] block_ids_to_get = all_block_ids[num_computed_blocks:num_computed_blocks+num_blocks_to_get] task.slot_mapping = np.array(block_ids_to_get).repeat(self.block_size)*self.block_size - + def wait_for_all_get_tasks(self) -> list[FlexKVResponse]: """ Blocking wait for all get tasks. - + Returns: list[FlexKVResponse]: Responses of all get tasks. """ return self._blocking_waiting_for_tasks(self.get_tasks) - + #################### #### Put Method #### #################### - + def request_finished( self, request: "Request", @@ -327,34 +332,34 @@ def request_finished( # Task not finished, can't free blocks if request.request_id in self.req_id_to_task_dict: return True - + # Abnormal finished, don't put if not (request.is_finished() and request.get_finished_reason() < 2): return False - + task_id, num_matched_tokens, num_unmatched_tokens = self._put_match(request=request) - + self.flexkv_stats.record_put(num_all_tokens=request.num_tokens, num_unmatched_tokens=num_unmatched_tokens) - + if not self._need_to_put(num_all_tokens=request.num_tokens, num_matched_tokens=num_matched_tokens, num_unmatched_tokens=num_unmatched_tokens): return False - + # prepare to launch task task: FlexKVPutTask = self.tasks_to_cancel.pop(task_id) self.tasks_to_launch[task_id] = task - + # compute slot mapping # num_blocks_to_put = (num_matched_tokens+num_unmatched_tokens) // self.block_size num_matched_blocks = num_matched_tokens // self.block_size num_unmatched_tokens = num_unmatched_tokens // self.block_size block_ids_to_put = block_ids[num_matched_blocks:num_matched_blocks+num_unmatched_tokens] task.slot_mapping = np.array(block_ids_to_put).repeat(self.block_size)*self.block_size - + return True - + def _put_match( self, request: "Request" @@ -368,22 +373,22 @@ def _put_match( the task_id, number of matched tokens and number of unmatched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_put = (cdiv(request.num_tokens, self.block_size)-1)*self.block_size + num_tokens_to_put = (cdiv(request.num_tokens+1, self.block_size)-1)*self.block_size token_ids = request.all_token_ids[:num_tokens_to_put] if num_tokens_to_put == 0: return -1, 0, 0 - + np_token_ids = np.array(token_ids) task_id, unmatched_mask = self.flexkv_manager.put_match(token_ids=np_token_ids) - + num_unmatched_tokens = unmatched_mask.sum().item() num_matched_tokens = num_tokens_to_put - num_unmatched_tokens - + # Auto cancel if not need to put. match_end_time = time.perf_counter() logger.debug(f"Put match cost {(match_end_time-match_start_time)*1000:.2f} ms.") - + if num_unmatched_tokens > 0: self.req_id_to_task_dict[request.request_id] = task_id self.tasks_to_cancel[task_id] = FlexKVPutTask(task_id=task_id, @@ -393,9 +398,9 @@ def _put_match( match_start_time=match_start_time, match_end_time=match_end_time) logger.debug(f"FlexKV create put task: {self.tasks_to_cancel[task_id]}") - + return task_id, num_matched_tokens, num_unmatched_tokens - + def _need_to_put( self, num_all_tokens: int, @@ -406,23 +411,23 @@ def _need_to_put( Determine whether it is necessary to put the unmatched blocks from flexkv. """ return num_unmatched_tokens > 0 - + def wait_for_all_put_tasks(self) -> list[FlexKVResponse]: """ Blocking wait for all put tasks. - + Returns: list[FlexKVResponse]: Responses of all put tasks. """ return self._blocking_waiting_for_tasks(self.put_tasks) - + ####################### #### Common Method #### ####################### - + def cancel_tasks(self) -> None: """ - Cancel tasks in self.cancel_tasks. + Cancel tasks in self.cancel_tasks. Call before launch_tasks() to delete req_id in self.req_id_to_task_dict """ # TODO: check if this method is inproc. @@ -433,7 +438,7 @@ def cancel_tasks(self) -> None: logger.info(f"FlexKV Cancel task: {task}") self.flexkv_manager.cancel(task_ids=list(self.tasks_to_cancel.keys())) self.tasks_to_cancel.clear() - + def launch_tasks(self) -> None: """ Launch tasks in self.unlaunched_tasks @@ -443,7 +448,7 @@ def launch_tasks(self) -> None: task_launch_time = time.perf_counter() task_ids: list[int] = [] slot_mappings: list[np.ndarray] = [] - + for task_id, task in self.tasks_to_launch.items(): logger.info(f"FlexKV Launch task: {task}") task.task_launch_time = task_launch_time @@ -456,11 +461,11 @@ def launch_tasks(self) -> None: self.flexkv_manager.launch(task_ids=task_ids, slot_mappings=slot_mappings) self.tasks_to_launch.clear() - + def query_finished_task(self) -> tuple[set[str], set[str]]: """ Get response of finished task. - + Returns: list[FlexKVResponse]: Responses of finished tasks. """ @@ -493,17 +498,17 @@ def query_finished_task(self) -> tuple[set[str], set[str]]: # request=task.request, success=success)) self.flexkv_stats.record_faild(num_failed_requests=num_failed_tasks) return finished_sending, finished_recving - + def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[FlexKVResponse]: """ Blocking wait for tasks in task_dict. - + Returns: list[FlexKVResponse]: Responses of all tasks in task_dict. """ if len(task_dict) == 0: return [] - + task_ids = list(task_dict.keys()) response_from_manager = self.flexkv_manager.wait(task_ids=task_ids) task_finished_time = time.perf_counter() @@ -516,24 +521,25 @@ def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[ logger.info(f"{task} finished successfully.") else: logger.error(f"{task} failed, status: {response.status}.") - responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, request=task.request, success=success)) return responses_to_return - - + + class FlexKVWorkerConnector: def __init__( self, flexkv_config: FlexKVConfig, + dp_client_id: int, ): - current_device_id = torch.cuda.current_device() + current_device_id = torch.cuda.current_device() + dp_client_id * flexkv_config.tp_size self.flexkv_config = flexkv_config - logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.server_recv_port}") - self.tp_client = KVTPClient(flexkv_config.server_recv_port, 0, current_device_id) - logger.info(f"Finish init FlexKVWorkerConnector") + logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.server_recv_port}, dp_client_id: {dp_client_id}") + self.tp_client = KVTPClient(flexkv_config.server_recv_port, dp_client_id, current_device_id) + logger.info("Finish init FlexKVWorkerConnector") def register_to_server(self, kv_caches: dict[str, torch.Tensor]): - logger.info(f"Start register kv_caches") + logger.info("Start register kv_caches") gpu_blocks = list(kv_caches.values()) num_layer = len(kv_caches) if self.flexkv_config.use_mla: @@ -560,19 +566,20 @@ def register_to_server(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.flexkv_config.use_mla, ) self.tp_client.register_to_server(gpu_blocks, gpu_layout) - logger.info(f"Finish register kv_caches") + logger.info("Finish register kv_caches") + - class FlexKVConnectorV1Impl: def __init__(self, vllm_config: "VllmConfig", role: "KVConnectorRole"): self.role = role flexkv_config = FlexKVConfig.from_env() flexkv_config.post_init_from_vllm_config(vllm_config) + dp_rank = vllm_config.parallel_config.data_parallel_rank if role == KVConnectorRole.SCHEDULER: - self.connector = FlexKVSchedulerConnector(flexkv_config) + self.connector = FlexKVSchedulerConnector(flexkv_config, dp_rank) elif role == KVConnectorRole.WORKER: - self.connector = FlexKVWorkerConnector(flexkv_config) + self.connector = FlexKVWorkerConnector(flexkv_config, dp_rank) else: raise ValueError(f"Unrecognized KVConnectorRole: {role}.") @@ -595,9 +602,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ pass @@ -606,7 +613,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -617,13 +624,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: """ - Start saving the a layer of KV cache from vLLM's paged buffer + Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -677,14 +684,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self.connector.get_num_new_matched_tokens( @@ -742,4 +749,4 @@ def request_finished( Optional KVTransferParams to be included in the request outputs returned by the engine. """ - return self.connector.request_finished(request, block_ids), None \ No newline at end of file + return self.connector.request_finished(request, block_ids), None diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 8e85ee9959..cac6355da7 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -37,18 +37,22 @@ def __init__(self, flexkv_logger.info(f"{cache_config = }") self.model_config = model_config self.cache_config = cache_config - self.gpu_register_port = gpu_register_port - self.server_recv_port = server_recv_port + self.gpu_register_port = gpu_register_port if gpu_register_port is not None else "ipc:///tmp/flexkv_test_gpu_register" + self.server_recv_port = server_recv_port if server_recv_port is not None else "ipc:///tmp/flexkv_test_server" self.server_client_mode = model_config.dp_size > 1 self.dp_client_id = dp_client_id flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") if self.server_client_mode: # server should only be created once but kvmanager will init in every dp rank. if dp_client_id == 0: - self.server_handle = KVServer.create_server(model_config, - cache_config, - gpu_register_port, - server_recv_port) + # You can control child process environment variables here + # Example: child_env = {"CUDA_VISIBLE_DEVICES": "0"} + # Example: inherit_env = False # to not inherit parent env + self.server_handle = KVServer.create_server(model_config=model_config, + cache_config=cache_config, + gpu_register_port=gpu_register_port, + server_recv_port=self.server_recv_port, + inherit_env=False) else: self.server_handle = None diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 3d661a66bf..aeea4bc5b2 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -56,6 +56,7 @@ class KVTask: graph: TransferOpGraph return_mask: np.ndarray callback: Optional[Callable] + op_callback_dict: Dict[int, Callable] def is_completed(self) -> bool: return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED] @@ -127,7 +128,7 @@ def create_get_task(self, ) -> None: if task_id in self.tasks: raise ValueError(f"Task ID {task_id} already exists") - graph, return_mask, callback, task_end_op_id = self.cache_engine.get(task_id, + graph, return_mask, callback, op_callback_dict, task_end_op_id = self.cache_engine.get(task_id, token_ids, token_mask, slot_mapping, @@ -146,7 +147,8 @@ def create_get_task(self, dp_id=dp_id, graph=graph, return_mask=return_mask, - callback=callback) + callback=callback, + op_callback_dict=op_callback_dict) self.graph_to_task[graph.graph_id] = task_id @@ -160,7 +162,7 @@ def create_put_task(self, ) -> None: if task_id in self.tasks: raise ValueError(f"Task ID {task_id} already exists") - graph, return_mask, callback, task_end_op_id = self.cache_engine.put(task_id, + graph, return_mask, callback, op_callback_dict, task_end_op_id = self.cache_engine.put(task_id, token_ids, token_mask, slot_mapping, @@ -178,7 +180,8 @@ def create_put_task(self, dp_id=dp_id, graph=graph, return_mask=return_mask, - callback=callback) + callback=callback, + op_callback_dict=op_callback_dict) self.graph_to_task[graph.graph_id] = task_id def _launch_task(self, task_id: int) -> None: @@ -204,6 +207,8 @@ def _update_tasks(self, timeout: float = 0.001) -> None: self._mark_completed(task_id) elif completed_op_id == task.task_end_op_id: self.tasks[task_id].task_end_op_finished = True + if completed_op_id in task.op_callback_dict: + task.op_callback_dict[completed_op_id]() def _cancel_task(self, task_id: int) -> None: task = self.tasks[task_id] diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 1643af98a3..0376035157 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -236,10 +236,7 @@ def register_to_server( handles = [] for _, tensor in enumerate(kv_caches): - if tensor.device.index != self.device_id: - raise ValueError(f"All tensors must be on specified device: {self.device_id}") - - handle = TensorSharedHandle(tensor) + handle = TensorSharedHandle(tensor, self.device_id) handles.append(handle) register_req = RegisterTPClientRequest( diff --git a/flexkv/server/server.py b/flexkv/server/server.py index daf25ff7ca..65c4f5892b 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Union import tempfile import zmq @@ -10,6 +10,8 @@ import multiprocessing as mp import socket import os +import subprocess +import textwrap from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger @@ -64,11 +66,13 @@ def register_dp_client( context: zmq.Context, client_recv_port: str, tp_size: int = 1, + client_id: Optional[int] = None, ) -> int: - if len(self.free_client_ids) == 0: - flexkv_logger.error("Client full. DP client registration failed.") - raise - client_id = self.free_client_ids.popleft() + if client_id is None: + if len(self.free_client_ids) == 0: + flexkv_logger.error("Client full. DP client registration failed.") + raise + client_id = self.free_client_ids.popleft() send_to_client = get_zmq_socket( context, zmq.SocketType.PUSH, client_recv_port, False ) @@ -81,7 +85,7 @@ def register_dp_client( flexkv_logger.info(f"DP client {client_id} registered successfully") return client_id - + def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -103,19 +107,8 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: return False class KVServerHandle: - def __init__(self, process: mp.Process): + def __init__(self, process: Union[mp.Process, 'subprocess.Popen']): self.process = process - - def shutdown(self) -> None: - self.process.join(timeout=5) - if self.process.is_alive(): - flexkv_logger.info("force terminate the server process") - self.process.terminate() - self.process.join() - - def __del__(self) -> None: - if self.process.is_alive(): - self.shutdown() class KVServer: def __init__( @@ -137,7 +130,7 @@ def __init__( self.req_counter = 0 self._is_ready = False self._running = False - + # Request handler dispatch table self.request_handlers = { StartRequest: self._handle_start_request, @@ -162,35 +155,71 @@ def start_server(self) -> None: self._is_ready = True @staticmethod - def _server_process(model_config: ModelConfig, + def _server_process(model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: str, server_recv_port: str) -> None: - + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) server.run() - + @classmethod def create_server(cls, model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: str, - server_recv_port: Optional[str] = None) -> 'KVServerHandle': - #if server_recv_port is None: - # server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" #TODO unify this - + server_recv_port: Optional[str] = None, + child_env: Optional[dict] = None, + inherit_env: bool = True) -> 'KVServerHandle': + # Set spawn method for CUDA compatibility - try: + with contextlib.suppress(RuntimeError): mp.set_start_method("spawn") - except RuntimeError: - # If already set, just continue - pass - process = mp.Process(target=cls._server_process, - args=(model_config, cache_config, gpu_register_port, server_recv_port)) - process.start() - flexkv_logger.info(f"KVServer process started, PID: {process.pid}") - return KVServerHandle(process) + # Prepare environment variables for child process + if child_env is not None or not inherit_env: + # Use subprocess for better environment control + import subprocess + import pickle + import sys + + # Prepare environment + if inherit_env: + env = os.environ.copy() + if child_env: + env.update(child_env) + else: + env = child_env or {} + + # Serialize arguments + args_data = pickle.dumps((model_config, cache_config, gpu_register_port, server_recv_port)) + + # Start subprocess + flexkv_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + server_script = textwrap.dedent(f''' + import pickle + import sys + sys.path.insert(0, "{flexkv_root}") + from flexkv.server.server import KVServer + + args_data = {args_data!r} + model_config, cache_config, gpu_register_port, server_recv_port = pickle.loads(args_data) + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) + server.run() + ''').strip() + process = subprocess.Popen([ + sys.executable, '-c', server_script + ], env=env) + + flexkv_logger.info(f"KVServer subprocess started, PID: {process.pid}") + return KVServerHandle(process) + else: + # Use multiprocessing as before + process = mp.Process(target=cls._server_process, + args=(model_config, cache_config, gpu_register_port, server_recv_port)) + process.start() + flexkv_logger.info(f"KVServer process started, PID: {process.pid}") + return KVServerHandle(process) def run(self) -> None: """Main server loop""" @@ -211,18 +240,18 @@ def run(self) -> None: try: flexkv_logger.info("start waiting for req") req = self.recv_from_client.recv_pyobj() - flexkv_logger.info(f"recv req: {type(req)}") + flexkv_logger.info(f"recv req: {type(req)} from DP client {req.dp_client_id}") # Use dispatch table for request handling req_type = type(req) handler = self.request_handlers.get(req_type) - + if handler is None: raise TypeError(f"Unrecognized RequestType: {req_type}") - + # Call the corresponding handler method handler(req) - + # If the request is a shutdown request, exit the loop if req_type == ShutdownRequest: break @@ -246,7 +275,7 @@ def _verify_model_config( return True # Request Handler Methods - + def _handle_start_request(self, req: StartRequest) -> None: """Handle start request""" flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}") @@ -257,9 +286,9 @@ def _handle_register_dp_client_request(self, req: RegisterDPClientRequest) -> No client_id = self.client_manager.register_dp_client( self.context, req.client_recv_port, - req.model_config.tp_size + req.model_config.tp_size, + req.dp_client_id, ) - flexkv_logger.info(f"DP client {client_id} registered successfully") def _handle_is_ready_request(self, req: IsReadyRequest) -> None: """Handle ready state check request""" @@ -317,7 +346,7 @@ def _handle_put_match_request(self, req: PutMatchRequest) -> None: def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: """Handle LaunchTask request""" self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings) - + def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" self.kv_task_engine.cancel_tasks(req.task_ids) @@ -381,7 +410,6 @@ def __del__(self) -> None: enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks,) diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index 7cd38156e0..ed683e6505 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -95,7 +95,6 @@ def allocate(cls, layout: KVCacheLayout, dtype: torch.dtype, **kwargs: Any) -> StorageHandle: - pin_memory = kwargs.get("pin_memory", True) total_size = layout.get_total_elements() # although the kv layout may have multiple dimensions, we only have one-dim CPU tensor flexkv_logger.info(f"CPU allocate total_size: {2 * total_size/1024/1024/1024} GB") @@ -103,7 +102,7 @@ def allocate(cls, size=(total_size,), dtype=dtype, device="cpu", - pin_memory=pin_memory, + pin_memory=False, ) return StorageHandle( handle_type=AccessHandleType.TENSOR, diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 0762b0062d..0d48fe6230 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -35,7 +35,6 @@ def __init__(self, device_type=DeviceType.CPU, layout=self._cpu_layout, dtype=self._model_config.dtype, - pin_memory=self._cache_config.use_pinned_memory, ) if self._cache_config.enable_ssd: if not self._cache_config.ssd_kv_layout_type == self._cpu_layout.type: diff --git a/pyproject.toml b/pyproject.toml index db54deba99..7568848bc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,9 @@ [build-system] requires = [ - "setuptools>=61", - "torch==2.3.0" + #"setuptools>=61", + #"torch==2.3.0" + "setuptools>=40.0.0", + "torch>=1.10.0" ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index e857eb9cbe..4c1ec7be69 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -setuptools>=61 -torch>=2.3.0 -nvtx==0.2.11 +setuptools>=40.0.0 +torch>=1.10.0 +# nvtx==0.2.11 # Skip nvtx for now due to compatibility issues Cython>=3.0.10 -pytest==8.4.0 -pytest-benchmark==5.1.0 +pytest>=6.0.0 +pytest-benchmark>=3.0.0 expiring-dict==1.1.2 diff --git a/setup.py b/setup.py index d11ffc5143..af8996b482 100755 --- a/setup.py +++ b/setup.py @@ -7,6 +7,9 @@ from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension +def get_version(): + with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: + return f.read().strip() build_dir = "build" os.makedirs(build_dir, exist_ok=True) @@ -21,11 +24,15 @@ # Define C++ extensions cpp_sources = [ "csrc/bindings.cpp", - "csrc/transfer.cu", + # "csrc/transfer.cu", # Skip CUDA file for now "csrc/hash.cpp", "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", "csrc/radix_tree.cpp", + "csrc/distributed_radix_tree.cpp", + "csrc/local_radix_tree.cpp", + "csrc/redis_meta_channel.cpp", + "csrc/lease_meta_mempool.cpp", ] hpp_sources = [ @@ -35,9 +42,10 @@ "csrc/radix_tree.h", ] -extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] +#extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] +extra_link_args = ["-lxxhash", "-lpthread", "-lrt", "-luring", "-lhiredis"] extra_compile_args = ["-std=c++17"] -include_dirs = [os.path.join(build_dir, "include")] +include_dirs = [os.path.abspath(os.path.join(build_dir, "include"))] # Add rpath to find libraries at runtime lib_dir = os.path.join(build_dir, "lib") @@ -130,7 +138,7 @@ def copy_shared_libraries(self): setup( name="flexkv", description="A global KV-Cache manager for LLM inference", - version="0.1.0", + version=get_version(), packages=find_packages(exclude=("benchmarks", "csrc", "examples", "tests")), package_data={ "flexkv": ["*.so", "lib/*.so", "lib/*.so.*"], @@ -145,5 +153,6 @@ def copy_shared_libraries(self): build_temp=os.path.join(build_dir, "temp"), # Temporary build files ) }, - python_requires=">=3.8", + #python_requires=">=3.8", + python_requires=">=3.6", ) diff --git a/setup_env.sh b/setup_env.sh new file mode 100755 index 0000000000..61071b47ba --- /dev/null +++ b/setup_env.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# FlexKV 环境设置脚本 +# 设置 FlexKV 运行所需的环境变量 + +# 获取脚本所在目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +FLEXKV_ROOT="$(dirname "$SCRIPT_DIR")" + +# 设置库路径 +export LD_LIBRARY_PATH="$FLEXKV_ROOT/build/lib:$FLEXKV_ROOT/flexkv/lib:$LD_LIBRARY_PATH" + +# 设置 Python 路径 +export PYTHONPATH="$FLEXKV_ROOT:$PYTHONPATH" + +echo "FlexKV 环境变量已设置:" +echo " LD_LIBRARY_PATH: $LD_LIBRARY_PATH" +echo " PYTHONPATH: $PYTHONPATH" +echo "" +echo "现在可以运行 FlexKV 测试程序了!" diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index 3ddc0ce810..fad6a20ea9 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -113,7 +113,6 @@ def parse_config_event(self, event: Dict[str, Any]): ssd_kv_layout_type=self._parse_layout_type(cache_config_data['ssd_kv_layout_type']), remote_kv_layout_type=self._parse_layout_type(cache_config_data['remote_kv_layout_type']), use_gds=cache_config_data['use_gds'], - use_pinned_memory=False,#cache_config_data['use_pinned_memory'], # for local test remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], num_cpu_blocks=cache_config_data['num_cpu_blocks'], num_ssd_blocks=cache_config_data['num_ssd_blocks'], diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index e224d4305b..70a12ffeb0 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -176,7 +176,8 @@ def test_take_and_recycle(cache_engine: CacheEngine): cache_engine.lock_node(radixnode) with pytest.raises(NotEnoughSpaceError): cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) - cache_engine.cleanup(radixnode, radixnode.size()) + cache_engine.unlock(radixnode) + cache_engine.set_ready(radixnode, True, radixnode.size()) physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) assert physical_blocks.shape == (num_total_blocks, ) @@ -227,11 +228,14 @@ def test_cleanup(cache_engine: CacheEngine): assert cache_engine.index.total_unready_blocks() == total_insert_blocks assert cache_engine.index.total_ready_blocks() == 0 - cache_engine.cleanup(radixnode2, radixnode2_size) + cache_engine.unlock(radixnode2) + cache_engine.set_ready(radixnode2, True, radixnode2_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 - cache_engine.cleanup(radixnode1, radixnode1_size) + cache_engine.unlock(radixnode1) + cache_engine.set_ready(radixnode1, True, radixnode1_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 - cache_engine.cleanup(radixnode0, radixnode0_size) + cache_engine.unlock(radixnode0) + cache_engine.set_ready(radixnode0, True, radixnode0_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py index 15fef43ec3..3c0e2b3cbe 100644 --- a/tests/test_cache_engine_accel.py +++ b/tests/test_cache_engine_accel.py @@ -172,7 +172,8 @@ def test_take_and_recycle(cache_engine: CacheEngineAccel): cache_engine.lock_node(radixnode) with pytest.raises(NotEnoughSpaceError): cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) - cache_engine.cleanup(radixnode, radixnode.size()) + cache_engine.unlock(radixnode) + cache_engine.set_ready(radixnode, True, radixnode.size()) physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) assert physical_blocks.shape == (num_total_blocks, ) @@ -222,11 +223,14 @@ def test_cleanup(cache_engine: CacheEngineAccel): assert cache_engine.index.total_unready_blocks() == total_insert_blocks assert cache_engine.index.total_ready_blocks() == 0 - cache_engine.cleanup(radixnode2, radixnode2_size) + cache_engine.unlock(radixnode2) + cache_engine.set_ready(radixnode2, True, radixnode2_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 - cache_engine.cleanup(radixnode1, radixnode1_size) + cache_engine.unlock(radixnode1) + cache_engine.set_ready(radixnode1, True, radixnode1_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 - cache_engine.cleanup(radixnode0, radixnode0_size) + cache_engine.unlock(radixnode0) + cache_engine.set_ready(radixnode0, True, radixnode0_size) assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 diff --git a/tests/test_dis_radixtree_basic.py b/tests/test_dis_radixtree_basic.py new file mode 100644 index 0000000000..af42e47005 --- /dev/null +++ b/tests/test_dis_radixtree_basic.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 +""" +FlexKV DistributedRadixTree 基本功能测试 +验证 FlexKV 编译安装成功,DistributedRadixTree 基本功能正常 +""" + +import sys +import os +import torch +import time +import threading +from typing import Optional, List, Dict, Any + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def check_environment(): + """检查运行环境""" + print("=== 检查运行环境 ===") + + # 检查 Python 版本 + python_version = sys.version_info + print(f"Python 版本: {python_version.major}.{python_version.minor}.{python_version.micro}") + + # 检查 PyTorch 版本 + try: + torch_version = torch.__version__ + print(f"PyTorch 版本: {torch_version}") + except Exception as e: + print(f"[ERROR] PyTorch 版本检查失败: {e}") + return False + + # 检查 CUDA 可用性 + cuda_available = torch.cuda.is_available() + print(f"CUDA 可用性: {cuda_available}") + + # 检查环境变量 + ld_library_path = os.environ.get('LD_LIBRARY_PATH', '') + if 'flexkv' in ld_library_path.lower() or 'build/lib' in ld_library_path: + print("[OK] LD_LIBRARY_PATH 包含 FlexKV 库路径") + else: + print("[WARN] LD_LIBRARY_PATH 可能缺少 FlexKV 库路径") + + return True + +def test_imports(): + """测试模块导入""" + print("\n=== 测试模块导入 ===") + try: + import flexkv + print("[OK] flexkv 模块导入成功") + + from flexkv import c_ext + print("[OK] flexkv.c_ext 模块导入成功") + + from flexkv.cache.radix_remote import LocalRadixTree, DistributedRadixTree + print("[OK] LocalRadixTree 和 DistributedRadixTree 导入成功") + + from flexkv.cache.pcfs_cache_engine import PCFSCacheEngine + print("[OK] PCFSCacheEngine 导入成功") + + from flexkv.cache.redis_meta import RedisMeta, RedisMetaChannel + print("[OK] RedisMeta 和 RedisMetaChannel 导入成功") + + return True + except ImportError as e: + print(f"[ERROR] 导入失败: {e}") + return False + +def test_distributed_radix_tree(): + """测试 DistributedRadixTree 功能""" + print("\n=== 测试 DistributedRadixTree 功能 ===") + try: + from flexkv.cache.radix_remote import DistributedRadixTree + + # 创建 DistributedRadixTree 实例 + drt = DistributedRadixTree( + tokens_per_block=4, + max_num_blocks=1000, + node_id=1, + refresh_batch_size=32, + rebuild_interval_ms=1, + idle_sleep_ms=1, + lease_renew_ms=1 + ) + print("[OK] DistributedRadixTree 创建成功") + + # 测试基本操作 + test_tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long) + + # 测试查找(DistributedRadixTree 没有 insert 方法) + try: + match_result = drt.match_prefix(test_tokens, 0, False) + print(f"[OK] 查找操作成功: {match_result}") + except Exception as e: + print(f"[WARN] 查找操作失败: {e}") + + # 测试其他方法 + try: + is_empty = drt.is_empty() + print(f"[OK] is_empty 方法成功: {is_empty}") + except Exception as e: + print(f"[WARN] is_empty 方法失败: {e}") + + return True + except Exception as e: + print(f"[ERROR] DistributedRadixTree 测试失败: {e}") + return False + +def test_local_radix_tree(): + """测试 LocalRadixTree 功能""" + print("\n=== 测试 LocalRadixTree 功能 ===") + try: + from flexkv.cache.radix_remote import LocalRadixTree + + # 创建 LocalRadixTree 实例 + lrt = LocalRadixTree( + tokens_per_block=4, + max_num_blocks=1000, + lease_ttl_ms=10000, + renew_lease_ms=2, + refresh_batch_size=64, + idle_sleep_ms=1 + ) + print("[OK] LocalRadixTree 创建成功") + + # 测试基本操作 + test_tokens = torch.tensor([5, 6, 7, 8], dtype=torch.long) + + # 测试查找 + try: + lrt.insert(test_tokens, test_tokens, 4, 4, True, None, -1, -1) + match_result = lrt.match_prefix(test_tokens, 0, False) + if match_result is None: + print(f"[WARN] 查找操作失败: {match_result}") + return False + elif match_result.num_matched_blocks == 0: + print(f"[WARN] 查找操作失败: {match_result}") + return False + print(f"[OK] 查找操作成功: {match_result}") + except Exception as e: + print(f"[WARN] 查找操作失败: {e}") + return False + + # 测试其他方法 + try: + is_empty = lrt.is_empty() + if is_empty: + print(f"[WARN] 查找操作失败: {is_empty}") + return False + total_nodes = lrt.total_node_num() + if total_nodes == 0: + print(f"[WARN] 查找操作失败: {total_nodes}") + return False + print(f"[OK] is_empty 方法成功: {is_empty}") + print(f"[OK] total_node_num 方法成功: {total_nodes}") + except Exception as e: + print(f"[WARN] 其他方法失败: {e}") + return False + + return True + except Exception as e: + print(f"[ERROR] LocalRadixTree 测试失败: {e}") + return False + +def test_pcfs_cache_engine(): + """测试 PCFSCacheEngine 功能""" + print("\n=== 测试 PCFSCacheEngine 功能 ===") + try: + from flexkv.cache.pcfs_cache_engine import PCFSCacheEngine + + # 创建 PCFSCacheEngine 实例 + cache_engine = PCFSCacheEngine( + num_total_blocks=1000, + tokens_per_block=4, + evict_ratio=0.1 + ) + print("[OK] PCFSCacheEngine 创建成功") + + # 测试基本属性 + print(f" - num_total_blocks: {cache_engine.num_total_blocks}") + print(f" - tokens_per_block: {cache_engine.tokens_per_block}") + print(f" - evict_ratio: {cache_engine.evict_ratio}") + + return True + except Exception as e: + print(f"[ERROR] PCFSCacheEngine 测试失败: {e}") + return False + +def test_distributed_radix_tree_integration(): + """测试分布式RadixTree集成功能""" + print("\n=== 测试分布式RadixTree集成功能 ===") + try: + from flexkv.cache.radix_remote import LocalRadixTree, DistributedRadixTree + from flexkv.cache.redis_meta import RedisMeta + + # 步骤1: 创建两个RedisMeta实例 + print("步骤1: 创建RedisMeta实例...") + redis_meta1 = RedisMeta(host="127.0.0.1", port=6379, local_ip="127.0.0.1") + redis_meta2 = RedisMeta(host="127.0.0.1", port=6379, local_ip="127.0.0.2") + print(f"[OK] RedisMeta实例创建成功 - Meta1(127.0.0.1), Meta2(127.0.0.2)") + + # 步骤2: 初始化RedisMeta + print("步骤2: 初始化RedisMeta...") + node_id1 = redis_meta1.init_meta() + if node_id1 is None: + raise RuntimeError("RedisMeta1初始化失败,无法获取node_id") + + node_id2 = redis_meta2.init_meta() + if node_id2 is None: + raise RuntimeError("RedisMeta2初始化失败,无法获取node_id") + + print(f"[OK] RedisMeta初始化成功 - Node1: {node_id1}, Node2: {node_id2}") + + # 步骤3: 创建2个LocalRadixTree和2个DistributedRadixTree实例 + print("步骤3: 创建RadixTree实例...") + + # 创建LocalRadixTree实例 + local_tree1 = LocalRadixTree( + tokens_per_block=4, + max_num_blocks=1000, + lease_ttl_ms=10000, + renew_lease_ms=2, + refresh_batch_size=64, + idle_sleep_ms=1 + ) + + local_tree2 = LocalRadixTree( + tokens_per_block=4, + max_num_blocks=1000, + lease_ttl_ms=10000, + renew_lease_ms=2, + refresh_batch_size=64, + idle_sleep_ms=1 + ) + + # 创建DistributedRadixTree实例 + distributed_tree1 = DistributedRadixTree( + tokens_per_block=4, + max_num_blocks=1000, + node_id=node_id1, + refresh_batch_size=32, + rebuild_interval_ms=1, + idle_sleep_ms=1, + lease_renew_ms=2 + ) + + distributed_tree2 = DistributedRadixTree( + tokens_per_block=4, + max_num_blocks=1000, + node_id=node_id2, + refresh_batch_size=32, + rebuild_interval_ms=1, + idle_sleep_ms=1, + lease_renew_ms=2 + ) + print("[OK] RadixTree实例创建成功 - 2个LocalRadixTree, 2个DistributedRadixTree") + + # 步骤4: 获取RedisMetaChannel并同时启动所有RadixTree + print("步骤4: 启动所有RadixTree...") + channel1 = redis_meta1.get_redis_meta_channel() + if not channel1: + raise RuntimeError("RedisMeta1获取RedisMetaChannel失败") + channel2 = redis_meta2.get_redis_meta_channel() + if not channel2: + raise RuntimeError("RedisMeta2获取RedisMetaChannel失败") + + # 同时启动所有RadixTree + if not local_tree1.start(channel1): + raise RuntimeError("LocalRadixTree1启动失败") + if not local_tree2.start(channel2): + raise RuntimeError("LocalRadixTree2启动失败") + if not distributed_tree1.start(channel1): + raise RuntimeError("DistributedRadixTree1启动失败") + if not distributed_tree2.start(channel2): + raise RuntimeError("DistributedRadixTree2启动失败") + print("[OK] 所有RadixTree启动成功") + + # 步骤5: 创建测试数据 - 每个包含4个block + print("步骤5: 创建测试数据...") + + # LocalRadixTree1的测试数据 - 4个block + physical_blocks1 = torch.tensor([1001, 1002, 1003, 1004], dtype=torch.long) + block_hashes1 = torch.tensor([2001, 2002, 2003, 2004], dtype=torch.long) + + # LocalRadixTree2的测试数据 - 4个block + physical_blocks2 = torch.tensor([2001, 2002, 2003, 2004], dtype=torch.long) + block_hashes2 = torch.tensor([3001, 3002, 3003, 3004], dtype=torch.long) + + print(f"[OK] 测试数据创建成功 - 每个包含4个block") + print(f" - LocalRadixTree1: physical_blocks={physical_blocks1.tolist()}, hashes={block_hashes1.tolist()}") + print(f" - LocalRadixTree2: physical_blocks={physical_blocks2.tolist()}, hashes={block_hashes2.tolist()}") + + # 步骤6: 向LocalRadixTree添加节点 + print("步骤6: 向LocalRadixTree添加节点...") + + # LocalRadixTree1使用insert方法 + print(" - LocalRadixTree1使用insert方法...") + node1 = local_tree1.insert( + physical_blocks1, block_hashes1, 4, 4, True, None, -1, -1 + ) + if node1 is not None: + print(f" [OK] LocalRadixTree1 insert成功,插入节点包含4个block") + local_tree1.insert_and_publish(node1) + else: + print(" [WARN] LocalRadixTree1 insert返回None") + + # LocalRadixTree2使用insert_and_publish方法 + print(" - LocalRadixTree2使用insert_and_publish方法...") + node2 = local_tree2.insert( + physical_blocks2, block_hashes2, 4, 4, True, None, -1, -1 + ) + if node2 is not None: + local_tree2.insert_and_publish(node2) + print(f" [OK] LocalRadixTree2 insert_and_publish成功,插入并发布节点包含4个block") + else: + print(" [WARN] LocalRadixTree2 insert返回None") + + # 步骤7: 等待数据同步 + print("步骤7: 等待数据同步...") + time.sleep(3) # 增加等待时间确保数据同步 + print("[OK] 数据同步等待完成") + + # 步骤8: 使用DistributedRadixTree加载Redis数据 + print("步骤8: 使用DistributedRadixTree加载Redis数据...") + + # DistributedRadixTree1刷新 + print(" - DistributedRadixTree1执行remote_tree_refresh...") + refresh_result1 = distributed_tree1.remote_tree_refresh() + if refresh_result1 is None: + raise RuntimeError("DistributedRadixTree1 remote_tree_refresh失败") + print(f" [OK] DistributedRadixTree1 remote_tree_refresh完成") + + # DistributedRadixTree2刷新 + print(" - DistributedRadixTree2执行remote_tree_refresh...") + refresh_result2 = distributed_tree2.remote_tree_refresh() + if refresh_result2 is None: + raise RuntimeError("DistributedRadixTree2 remote_tree_refresh失败") + print(f" [OK] DistributedRadixTree2 remote_tree_refresh完成") + + # 步骤9: 详细验证结果 + print("步骤9: 验证结果...") + + # 验证LocalRadixTree状态 + print("LocalRadixTree状态:") + lrt1_nodes = local_tree1.total_node_num() + if lrt1_nodes == 0: + raise RuntimeError("LocalRadixTree1 total_node_num失败") + lrt1_cached = local_tree1.total_cached_blocks() + if lrt1_cached == 0: + raise RuntimeError("LocalRadixTree1 total_cached_blocks失败") + lrt1_ready = local_tree1.total_ready_blocks() + if lrt1_ready == 0: + raise RuntimeError("LocalRadixTree1 total_ready_blocks失败") + lrt1_unready = local_tree1.total_unready_blocks() + if lrt1_unready == 0: + raise RuntimeError("LocalRadixTree1 total_unready_blocks失败") + + lrt2_nodes = local_tree2.total_node_num() + if lrt2_nodes == 0: + raise RuntimeError("LocalRadixTree2 total_node_num失败") + lrt2_cached = local_tree2.total_cached_blocks() + if lrt2_cached == 0: + raise RuntimeError("LocalRadixTree2 total_cached_blocks失败") + lrt2_ready = local_tree2.total_ready_blocks() + if lrt2_ready == 0: + raise RuntimeError("LocalRadixTree2 total_ready_blocks失败") + lrt2_unready = local_tree2.total_unready_blocks() + if lrt2_unready == 0: + raise RuntimeError("LocalRadixTree2 total_unready_blocks失败") + + print(f" - LocalRadixTree1: 节点数={lrt1_nodes}, 缓存块数={lrt1_cached}, 就绪块数={lrt1_ready}, 未就绪块数={lrt1_unready}") + print(f" - LocalRadixTree2: 节点数={lrt2_nodes}, 缓存块数={lrt2_cached}, 就绪块数={lrt2_ready}, 未就绪块数={lrt2_unready}") + + # 验证DistributedRadixTree状态 + print("DistributedRadixTree状态:") + drt1_empty = distributed_tree1.is_empty() + if drt1_empty: + raise RuntimeError("DistributedRadixTree1 is_empty失败") + drt2_empty = distributed_tree2.is_empty() + if drt2_empty: + raise RuntimeError("DistributedRadixTree2 is_empty失败") + print(f" - DistributedRadixTree1: 是否为空={drt1_empty}") + print(f" - DistributedRadixTree2: 是否为空={drt2_empty}") + + # 测试前缀匹配功能 + print("测试前缀匹配功能...") + test_hashes1 = torch.tensor([2001, 2002, 2003, 2004], dtype=torch.long) + test_hashes2 = torch.tensor([3001, 3002, 3003, 3004], dtype=torch.long) + + # 在LocalRadixTree中测试匹配 + match_result1 = local_tree1.match_prefix(test_hashes1, 4, True) + if match_result1 is None: + raise RuntimeError("LocalRadixTree1 match_prefix失败") + if match_result1.num_matched_blocks == 0: + raise RuntimeError("LocalRadixTree1 match_prefix失败") + match_result2 = local_tree2.match_prefix(test_hashes2, 4, True) + if match_result2 is None: + raise RuntimeError("LocalRadixTree2 match_prefix失败") + if match_result2.num_matched_blocks == 0: + raise RuntimeError("LocalRadixTree2 match_prefix失败") + print(f" - LocalRadixTree1匹配结果: 匹配块数={match_result1.num_matched_blocks if match_result1 else 0}") + print(f" - LocalRadixTree2匹配结果: 匹配块数={match_result2.num_matched_blocks if match_result2 else 0}") + + # 在DistributedRadixTree中测试匹配 + drt_match1 = distributed_tree1.match_prefix(test_hashes2, 4, True) + if drt_match1 is None: + raise RuntimeError("DistributedRadixTree1 match_prefix失败") + if drt_match1.num_matched_blocks == 0: + raise RuntimeError("DistributedRadixTree1 match_prefix失败") + drt_match2 = distributed_tree2.match_prefix(test_hashes1, 4, True) + if drt_match2 is None: + raise RuntimeError("DistributedRadixTree2 match_prefix失败") + if drt_match2.num_matched_blocks == 0: + raise RuntimeError("DistributedRadixTree2 match_prefix失败") + print(f" - DistributedRadixTree1匹配结果: 匹配块数={drt_match1.num_matched_blocks if drt_match1 else 0}") + print(f" - DistributedRadixTree2匹配结果: 匹配块数={drt_match2.num_matched_blocks if drt_match2 else 0}") + + # 步骤10: 性能测试 + print("步骤10: 性能测试...") + + # 使用已创建的local_tree1和distributed_tree1进行性能测试 + num_operations = 100 + print(f" - 开始性能测试,执行{num_operations}次操作...") + + # LocalRadixTree1性能测试 + print(" - LocalRadixTree1性能测试...") + lrt_start_time = time.time() + + for i in range(num_operations): + test_physical = torch.tensor([i % 100, (i + 1) % 100, (i + 2) % 100, (i + 3) % 100], dtype=torch.long) + test_hashes = torch.tensor([(i + 1000) % 2000, (i + 1001) % 2000, (i + 1002) % 2000, (i + 1003) % 2000], dtype=torch.long) + try: + # 测试insert性能 + lrt_node = local_tree1.insert(test_physical, test_hashes, 4, 4, True, None, -1, -1) + # 测试match_prefix性能 + lrt_match = local_tree1.match_prefix(test_hashes, 4, True) + except Exception as e: + print(f" [WARN] LocalRadixTree1性能测试中操作失败: {e}") + break + + lrt_end_time = time.time() + lrt_duration = lrt_end_time - lrt_start_time + lrt_ops_per_sec = num_operations / lrt_duration if lrt_duration > 0 else 0 + + print(f" [OK] LocalRadixTree1性能测试完成:") + print(f" - 操作数量: {num_operations}") + print(f" - 总时间: {lrt_duration:.3f} 秒") + print(f" - 每秒操作数: {lrt_ops_per_sec:.0f}") + + # DistributedRadixTree1性能测试 + print(" - DistributedRadixTree1性能测试...") + drt_start_time = time.time() + + for i in range(num_operations): + test_hashes = torch.tensor([(i + 2000) % 3000, (i + 2001) % 3000, (i + 2002) % 3000, (i + 2003) % 3000], dtype=torch.long) + try: + # 测试match_prefix性能 + drt_match = distributed_tree1.match_prefix(test_hashes, 4, True) + except Exception as e: + print(f" [WARN] DistributedRadixTree1性能测试中操作失败: {e}") + break + + drt_end_time = time.time() + drt_duration = drt_end_time - drt_start_time + drt_ops_per_sec = num_operations / drt_duration if drt_duration > 0 else 0 + + print(f" [OK] DistributedRadixTree1性能测试完成:") + print(f" - 操作数量: {num_operations}") + print(f" - 总时间: {drt_duration:.3f} 秒") + print(f" - 每秒操作数: {drt_ops_per_sec:.0f}") + + # 性能对比 + print(" - 性能对比:") + if lrt_ops_per_sec > 0 and drt_ops_per_sec > 0: + ratio = lrt_ops_per_sec / drt_ops_per_sec + print(f" LocalRadixTree1 vs DistributedRadixTree1: {ratio:.2f}x") + + print("[OK] 性能测试完成") + + # 步骤11: 清理资源 + print("步骤11: 清理资源...") + local_tree1.stop() + local_tree2.stop() + distributed_tree1.stop() + distributed_tree2.stop() + redis_meta1.unregister_node() + redis_meta2.unregister_node() + print("[OK] 资源清理完成") + + # 验证测试结果 + success = True + if lrt1_cached == 0 and lrt2_cached == 0: + print("[WARN] LocalRadixTree没有缓存任何块,可能插入失败") + success = False + + if drt1_empty and drt2_empty: + print("[WARN] DistributedRadixTree为空,可能remote_tree_refresh失败") + success = False + + if success: + print("\n[SUCCESS] 分布式RadixTree集成测试完成,所有功能正常工作!") + print(f"性能测试结果: LocalRadixTree1({lrt_ops_per_sec:.0f} ops/s), DistributedRadixTree1({drt_ops_per_sec:.0f} ops/s)") + else: + print("\n[WARN] 分布式RadixTree集成测试完成,但有一些警告") + + return success + + except Exception as e: + print(f"[ERROR] 分布式RadixTree集成测试失败: {e}") + import traceback + traceback.print_exc() + return False + +def test_cuda_skipping(): + """测试 CUDA 跳过""" + print("\n=== 测试 CUDA 跳过 ===") + try: + cuda_available = torch.cuda.is_available() + print(f"CUDA 可用性: {cuda_available}") + + if cuda_available: + print("[WARN] CUDA 可用,但跳过 CUDA 相关测试") + else: + print("[OK] CUDA 不可用,自动跳过 CUDA 相关测试") + + # 测试 CPU 张量操作 + cpu_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.long) + print(f"[OK] CPU 张量创建成功: {cpu_tensor}") + + return True + except Exception as e: + print(f"[ERROR] CUDA 跳过测试失败: {e}") + return False + +def main(): + """主测试函数""" + print("FlexKV DistributedRadixTree 基本功能测试") + print("=" * 60) + + # 记录测试结果 + test_results = [] + + # 运行所有测试 + tests = [ + ("环境检查", check_environment), + ("模块导入", test_imports), + ("LocalRadixTree", test_local_radix_tree), + ("DistributedRadixTree", test_distributed_radix_tree), + ("PCFSCacheEngine", test_pcfs_cache_engine), + ("分布式RadixTree集成", test_distributed_radix_tree_integration), + ("CUDA 跳过", test_cuda_skipping), + ] + + for test_name, test_func in tests: + try: + result = test_func() + test_results.append((test_name, result)) + except Exception as e: + print(f"[ERROR] {test_name} 测试异常: {e}") + test_results.append((test_name, False)) + + # 输出测试总结 + print("\n" + "=" * 60) + print("测试总结:") + print("=" * 60) + + passed = 0 + failed = 0 + + for test_name, result in test_results: + status = "[OK] 通过" if result else "[ERROR] 失败" + print(f"{test_name}: {status}") + if result: + passed += 1 + else: + failed += 1 + + print(f"\n总计: {passed} 个测试通过, {failed} 个测试失败") + + if failed == 0: + print("\n[SUCCESS] FlexKV DistributedRadixTree 测试全部通过!") + print("所有基本功能都正常工作,DistributedRadixTree 可以正常使用。") + return 0 + else: + print(f"\n[WARN] 有 {failed} 个测试失败,请检查上述错误信息") + return 1 + +if __name__ == "__main__": + exit_code = main() + sys.exit(exit_code) diff --git a/tests/test_redis_meta.py b/tests/test_redis_meta.py new file mode 100644 index 0000000000..a72fce8e80 --- /dev/null +++ b/tests/test_redis_meta.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 +""" +FlexKV RedisMeta 测试程序 +测试 RedisMetaChannel 和 redis_meta.py 里的代码 +""" + +import sys +import os +import time +import threading +from typing import Optional, List, Dict, Any + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +def check_environment(): + """检查运行环境""" + print("=== 检查运行环境 ===") + + # 检查 Python 版本 + python_version = sys.version_info + print(f"Python 版本: {python_version.major}.{python_version.minor}.{python_version.micro}") + + # 检查环境变量 + ld_library_path = os.environ.get('LD_LIBRARY_PATH', '') + if 'flexkv' in ld_library_path.lower() or 'build/lib' in ld_library_path: + print("[OK] LD_LIBRARY_PATH 包含 FlexKV 库路径") + else: + print("[WARN] LD_LIBRARY_PATH 可能缺少 FlexKV 库路径") + + return True + +def test_imports(): + """测试模块导入""" + print("\n=== 测试模块导入 ===") + try: + import flexkv + print("[OK] flexkv 模块导入成功") + + # 尝试导入 c_ext,如果失败也不影响其他测试 + try: + from flexkv import c_ext + print("[OK] flexkv.c_ext 模块导入成功") + except Exception as e: + print(f"[WARN] flexkv.c_ext 模块导入失败: {e}") + + from flexkv.cache.redis_meta import RedisMeta, RedisMetaChannel, BlockMeta, NodeState + print("[OK] RedisMeta, RedisMetaChannel, BlockMeta, NodeState 导入成功") + + # 检查 C++ 扩展是否可用 + from flexkv.cache.redis_meta import _CRedisMetaChannel, _CBlockMeta + if _CRedisMetaChannel is not None: + print("[OK] C++ RedisMetaChannel 扩展可用") + else: + print("[WARN] C++ RedisMetaChannel 扩展不可用") + + if _CBlockMeta is not None: + print("[OK] C++ BlockMeta 扩展可用") + else: + print("[WARN] C++ BlockMeta 扩展不可用") + + return True + except ImportError as e: + print(f"[ERROR] 导入失败: {e}") + return False + +def test_node_state(): + """测试 NodeState 枚举""" + print("\n=== 测试 NodeState 枚举 ===") + try: + from flexkv.cache.redis_meta import NodeState + + # 测试枚举值 + print(f"[OK] NODE_STATE_NORMAL: {NodeState.NODE_STATE_NORMAL}") + print(f"[OK] NODE_STATE_ABOUT_TO_EVICT: {NodeState.NODE_STATE_ABOUT_TO_EVICT}") + print(f"[OK] NODE_STATE_EVICTED: {NodeState.NODE_STATE_EVICTED}") + + # 测试枚举转换 + normal_state = NodeState(0) + evict_state = NodeState(1) + evicted_state = NodeState(2) + + print(f"[OK] 枚举转换测试成功: {normal_state}, {evict_state}, {evicted_state}") + + return True + except Exception as e: + print(f"[ERROR] NodeState 测试失败: {e}") + return False + +def test_block_meta(): + """测试 BlockMeta 类""" + print("\n=== 测试 BlockMeta 类 ===") + try: + from flexkv.cache.redis_meta import BlockMeta, NodeState + + # 创建 BlockMeta 实例 + meta = BlockMeta( + ph=12345, + pb=67890, + nid=1, + hash=987654321, + lt=1000000, + state=NodeState.NODE_STATE_NORMAL + ) + print("[OK] BlockMeta 创建成功") + + # 测试属性 + print(f" - ph: {meta.ph}") + print(f" - pb: {meta.pb}") + print(f" - nid: {meta.nid}") + print(f" - hash: {meta.hash}") + print(f" - lt: {meta.lt}") + print(f" - state: {meta.state}") + + # 测试默认值 + default_meta = BlockMeta() + print("[OK] BlockMeta 默认值创建成功") + print(f" - 默认 ph: {default_meta.ph}") + print(f" - 默认 state: {default_meta.state}") + + # 测试 C++ 转换(如果 C++ 扩展可用) + try: + from flexkv.cache.redis_meta import _CBlockMeta + if _CBlockMeta is not None: + c_meta = meta.to_c() + print("[OK] BlockMeta.to_c() 转换成功") + + restored_meta = BlockMeta.from_c(c_meta) + print("[OK] BlockMeta.from_c() 转换成功") + + # 验证转换正确性 + if (restored_meta.ph == meta.ph and + restored_meta.pb == meta.pb and + restored_meta.nid == meta.nid and + restored_meta.hash == meta.hash and + restored_meta.lt == meta.lt and + restored_meta.state == meta.state): + print("[OK] C++ 转换验证成功") + else: + print("[ERROR] C++ 转换验证失败") + return False + else: + print("[WARN] C++ BlockMeta 扩展不可用,跳过转换测试") + + except Exception as e: + print(f"[WARN] C++ 转换测试失败: {e}") + + return True + except Exception as e: + print(f"[ERROR] BlockMeta 测试失败: {e}") + return False + +def test_redis_meta_channel(): + """测试 RedisMetaChannel 类""" + print("\n=== 测试 RedisMetaChannel 类 ===") + try: + from flexkv.cache.redis_meta import RedisMetaChannel, BlockMeta, NodeState, _CRedisMetaChannel + + # 检查 C++ 扩展是否可用 + if _CRedisMetaChannel is None: + print("[WARN] C++ RedisMetaChannel 扩展不可用,跳过 RedisMetaChannel 测试") + return True + + # 创建 RedisMetaChannel 实例 + channel = RedisMetaChannel( + host="127.0.0.1", + port=6379, + node_id=1, + local_ip="127.0.0.1", + blocks_key="flexkv_test_blocks" + ) + print("[OK] RedisMetaChannel 创建成功") + + # 测试属性 + try: + node_id = channel.node_id + local_ip = channel.local_ip + print(f"[OK] node_id: {node_id}") + print(f"[OK] local_ip: {local_ip}") + except Exception as e: + print(f"[WARN] 属性访问失败: {e}") + + # 测试连接 + try: + connected = channel.connect() + if connected: + print("[OK] RedisMetaChannel 连接成功") + else: + print("[WARN] RedisMetaChannel 连接失败") + return False + except Exception as e: + print(f"[ERROR] RedisMetaChannel 连接异常: {e}") + return False + + # 测试 make_block_key 方法 + try: + key = channel.make_block_key(1, 12345) + print(f"[OK] make_block_key 成功: {key}") + except Exception as e: + print(f"[WARN] make_block_key 失败: {e}") + + # 测试 publish_one 方法(需要连接) + try: + meta = BlockMeta(ph=1, pb=2, nid=1, hash=12345, lt=1000000, state=NodeState.NODE_STATE_NORMAL) + channel.publish_one(meta) + print("[OK] publish_one 成功") + except Exception as e: + print(f"[WARN] publish_one 失败: {e}") + + # 测试 publish_batch 方法 + try: + metas = [ + BlockMeta(ph=1, pb=2, nid=1, hash=12345, lt=1000000, state=NodeState.NODE_STATE_NORMAL), + BlockMeta(ph=2, pb=3, nid=1, hash=12346, lt=1000001, state=NodeState.NODE_STATE_NORMAL) + ] + channel.publish_batch(metas, batch_size=10) + print("[OK] publish_batch 成功") + except Exception as e: + print(f"[WARN] publish_batch 失败: {e}") + + # 测试 list_keys 方法 + try: + keys = channel.list_keys("*") + print(f"[OK] list_keys 成功,找到 {len(keys)} 个键") + except Exception as e: + print(f"[WARN] list_keys 失败: {e}") + + # 测试 hmget_field_for_keys 方法 + try: + # 先创建一些测试数据 + test_keys = [] + for i in range(5): + meta = BlockMeta(ph=1, pb=2, nid=1, hash=50000+i, lt=1000000+i, state=NodeState.NODE_STATE_NORMAL) + channel.publish_one(meta) + test_keys.append(channel.make_block_key(1, 50000+i)) + + # 测试获取单个字段 + values = channel.hmget_field_for_keys(test_keys, "ph") + print(f"[OK] hmget_field_for_keys 成功,获取了 {len(values)} 个字段值") + print(f" 字段值示例: {values[:3]}") + except Exception as e: + print(f"[WARN] hmget_field_for_keys 失败: {e}") + + # 测试 hmget_two_fields_for_keys 方法 + try: + # 测试获取两个字段 + field_pairs = channel.hmget_two_fields_for_keys(test_keys, "ph", "pb") + print(f"[OK] hmget_two_fields_for_keys 成功,获取了 {len(field_pairs)} 个字段对") + print(f" 字段对示例: {field_pairs[:2]}") + except Exception as e: + print(f"[WARN] hmget_two_fields_for_keys 失败: {e}") + + # 测试 renew_node_leases 方法 + try: + result = channel.renew_node_leases(1, 2000000, batch_size=10) + print(f"[OK] renew_node_leases 成功,结果: {result}") + except Exception as e: + print(f"[WARN] renew_node_leases 失败: {e}") + + # 测试 update_block_state_batch 方法 + try: + # 准备测试哈希值 + test_hashes = [50000 + i for i in range(5)] + result = channel.update_block_state_batch(1, test_hashes, NodeState.NODE_STATE_ABOUT_TO_EVICT, batch_size=10) + print(f"[OK] update_block_state_batch 成功,结果: {result}") + + # 验证状态是否更新成功 + values = channel.hmget_field_for_keys(test_keys, "state") + print(f" 更新后的状态值: {values}") + except Exception as e: + print(f"[WARN] update_block_state_batch 失败: {e}") + + # 测试 delete_blockmeta_batch 方法 + try: + # 删除之前创建的测试数据 + result = channel.delete_blockmeta_batch(1, test_hashes, batch_size=10) + print(f"[OK] delete_blockmeta_batch 成功,结果: {result}") + + # 验证数据是否被删除 + remaining_keys = channel.list_keys("flexkv_test_blocks:block:1:5000*") + print(f" 删除后剩余的键数量: {len(remaining_keys)}") + except Exception as e: + print(f"[WARN] delete_blockmeta_batch 失败: {e}") + + return True + except Exception as e: + print(f"[ERROR] RedisMetaChannel 测试失败: {e}") + return False + +def test_redis_meta(): + """测试 RedisMeta 类""" + print("\n=== 测试 RedisMeta 类 ===") + try: + from flexkv.cache.redis_meta import RedisMeta + + # 创建 RedisMeta 实例 + redis_meta = RedisMeta( + host="127.0.0.1", + port=6379, + password=None, + local_ip="127.0.0.1", + decode_responses=True + ) + print("[OK] RedisMeta 创建成功") + + # 测试属性 + print(f" - host: {redis_meta.host}") + print(f" - port: {redis_meta.port}") + print(f" - local_ip: {redis_meta.local_ip}") + print(f" - decode_responses: {redis_meta.decode_responses}") + + # 测试 UUID + uuid = redis_meta.get_uuid() + print(f"[OK] UUID 生成成功: {uuid}") + + # 测试 init_meta + try: + node_id = redis_meta.init_meta() + print(f"[OK] init_meta 成功,node_id: {node_id}") + + # 测试 get_node_id + retrieved_node_id = redis_meta.get_node_id() + print(f"[OK] get_node_id 成功: {retrieved_node_id}") + + # 测试 get_redis_meta_channel + channel = redis_meta.get_redis_meta_channel("flexkv_test_blocks") + print("[OK] get_redis_meta_channel 成功") + + # 测试 unregister_node + redis_meta.unregister_node() + print("[OK] unregister_node 成功") + + except Exception as e: + print(f"[ERROR] Redis 操作失败: {e}") + print("[INFO] 请确保 Redis 服务正在运行且可访问") + return False + + return True + except Exception as e: + print(f"[ERROR] RedisMeta 测试失败: {e}") + return False + +def main(): + """主测试函数""" + print("FlexKV RedisMeta 测试程序") + print("=" * 50) + + # 记录测试结果 + test_results = [] + + # 运行所有测试 + tests = [ + ("环境检查", check_environment), + ("模块导入", test_imports), + ("NodeState 枚举", test_node_state), + ("BlockMeta 类", test_block_meta), + ("RedisMetaChannel 类", test_redis_meta_channel), + ("RedisMeta 类", test_redis_meta), + ] + + for test_name, test_func in tests: + try: + result = test_func() + test_results.append((test_name, result)) + except Exception as e: + print(f"[ERROR] {test_name} 测试异常: {e}") + test_results.append((test_name, False)) + + # 输出测试总结 + print("\n" + "=" * 50) + print("测试总结:") + print("=" * 50) + + passed = 0 + failed = 0 + + for test_name, result in test_results: + status = "[OK] 通过" if result else "[ERROR] 失败" + print(f"{test_name}: {status}") + if result: + passed += 1 + else: + failed += 1 + + print(f"\n总计: {passed} 个测试通过, {failed} 个测试失败") + + if failed == 0: + print("\n[SUCCESS] FlexKV RedisMeta 测试全部通过!") + print("所有 RedisMeta 相关功能都正常工作。") + return 0 + else: + print(f"\n[WARN] 有 {failed} 个测试失败,请检查上述错误信息") + return 1 + +if __name__ == "__main__": + exit_code = main() + sys.exit(exit_code) diff --git a/tests/test_redis_node_info.py b/tests/test_redis_node_info.py new file mode 100644 index 0000000000..eb2ea1676f --- /dev/null +++ b/tests/test_redis_node_info.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +测试 RedisNodeInfo 类的功能 +""" + +import sys +import time +import threading +sys.path.insert(0, '..') + +from flexkv.cache.redis_meta import RedisNodeInfo + +def test_redis_node_info(): + """测试 RedisNodeInfo 的基本功能""" + print("=== 测试 RedisNodeInfo 基本功能 ===") + + try: + # 创建 RedisNodeInfo 实例 + node_info = RedisNodeInfo( + host='127.0.0.1', + port=6379, + local_ip='127.0.0.1', + password='' + ) + print("[OK] RedisNodeInfo 创建成功") + + # 连接 Redis + if not node_info.connect(): + print("[ERROR] 无法连接到 Redis") + return False + print("[OK] RedisNodeInfo 连接成功") + + # 注册节点 + node_id = node_info.register_node() + if node_id == 0xFFFFFFFF: # UINT32_MAX + print("[ERROR] 节点注册失败") + return False + print(f"[OK] 节点注册成功,node_id: {node_id}") + + # 获取当前节点 ID + current_node_id = node_info.node_id + print(f"[OK] 当前节点 ID: {current_node_id}") + + # 获取活跃节点列表 + active_nodes = node_info.get_active_node_ids() + print(f"[OK] 活跃节点列表: {active_nodes}") + + # 检查节点是否活跃 + is_active = node_info.is_node_active(node_id) + print(f"[OK] 节点 {node_id} 是否活跃: {is_active}") + + # 等待一段时间让监听线程工作 + print("[INFO] 等待 3 秒让监听线程工作...") + time.sleep(3) + + # 再次获取活跃节点列表 + active_nodes_after = node_info.get_active_node_ids() + print(f"[OK] 3秒后活跃节点列表: {active_nodes_after}") + + # 解注册节点 + unregister_result = node_info.unregister_node() + if not unregister_result: + print("[ERROR] 节点解注册失败") + return False + print("[OK] 节点解注册成功") + + # 断开连接 + node_info.disconnect() + print("[OK] RedisNodeInfo 断开连接成功") + + return True + + except Exception as e: + print(f"[ERROR] RedisNodeInfo 测试失败: {e}") + return False + +def test_multiple_nodes(): + """测试多个节点的注册和解注册""" + print("\n=== 测试多个节点注册 ===") + + try: + # 创建多个节点 + nodes = [] + for i in range(3): + node_info = RedisNodeInfo( + host='127.0.0.1', + port=6379, + local_ip=f'127.0.0.{i+1}', + password='' + ) + if node_info.connect(): + node_id = node_info.register_node() + if node_id != 0xFFFFFFFF: + nodes.append(node_info) + print(f"[OK] 节点 {i+1} 注册成功,node_id: {node_id}, IP: 127.0.0.{i+1}") + else: + print(f"[ERROR] 节点 {i+1} 注册失败") + else: + print(f"[ERROR] 节点 {i+1} 连接失败") + + if not nodes: + print("[ERROR] 没有成功注册的节点") + return False + + # 等待监听线程更新 + print("[INFO] 等待 2 秒让监听线程更新...") + time.sleep(2) + + # 检查第一个节点的活跃节点列表 + active_nodes = nodes[0].get_active_node_ids() + print(f"[OK] 第一个节点看到的活跃节点列表: {active_nodes}") + + # 解注册所有节点 + for i, node_info in enumerate(nodes): + if node_info.unregister_node(): + print(f"[OK] 节点 {i+1} 解注册成功") + else: + print(f"[ERROR] 节点 {i+1} 解注册失败") + node_info.disconnect() + + return True + + except Exception as e: + print(f"[ERROR] 多节点测试失败: {e}") + return False + +def test_pub_sub_notification(): + """测试发布订阅通知功能""" + print("\n=== 测试发布订阅通知功能 ===") + + try: + # 创建两个节点 + node1 = RedisNodeInfo('127.0.0.1', 6379, '127.0.0.1', '') + node2 = RedisNodeInfo('127.0.0.1', 6379, '127.0.0.2', '') + + if not node1.connect() or not node2.connect(): + print("[ERROR] 节点连接失败") + return False + + # 注册两个节点 + node1_id = node1.register_node() + node2_id = node2.register_node() + + print(f"[OK] 节点1 ID: {node1_id}, 节点2 ID: {node2_id}") + + # 等待通知传播 + print("[INFO] 等待 2 秒让通知传播...") + time.sleep(2) + + # 检查两个节点是否都能看到对方 + active_nodes_1 = node1.get_active_node_ids() + active_nodes_2 = node2.get_active_node_ids() + + print(f"[OK] 节点1看到的活跃节点: {active_nodes_1}") + print(f"[OK] 节点2看到的活跃节点: {active_nodes_2}") + + # 解注册节点 + node1.unregister_node() + node2.unregister_node() + + # 等待通知传播 + print("[INFO] 等待 2 秒让解注册通知传播...") + time.sleep(2) + + # 检查解注册后的状态 + active_nodes_1_after = node1.get_active_node_ids() + active_nodes_2_after = node2.get_active_node_ids() + + print(f"[OK] 解注册后节点1看到的活跃节点: {active_nodes_1_after}") + print(f"[OK] 解注册后节点2看到的活跃节点: {active_nodes_2_after}") + + # 断开连接 + node1.disconnect() + node2.disconnect() + + return True + + except Exception as e: + print(f"[ERROR] 发布订阅通知测试失败: {e}") + return False + +def main(): + print("RedisNodeInfo 功能测试") + print("=" * 50) + + success = True + + # 测试基本功能 + if not test_redis_node_info(): + success = False + + # 测试多个节点 + if not test_multiple_nodes(): + success = False + + # 测试发布订阅通知 + if not test_pub_sub_notification(): + success = False + + print("\n" + "=" * 50) + if success: + print("[SUCCESS] 所有 RedisNodeInfo 测试通过!") + print("\n功能总结:") + print("1. 节点注册:通过原子递增 global:node_id 获取唯一 node_id") + print("2. 节点信息存储:在 node:node_id 哈希中存储节点信息") + print("3. 发布订阅通知:注册/解注册时发布 flexkv_node_id_updated 消息") + print("4. 监听线程:订阅 flexkv_node_id_updated 并更新活跃节点列表") + print("5. 节点扫描:通过 SCAN 0 MATCH node:* 扫描所有活跃节点") + else: + print("[ERROR] 部分测试失败") + +if __name__ == "__main__": + main() diff --git a/tests/test_utils.py b/tests/test_utils.py index ba1392eabc..93541b612b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,7 +38,6 @@ 'remote_file_prefix': "remote_cache", 'use_gds': False, 'enable_trace': False, - 'use_pinned_memory': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], 'ssd_cache_iouring_entries': 32, 'remote_cache_path': ["remote_cache1", "remote_cache2"],