diff --git a/README.md b/README.md index 8923a3a7f7..ea1a95cfbe 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) ```bash apt install liburing-dev -apt install libxxhash-dev +apt install libxxhash-dev ``` ### Build FlexKV @@ -26,6 +26,10 @@ apt install libxxhash-dev See [docs/vllm_adapter/README_en.md](docs/vllm_adapter/README_en.md) +### Use FlexKV with TensorRT-LLM + +See [docs/trtllm_adaption/README_en.md](docs/trtllm_adaption/README_en.md) + ### FlexKV Integration with Dynamo See [docs/dynamo_integration/README_en.md](docs/dynamo_integration/README_en.md) @@ -90,7 +94,7 @@ FlexKV performs: ## Roadmap -- **In-Process Cache Engine Integration**: In the dev branch, the implementation, integration, and invocation of the Cache Engine will be further optimized, along with synchronized updates to related APIs. +- **In-Process Cache Engine Integration**: In the dev branch, the implementation, integration, and invocation of the Cache Engine will be further optimized, along with synchronized updates to related APIs. - **Framework Integration**: Support works for vLLM, SGLang, and other acceleration frameworks will be updated soon. - **Distributed Query Support**: Enable scalable, distributed KVCache lookup. - **Latency Optimization**: Further reduce *get* latency via smarter prefetching and compression. diff --git a/README_zh.md b/README_zh.md index 1654ff17ae..626b269d79 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,7 +12,7 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE ```bash apt install liburing-dev -apt install libxxhash-dev +apt install libxxhash-dev ``` ### 编译 FlexKV @@ -22,10 +22,14 @@ apt install libxxhash-dev #./build.sh --release for cython package ``` -### 以 vLLM 为例使用 FlexKV +### 在 vLLM 中使用 FlexKV 见[docs/vllm_adapter/README_zh.md](docs/vllm_adapter/README_zh.md) +### 在 TensorRT-LLM 中使用 Flexkv + +见[docs/trtllm_adaption/README_zh.md](docs/trtllm_adaption/README_zh.md) + ### FlexKV和Dynamo框架的集成 见[docs/dynamo_integration/README_zh.md](docs/dynamo_integration/README_zh.md) @@ -93,4 +97,4 @@ FlexKV 在处理 *get* 请求时: - **缓存引擎共进程化**:dev 分支将进一步优化 Cache Engine 的实现、集成和调用,并同步更新相关 API 支持 - **加速框架支持**:对 vLLM、SGLang 等主流推理框架的适配将陆续发布 - **分布式查询支持**:实现可扩展的分布式 KVCache 查询能力 -- **延迟优化**:通过预取、压缩等手段进一步降低 *get* 请求延迟 \ No newline at end of file +- **延迟优化**:通过预取、压缩等手段进一步降低 *get* 请求延迟 diff --git a/docs/trtllm_adaption/README_en.md b/docs/trtllm_adaption/README_en.md new file mode 100644 index 0000000000..f15e92fbbc --- /dev/null +++ b/docs/trtllm_adaption/README_en.md @@ -0,0 +1,74 @@ +# Using FlexKV in TensorRT-LLM +## 1. Environment Setup + +### 1.1 Install TensorRT-LLM (Tag v1.1.0.rc2) +We are currently working with the community to merge TensorRT-LLM adaptation code. Before it is merged into the main branch, there are two methods: +#### 1.1.1 Method 1 +You can use the patch we provide and recompile: +```bash +cd TensorRT-LLM +git apply FLEXKV_DIR/examples/trtllm_adaption/trtllm_v1.1.0rc2.patch +``` +Note: For TensorRT-LLM compilation instructions, please refer to [here](https://nvidia.github.io/TensorRT-LLM/installation/build-from-source-linux.html#build-from-source-linux) + +#### 1.1.2 Method 2 +You can also install our pre-compiled package: +```bash +pip install https://flexkv-1252113659.cos.ap-shanghai.myqcloud.com/TensorRT-LLM/tensorrt_llm-1.1.0rc2-cp312-cp312-linux_x86_64.whl +``` + +## 2. Running + +### 2.1 Configure FlexKV + +First, set the environment variable `TENSORRT_LLM_USE_FLEXKV` to enable FlexKV: +```bash +export TENSORRT_LLM_USE_FLEXKV=1 +``` + +FlexKV can be configured through environment variables and configuration files. For details, please refer to [`docs/flexkv_config_reference/README_en.md`](../../docs/flexkv_config_reference/README_en.md). Below are two simple configuration examples. +##### Example 1: Enable CPU Offloading Only +Use 32GB of CPU memory as secondary cache. +```bash +unset FLEXKV_CONFIG_PATH +export FLEXKV_CPU_CACHE_GB=32 +``` +##### Example 2: Enable SSD Offloading +Use 32GB of CPU memory and 1TB of SSD storage as secondary and tertiary caches respectively. (Assuming the machine has two SSDs mounted at /data0 and /data1.) +```bash +# generate config +cat < ./flexkv_config.yml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data0/flexkv_ssd/;/data1/flexkv_ssd/ +enable_gds: false +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.yml" +``` + +### 2.2 Launch TensorRT-LLM +#### 2.2.1. Method 1: Using Our Provided Example Script +```bash +cd FLEXKV_DIR/examples/trtllm_adaption +bash launch.sh YOUR_MODEL_PATH +``` +Note: The `launch.sh` script will launch both TensorRT-LLM and FlexKV, and configure FlexKV through `flexkv_config.json` in the same directory. +#### 2.2.2. Method 2: Custom Launch +After configuring FlexKV according to the instructions in section [2.1](#21-configure-flexkv), add the following content to your `extra-llm-api-config.yml`: +```txt +kv_cache_config: + enable_partial_reuse: false +kv_connector_config: + connector_module: "flexkv.integration.tensorrt_llm.trtllm_adapter" + connector_scheduler_class: "FlexKVSchedulerConnector" + connector_worker_class: "FlexKVWorkerConnector" +``` + +### 2.3 Potential TensorRT-LLM Issues +If you send a request to TensorRT-LLM that exceeds the `max_seq_len` length, you may encounter an error similar to the following: +``` +[W] `default_max_tokens` (-40205) should be greater than 0, `default_max_tokens` (-40205) = max_seq_len (40961) - `splited_prompt_len` (81166) - `query_token_len` (0) +[W] User-specified `max_tokens` (16384) is greater than deduced `default_max_tokens` (-40205), using default_max_tokens instead. +[E] submit request failed: [TensorRT-LLM][ERROR] Assertion failed: mMaxNewTokens > 0 +``` +This is caused by the TensorRT-LLM framework itself not filtering requests that exceed the `max_seq_len` length, and is not related to FlexKV. We are currently working with the community to fix this issue. diff --git a/docs/trtllm_adaption/README_zh.md b/docs/trtllm_adaption/README_zh.md new file mode 100644 index 0000000000..6b72aa056f --- /dev/null +++ b/docs/trtllm_adaption/README_zh.md @@ -0,0 +1,74 @@ +# 在 TensorRT-LLM 中使用 FlexKV +## 1. 环境准备 + +### 1.1 安装 TensorRT-LLM(Tag 为 v1.1.0.rc2) +目前我们正在推动社区合入 TensorRT-LLM 侧的适配代码,在合入主分支之前,有如下两种方法: +#### 1.1.1 方法一 +您可以使用我们提供的 patch,然后重新编译: +```bash +cd TensorRT-LLM +git apply FLEXKV_DIR/examples/trtllm_adaption/trtllm_v1.1.0rc2.patch +``` +注:TensorRT-LLM 的编译方式可以参考[这里](https://nvidia.github.io/TensorRT-LLM/installation/build-from-source-linux.html#build-from-source-linux) + +#### 1.1.2 方法二 +您也可以安装我们预先编译好的包: +```bash +pip install https://flexkv-1252113659.cos.ap-shanghai.myqcloud.com/TensorRT-LLM/tensorrt_llm-1.1.0rc2-cp312-cp312-linux_x86_64.whl +``` + +## 2. 运行 + +### 2.1 配置FlexKV + +首先设置环境变量`TENSORRT_LLM_USE_FLEXKV`以启用FlexKV +```bash +export TENSORRT_LLM_USE_FLEXKV=1 +``` + +可以通过环境变量和配置文件两种方式配置FlexKV,具体请参考[`docs/flexkv_config_reference/README_zh.md`](../../docs/flexkv_config_reference/README_zh.md),下面提供了两个简单的配置示例。 +##### 示例一:仅启用CPU卸载 +使用32GB的CPU内存作为二级缓存。 +```bash +unset FLEXKV_CONFIG_PATH +export FLEXKV_CPU_CACHE_GB=32 +``` +##### 示例二:启用SSD卸载 +使用32GB的CPU内存和1T的SSD存储分别作为二级和三级缓存。(假设机器有两个SSD,并分别挂载在/data0和/data1两个路径上。) +```bash +# generate config +cat < ./flexkv_config.yml +cpu_cache_gb: 32 +ssd_cache_gb: 1024 +ssd_cache_dir: /data0/flexkv_ssd/;/data1/flexkv_ssd/ +enable_gds: false +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.yml" +``` + +### 2.2 启动 TensorRT-LLM +#### 2.2.1. 方式一:使用我们提供的示例脚本 +```bash +cd FLEXKV_DIR/examples/trtllm_adaption +bash launch.sh YOUR_MODEL_PATH +``` +注:`launch.sh` 脚本会同时启动 TensorRT-LLM 和 FlexKV,并通过同路径下的`flexkv_config.json`进行FlexKV的配置 +#### 2.2.2. 方式二:自定义启动 +按照 [2.1](#21-配置flexkv) 节的指示配置好FlexKV,接着在您的 `extra-llm-api-config.yml`加入下面的内容: +```txt +kv_cache_config: + enable_partial_reuse: false +kv_connector_config: + connector_module: "flexkv.integration.tensorrt_llm.trtllm_adapter" + connector_scheduler_class: "FlexKVSchedulerConnector" + connector_worker_class: "FlexKVWorkerConnector" +``` + +### 2.3 TensorRT-LLM 潜在的问题 +如果您向 TensorRT-LLM 发送了超过 `max_seq_len` 长度的请求,会出现类似下面的报错: +``` +[W] `default_max_tokens` (-40205) should be greater than 0, `default_max_tokens` (-40205) = max_seq_len (40961) - `splited_prompt_len` (81166) - `query_token_len` (0) +[W] User-specified `max_tokens` (16384) is greater than deduced `default_max_tokens` (-40205), using default_max_tokens instead. +[E] submit request failed: [TensorRT-LLM][ERROR] Assertion failed: mMaxNewTokens > 0 +``` +这是 TensorRT-LLM 框架本身没有过滤超过 `max_seq_len` 长度的请求导致的,和 FlexKV 本身无关,目前我们正在推动社区修复这个问题。 diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md index c9ca1b7697..79432012b2 100644 --- a/docs/vllm_adapter/README_en.md +++ b/docs/vllm_adapter/README_en.md @@ -43,10 +43,11 @@ export FLEXKV_CONFIG_PATH="./flexkv_config.yml" ### Running We provide an adaptation example based on **vLLM 0.10.1.1**: -1. apply patch +1. apply patch && installation ```bash -# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch -git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +cd vllm +git apply FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +pip install -e . # build and install vllm from source ``` 2. offline test diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md index 52e3199755..6c51b058f7 100644 --- a/docs/vllm_adapter/README_zh.md +++ b/docs/vllm_adapter/README_zh.md @@ -42,10 +42,11 @@ export FLEXKV_CONFIG_PATH="./flexkv_config.yml" ### 运行 我们提供了基于 **vLLM 0.10.1.1** 的适配示例: -1. apply patch +1. apply patch && installation ```bash -# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch -git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +cd vllm +git apply FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +pip install -e . # build and install vllm from source ``` 2. offline test diff --git a/examples/trtllm_adaption/extra-llm-api-config-cg.yml b/examples/trtllm_adaption/extra-llm-api-config.yml similarity index 83% rename from examples/trtllm_adaption/extra-llm-api-config-cg.yml rename to examples/trtllm_adaption/extra-llm-api-config.yml index bfcc3dfc94..009bb2a853 100644 --- a/examples/trtllm_adaption/extra-llm-api-config-cg.yml +++ b/examples/trtllm_adaption/extra-llm-api-config.yml @@ -15,6 +15,6 @@ kv_connector_config: connector_module: "flexkv.integration.tensorrt_llm.trtllm_adapter" connector_scheduler_class: "FlexKVSchedulerConnector" connector_worker_class: "FlexKVWorkerConnector" -speculative_config: - decoding_type: MTP - num_nextn_predict_layers: 3 \ No newline at end of file +# speculative_config: +# decoding_type: MTP +# num_nextn_predict_layers: 3 \ No newline at end of file diff --git a/examples/trtllm_adaption/launch.sh b/examples/trtllm_adaption/launch.sh index dcf0d0de6e..221af26af1 100644 --- a/examples/trtllm_adaption/launch.sh +++ b/examples/trtllm_adaption/launch.sh @@ -20,4 +20,4 @@ trtllm-serve serve $MODEL_PATH \ --max_seq_len $MAX_SEQ_LEN \ --max_num_tokens $MAX_NUM_TOKENS \ --max_batch_size $BATCH_SIZE \ - --extra_llm_api_options extra-llm-api-config-cg.yml 2>&1 | tee logs/$TIMESTAMP.log + --extra_llm_api_options extra-llm-api-config.yml 2>&1 | tee logs/$TIMESTAMP.log diff --git a/examples/trtllm_adaption/trtllm_v1.1.0rc2.patch b/examples/trtllm_adaption/trtllm_v1.1.0rc2.patch new file mode 100644 index 0000000000..06a74e3562 --- /dev/null +++ b/examples/trtllm_adaption/trtllm_v1.1.0rc2.patch @@ -0,0 +1,288 @@ +diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +index f069e3ac7..e74e6a01a 100644 +--- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h ++++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +@@ -1054,6 +1054,16 @@ public: + return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; + } + ++ [[nodiscard]] SizeType32 getNumConnectorMatchedTokens() const ++ { ++ return mNumConnectorMatchedTokens; ++ } ++ ++ void setNumConnectorMatchedTokens(SizeType32 numConnectorMatchedTokens) ++ { ++ mNumConnectorMatchedTokens = numConnectorMatchedTokens; ++ } ++ + void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock) + { + // Add debug log for prepopulatedPromptLen +@@ -1658,6 +1668,15 @@ public: + [](auto reason) { return reason == executor::FinishReason::kLENGTH; }); + } + ++ [[nodiscard]] bool isFinishedNormal() const noexcept ++ { ++ return std::all_of(mFinishReasons.begin(), mFinishReasons.end(), ++ [](auto reason) { ++ return reason == executor::FinishReason::kEND_ID || \ ++ reason == executor::FinishReason::kSTOP_WORDS || \ ++ reason == executor::FinishReason::kLENGTH; }); ++ } ++ + [[nodiscard]] bool isTimedOut() const + { + if (!mAllottedTimeMs.has_value()) +@@ -1906,6 +1925,9 @@ protected: + SizeType32 mPrepopulatedPromptLenTarget{0}; + SizeType32 mPrepopulatedPromptLenDraft{0}; + ++ // Number of tokens matched by KV cache connector for block reuse. ++ SizeType32 mNumConnectorMatchedTokens{0}; ++ + SizeType32 mMaxSentTokenLen; + + std::optional mEmbeddingBias{std::nullopt}; +diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +index 175b52577..fa01906a7 100644 +--- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp ++++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +@@ -1202,6 +1202,7 @@ void WindowBlockManager::addSequence( + if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) + { + numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); ++ llmRequest.setNumConnectorMatchedTokens(numConnectorMatchedTokens); + } + + llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock()); +@@ -2359,6 +2360,18 @@ void KVCacheManager::removeToken(RequestIdType requestId) + + void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths) + { ++ // Check if the sequence still exists before rewinding ++ // In overlap mode with MTP, the request may have been terminated and removed ++ // from mSequences before rewindKVCache is called ++ { ++ std::scoped_lock lck(mSequencesMtx); ++ if (mSequences.find(requestId) == mSequences.end()) ++ { ++ TLLM_LOG_DEBUG("Request %lu has already been removed from KV cache manager, skipping rewind", requestId); ++ return; ++ } ++ } ++ + for (SizeType32 si = 0; si < rewindLengths; ++si) + { + removeToken(requestId); +diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +index c170ca810..7fd5d5afe 100644 +--- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp ++++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +@@ -159,9 +159,11 @@ void initBindings(nb::module_& m) + .def("set_finished_reason", &GenLlmReq::setFinishedReason, nb::arg("finish_reason"), nb::arg("beam")) + .def_prop_ro("is_finished", &GenLlmReq::isFinished) + .def_prop_ro("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) ++ .def_prop_ro("is_finished_normal", &GenLlmReq::isFinishedNormal) + .def_prop_rw( + "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) + .def_prop_ro("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) ++ .def_prop_ro("num_connector_matched_tokens", &GenLlmReq::getNumConnectorMatchedTokens) + .def_prop_rw("guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) + .def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams) + .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) +diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +index 53c9ec7ef..c4be230d2 100644 +--- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp ++++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +@@ -164,9 +164,11 @@ void initBindings(pybind11::module_& m) + .def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam")) + .def_property_readonly("is_finished", &GenLlmReq::isFinished) + .def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength) ++ .def_property_readonly("is_finished_normal", &GenLlmReq::isFinishedNormal) + .def_property( + "context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition) + .def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen) ++ .def_property_readonly("num_connector_matched_tokens", &GenLlmReq::getNumConnectorMatchedTokens) + .def_property( + "guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams) + .def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams) +diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +index 5e8bf6dfa..ec22d269e 100644 +--- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py ++++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +@@ -392,6 +392,12 @@ class KvCacheConnectorManager(KvCacheConnectorManagerCpp): + + def get_num_new_matched_tokens(self, request: LlmRequest, + num_computed_tokens: int) -> int: ++ """ Called in C++: KVCacheManager::addSequence ++ if (mKvCacheConnectorManager && !llmRequest.isDummyRequest()) ++ { ++ numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen); ++ } ++ """ + num_tokens, load_kv_async = self._run_on_leader( + lambda: self.scheduler.get_num_new_matched_tokens( + request, num_computed_tokens)) +diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py +index 4b3315560..d098dd48b 100644 +--- a/tensorrt_llm/_torch/pyexecutor/py_executor.py ++++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py +@@ -275,6 +275,7 @@ class PyExecutor: + self.worker_started = False + self.worker_lock = threading.Lock() + ++ self.use_flexkv = os.getenv("TENSORRT_LLM_USE_FLEXKV", "0") == "1" + self.kv_connector_manager = kv_connector_manager + + self._maybe_init_kv_connector_manager() +@@ -282,6 +283,7 @@ class PyExecutor: + if start_worker: + self.start_worker() + ++ + def _maybe_init_kv_connector_manager(self): + if self.kv_connector_manager is not None: + if self.kv_cache_transceiver is not None: +@@ -309,6 +311,15 @@ class PyExecutor: + self.kv_connector_manager.layer_pre_hook) + module.register_forward_hook( + self.kv_connector_manager.layer_post_hook) ++ ++ if self.use_flexkv: ++ self._wait_for_flexkv_manager() ++ ++ def _wait_for_flexkv_manager(self): ++ if self.kv_connector_manager is not None and self.dist.rank == 0: ++ while not self.kv_connector_manager.scheduler.is_ready(): ++ time.sleep(0.1) ++ logger.info("FlexKV manager is ready") + + def _event_loop_wrapper(self): + try: +@@ -518,7 +529,7 @@ class PyExecutor: + if prev_device_step_time is None: + prev_device_step_time = "N/A" # Handle first iteration + else: +- prev_device_step_time = f"{prev_device_step_time}ms" ++ prev_device_step_time = f"{prev_device_step_time:.3f} ms" + host_step_time = (end_time - start_time) * 1000 # milliseconds + formatted_timestamp = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S") +@@ -528,7 +539,7 @@ class PyExecutor: + f"rank = {self.dist.rank}, " + f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" + f"{self.executor_request_queue.num_fetch_requests}, " +- f"host_step_time = {host_step_time}ms, " ++ f"host_step_time = {host_step_time:.3f} ms, " + f"prev_device_step_time = {prev_device_step_time}, " + f"timestamp = {formatted_timestamp}, " + f"num_scheduled_requests: {self.num_scheduled_requests}, " +@@ -965,6 +976,17 @@ class PyExecutor: + self.kv_connector_manager.worker.start_load_kv( + torch.cuda.current_stream()) + ++ def _kv_connector_refresh_unfinished_tasks(self): ++ if not self.use_flexkv: ++ return ++ if len(self.active_requests) == 0: ++ return ++ if not self.kv_connector_manager: ++ return ++ logger.warning(f"No scheduled requests, but flexkv have pending put requests") ++ self.kv_connector_manager.handle_metadata() ++ time.sleep(0.01) ++ + def _kv_connector_terminate_requests(self): + if self.kv_connector_manager: + reqs_to_terminate = self.kv_connector_manager.get_finished() +@@ -992,6 +1014,9 @@ class PyExecutor: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if scheduled_batch is None: + break ++ ++ if scheduled_batch.batch_size == 0: ++ self._kv_connector_refresh_unfinished_tasks() + + self._pause_requests(scheduled_batch.paused_requests) + +@@ -1124,6 +1149,9 @@ class PyExecutor: + break + + self._pause_requests(scheduled_batch.paused_requests) ++ ++ if scheduled_batch.batch_size == 0: ++ self._kv_connector_refresh_unfinished_tasks() + + if scheduled_batch.batch_size > 0: + if self.kv_cache_transceiver: +@@ -1772,6 +1800,13 @@ class PyExecutor: + new_responses.append((req_id, response)) + + if request_done: ++ # Release slot immediately when decode finishes, before put task completes ++ # This allows new requests to be scheduled earlier. ++ # Note: request_done is True when request.is_finished is True, which happens ++ # after the request state is set to GENERATION_COMPLETE in update_requests(). ++ # We check both to be safe. ++ if self.use_flexkv and (request.is_finished or request.state == LlmRequestState.GENERATION_COMPLETE): ++ self.resource_manager.free_slot_only(request) + if request.is_disagg_context_transmission_state: + self.ctx_in_transmission_requests.append(request) + else: +diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +index e824ee02d..63e0848ca 100644 +--- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py ++++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +@@ -1,6 +1,7 @@ + import copy + import enum + import importlib ++import os + from concurrent.futures import ThreadPoolExecutor + from contextlib import contextmanager + from dataclasses import dataclass +@@ -202,6 +203,7 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping: + tp_size=tensorrt_llm.mpi_world_size(), + gpus_per_node=tensorrt_llm.default_gpus_per_node(), + rank=tensorrt_llm.mpi_rank()) ++ executor_config.mapping = mapping + else: + mapping = copy.deepcopy(executor_config.mapping) + mapping.rank = tensorrt_llm.mpi_rank() +@@ -388,8 +390,12 @@ def create_py_executor( + f"Initializing kv connector with config: {kv_connector_config}") + + if pytorch_backend_config.use_cuda_graph: +- raise NotImplementedError( +- "CUDA graphs are not supported with KV connector hooks.") ++ use_flexkv = os.getenv("TENSORRT_LLM_USE_FLEXKV", "0") ++ if use_flexkv == "0": ++ raise NotImplementedError( ++ "CUDA graphs are not supported with KV connector hooks.") ++ else: ++ logger.info("Using FlexKV for KV connector") + + if executor_config.scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + raise NotImplementedError( +diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py +index 883a8d742..fa080a044 100644 +--- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py ++++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py +@@ -1033,6 +1033,15 @@ class ResourceManager: + if hasattr(resource_manager, "free_resources"): + resource_manager.free_resources(request) + ++ def free_slot_only(self, request: LlmRequest): ++ """Only free the slot for the request, without freeing other resources. ++ This is used to release the slot early when decode finishes, before ++ the put task completes. ++ """ ++ seq_slot_manager = self.get_resource_manager(ResourceManagerType.SEQ_SLOT_MANAGER) ++ if seq_slot_manager is not None: ++ seq_slot_manager.free_resources(request) ++ + def reorder_pipeline(self, resource_manager_list: list[str]): + assert set(resource_manager_list) == set(self.resource_managers.keys()) + for resource_manager in resource_manager_list: