diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 967f70016b..7c27428df4 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -320,7 +320,8 @@ def get(self, raise NotImplementedError(f"Layerwise transfer is not supported yet, " f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") - if not os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1": + combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1" + if not combine_with_trtllm: aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block else: # When using FlexKV with TensorRT-LLM, we ignore the last incomplete block. diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 1b91b5f4d8..cf507414d5 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -109,7 +109,8 @@ def __init__(self, if self.is_multinode_tp and not self.model_config.use_mla: model_config_for_transfer.num_kv_heads = self.tp_size_per_node - if os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1": + combine_with_trtllm = os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1" + if not combine_with_trtllm: self.transfer_handles = [TransferManagerHandle( model_config_for_transfer, self.cache_config, @@ -173,6 +174,8 @@ def shutdown(self) -> None: for transfer_handle in self.transfer_handles: transfer_handle.shutdown() if hasattr(self, "remote_process") and self.remote_process is not None: + assert self.remote_process.is_alive() + self.remote_process.terminate() self.remote_process.join() self.remote_process.close() self.remote_process = None diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 2129362f7f..246955cdfe 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -135,7 +135,8 @@ def start(self) -> None: self.transfer_engine.start() def shutdown(self) -> None: - self.transfer_engine.shutdown() + if hasattr(self, 'transfer_engine'): + self.transfer_engine.shutdown() def get_master_host_and_ports_from_env() -> Tuple[str, Tuple[str, str, str]]: master_host = os.getenv("FLEXKV_MASTER_HOST", "localhost") @@ -408,6 +409,12 @@ def __init__(self, popen_process): self._popen = popen_process self.pid = popen_process.pid + def is_alive(self): + return self._popen.poll() is None + + def terminate(self): + self._popen.terminate() + def join(self, timeout=None): return self._popen.wait(timeout) @@ -644,6 +651,13 @@ def _bind_master_ports(self) -> None: except Exception as e: flexkv_logger.error(f"Master failed to bind ports: {e}") + try: + self.command_socket.close() + self.result_socket.close() + self.query_socket.close() + self.context.term() + except Exception: + pass raise def send_config_to_remotes(self) -> None: