diff --git a/.gitignore b/.gitignore index 03727c4ed6..0b37f68d72 100644 --- a/.gitignore +++ b/.gitignore @@ -73,3 +73,6 @@ cover/ # VSCode .vscode/ + +# logs +*.log \ No newline at end of file diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp index 073cbe3279..adf056d50a 100644 --- a/csrc/radix_tree.cpp +++ b/csrc/radix_tree.cpp @@ -10,6 +10,38 @@ namespace flexkv { +// Helper function matching Python's get_hash with boundary check and _has_hashes branch +// Returns std::nullopt if block_id is out of bounds (like Python returning None) +// If has_hashes is true, reads from block_hashes_ptr; otherwise computes from token_ids +static std::optional get_hash_safe( + int64_t *block_hashes_ptr, + int64_t *token_ids_ptr, // Can be nullptr if has_hashes is true + int block_id, + int num_blocks, + bool has_hashes, + int tokens_per_block) { + if (block_id >= num_blocks) { + return std::nullopt; // Out of bounds, return None (similar to Python) + } + + if (has_hashes) { + // Read from pre-computed block_hashes (matching Python: if self._has_hashes) + return HashType(block_hashes_ptr[block_id]); + } else { + // Compute hash from token_ids (matching Python: hash_array(self.token_ids[...])) + if (token_ids_ptr == nullptr) { + // Cannot compute without token_ids, return nullopt + return std::nullopt; + } + // Compute hash for tokens up to (block_id+1)*tokens_per_block + // Matching Python: hash_array(self.token_ids[:(block_id+1)*self.tokens_per_block]) + Hasher hasher; + hasher.reset(); // Reset hasher (matching Python's _HASHER.reset()) + hasher.update(token_ids_ptr, (block_id + 1) * tokens_per_block * sizeof(int64_t)); + return hasher.digest(); + } +} + CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) { assert(index != nullptr); @@ -230,6 +262,11 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks, auto physical_blocks = new std::vector(); auto block_hashes_ptr = block_hashes.data_ptr(); HashType child_hash; + + // In C++ version, block_hashes is always pre-computed (has_hashes = true) + // token_ids_ptr is nullptr since we don't have token_ids in this function signature + bool has_hashes = true; + int64_t *token_ids_ptr = nullptr; while (prefix_blocks_num < num_blocks) { if (update_cache_info) { @@ -289,4 +326,4 @@ CRadixTreeIndex::match_prefix(torch::Tensor &block_hashes, int num_blocks, last_ready_node, current_node, physical_blocks); } -} // namespace flexkv +} // namespace flexkv diff --git a/examples/trtllm_adaption/extra-llm-api-config-cg.yml b/examples/trtllm_adaption/extra-llm-api-config-cg.yml new file mode 100644 index 0000000000..bfcc3dfc94 --- /dev/null +++ b/examples/trtllm_adaption/extra-llm-api-config-cg.yml @@ -0,0 +1,20 @@ +cuda_graph_config: + enable_padding: true + batch_sizes: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 +enable_chunked_prefill: true +kv_cache_config: + enable_partial_reuse: false + free_gpu_memory_fraction: 0.75 +kv_connector_config: + connector_module: "flexkv.integration.tensorrt_llm.trtllm_adapter" + connector_scheduler_class: "FlexKVSchedulerConnector" + connector_worker_class: "FlexKVWorkerConnector" +speculative_config: + decoding_type: MTP + num_nextn_predict_layers: 3 \ No newline at end of file diff --git a/examples/trtllm_adaption/flexkv_config.json b/examples/trtllm_adaption/flexkv_config.json new file mode 100644 index 0000000000..10f311c4ba --- /dev/null +++ b/examples/trtllm_adaption/flexkv_config.json @@ -0,0 +1,6 @@ +{ + "cpu_cache_gb": 32, + "ssd_cache_gb": 1024, + "ssd_cache_dir": "/data/flexkv_ssd/", + "enable_gds": false + } \ No newline at end of file diff --git a/examples/trtllm_adaption/launch.sh b/examples/trtllm_adaption/launch.sh new file mode 100644 index 0000000000..dcf0d0de6e --- /dev/null +++ b/examples/trtllm_adaption/launch.sh @@ -0,0 +1,23 @@ +mkdir -p logs +TIMESTAMP=$(date +%Y.%m.%d-%H:%M:%S) +MODEL_PATH=${1:-YOUR_MODEL_PATH} + +BATCH_SIZE=4 +TP_SIZE=8 +EP_SIZE=$TP_SIZE +MAX_SEQ_LEN=155648 +MAX_NUM_TOKENS=16384 + +export FLEXKV_CONFIG_PATH="./flexkv_config.json" +export TENSORRT_LLM_USE_FLEXKV=1 + +trtllm-serve serve $MODEL_PATH \ + --host 0.0.0.0 \ + --port 6000 \ + --backend pytorch \ + --tp_size $TP_SIZE \ + --ep_size $EP_SIZE \ + --max_seq_len $MAX_SEQ_LEN \ + --max_num_tokens $MAX_NUM_TOKENS \ + --max_batch_size $BATCH_SIZE \ + --extra_llm_api_options extra-llm-api-config-cg.yml 2>&1 | tee logs/$TIMESTAMP.log diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 2fa0038f21..0a8b1df67b 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -20,6 +20,7 @@ from typing import List, Tuple, Optional, Dict, Callable from dataclasses import dataclass, field +import os import numpy as np import nvtx import torch @@ -319,8 +320,12 @@ def get(self, raise NotImplementedError(f"Layerwise transfer is not supported yet, " f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") - # ignore the last incomplete block - aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block + if not os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1": + aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block + else: + # When using FlexKV with TensorRT-LLM, we ignore the last incomplete block. + aligned_length = ((token_ids.shape[0] - 1) // self.tokens_per_block) * self.tokens_per_block + aligned_token_ids = token_ids[:aligned_length] token_mask = token_mask[:aligned_length] diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index a522c5549a..0dbea1a457 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -2,13 +2,14 @@ import os import sys import time +import inspect from functools import wraps from typing import Optional, Callable, Any FLEXKV_LOGGING_PREFIX = os.getenv("FLEXKV_LOGGING_PREFIX", "FLEXKV") _FORMAT = (f"[{FLEXKV_LOGGING_PREFIX}] %(levelname)s %(asctime)s.%(msecs)03d " - " %(message)s") + "[%(filename)s:%(lineno)d] %(message)s") _DATE_FORMAT = "%m-%d %H:%M:%S" class FlexkvLogger: @@ -16,6 +17,8 @@ def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") + self.logger.propagate = False + has_console_handler = any( isinstance(handler, logging.StreamHandler) for handler in self.logger.handlers @@ -44,25 +47,62 @@ def set_level(self, level: str) -> None: self.logger.setLevel(log_level) self.enabled = log_level != (logging.CRITICAL + 1) + def _get_caller_info(self): + frame = inspect.currentframe() + try: + for _ in range(2): + frame = frame.f_back + if frame is None: + break + + if frame is not None: + filename = os.path.basename(frame.f_code.co_filename) + lineno = frame.f_lineno + return filename, lineno + finally: + del frame + + return "unknown", 0 + def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: if self.enabled: - self.logger.debug(msg, *args, **kwargs) + filename, lineno = self._get_caller_info() + record = self.logger.makeRecord( + self.logger.name, logging.DEBUG, filename, lineno, msg, args, None + ) + self.logger.handle(record) def info(self, msg: str, *args: Any, **kwargs: Any) -> None: if self.enabled: - self.logger.info(msg, *args, **kwargs) + filename, lineno = self._get_caller_info() + record = self.logger.makeRecord( + self.logger.name, logging.INFO, filename, lineno, msg, args, None + ) + self.logger.handle(record) def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: if self.enabled: - self.logger.warning(msg, *args, **kwargs) + filename, lineno = self._get_caller_info() + record = self.logger.makeRecord( + self.logger.name, logging.WARNING, filename, lineno, msg, args, None + ) + self.logger.handle(record) def error(self, msg: str, *args: Any, **kwargs: Any) -> None: if self.enabled: - self.logger.error(msg, *args, **kwargs) + filename, lineno = self._get_caller_info() + record = self.logger.makeRecord( + self.logger.name, logging.ERROR, filename, lineno, msg, args, None + ) + self.logger.handle(record) def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: if self.enabled: - self.logger.critical(msg, *args, **kwargs) + filename, lineno = self._get_caller_info() + record = self.logger.makeRecord( + self.logger.name, logging.CRITICAL, filename, lineno, msg, args, None + ) + self.logger.handle(record) flexkv_logger = FlexkvLogger(os.getenv("FLEXKV_LOG_LEVEL", "INFO")) diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index d8d1b6ff02..abe02aca0c 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -12,6 +12,26 @@ from flexkv.common.debug import flexkv_logger +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [('reserved', ctypes.c_byte * 64)] + +# Load CUDA runtime library +try: + cudart = ctypes.CDLL('libcudart.so') +except: + try: + cudart = ctypes.CDLL('libcudart.so.12') + except: + cudart = ctypes.CDLL('libcudart.so.11') + +# CUDA IPC handle size (64 bytes on Linux) +CUDA_IPC_HANDLE_SIZE = 64 + +# CUDA error codes +cudaSuccess = 0 +cudaErrorInvalidValue = 11 + + class cudaIpcMemHandle_t(ctypes.Structure): _fields_ = [("reserved", ctypes.c_byte * 64)] @@ -211,6 +231,120 @@ def _export_tensor_handle( device = tensor.device rebuild_func, rebuild_args = reductions.reduce_tensor(tensor) return rebuild_func, rebuild_args, device + + @staticmethod + def _export_cuda_ipc_handle(tensor: torch.Tensor) -> bytes: + """ + 直接使用 CUDA IPC API 导出 tensor 的 IPC handle + """ + # Get device pointer + data_ptr = tensor.data_ptr() + device = tensor.device + + flexkv_logger.debug(f"Exporting CUDA IPC handle: device={device}, data_ptr={hex(data_ptr)}") + + # Ensure we're on the correct device + torch.cuda.set_device(device) + + # Create IPC handle buffer + # ipc_handle = ctypes.create_string_buffer(CUDA_IPC_HANDLE_SIZE) + ipc_handle = cudaIpcMemHandle_t() + + # Call cudaIpcGetMemHandle + result = cudart.cudaIpcGetMemHandle( + ctypes.byref(ipc_handle), + ctypes.c_void_p(data_ptr) + ) + + if result != cudaSuccess: + error_msg = f"cudaIpcGetMemHandle failed with error code {result} for device {device}, ptr={hex(data_ptr)}" + flexkv_logger.error(error_msg) + raise RuntimeError(error_msg) + + # Return handle as bytes + # handle_bytes = bytes(ipc_handle.raw) + handle_bytes = ctypes.string_at(ctypes.byref(ipc_handle), 64) + flexkv_logger.debug(f"IPC handle exported successfully, first 16 bytes: {handle_bytes.hex()}") + return handle_bytes + + @staticmethod + def _import_cuda_ipc_handle(ipc_handle: bytes, shape: Tuple[int, ...], + dtype: torch.dtype, device: torch.device) -> torch.Tensor: + """ + 直接使用 CUDA IPC API 从 handle 导入 tensor + """ + flexkv_logger.debug(f"Attempting to import CUDA IPC handle for device {device}") + + # Ensure CUDA is initialized in this process + if not torch.cuda.is_initialized(): + flexkv_logger.info("Initializing CUDA in subprocess") + torch.cuda.init() + + # Set device and create a dummy tensor to ensure context is created + device_id = device.index if device.index is not None else 0 + torch.cuda.set_device(device_id) + + # Force CUDA context creation + _ = torch.zeros(1, device=device) + flexkv_logger.debug(f"CUDA context created for device {device_id}, current_device={torch.cuda.current_device()}") + + # Create IPC handle buffer + ipc_handle_buf = ctypes.create_string_buffer(ipc_handle, CUDA_IPC_HANDLE_SIZE) + + # 重建 IPC handle + handle = cudaIpcMemHandle_t() + ctypes.memmove(ctypes.byref(handle), ipc_handle, 64) + + # Open IPC memory handle + dev_ptr = ctypes.c_void_p() + result = cudart.cudaIpcOpenMemHandle( + ctypes.byref(dev_ptr), + handle, + ctypes.c_int(1) # cudaIpcMemLazyEnablePeerAccess = 1 + ) + flexkv_logger.debug(f"import CUDA IPC handle: device={device}, dev_ptr={hex(dev_ptr.value)}") + if result != cudaSuccess: + error_msg = f"cudaIpcOpenMemHandle failed with error code {result} for device {device_id}" + flexkv_logger.error(error_msg) + # flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle[:16].hex()}") + flexkv_logger.error(f"IPC handle bytes (first 16): {ipc_handle.hex()}") + flexkv_logger.error(f"Current CUDA device: {torch.cuda.current_device()}") + flexkv_logger.error(f"Target device: {device_id}") + raise RuntimeError(error_msg) + + # Create tensor from pointer + numel = 1 + for dim in shape: + numel *= dim + + class CudaArrayInterface: + def __init__(self, data_ptr, shape, dtype, strides=None): + self.__cuda_array_interface__ = { + "data": (data_ptr, False), # (data_ptr, read_only) + "shape": tuple(shape), + "typestr": { + torch.float32: " float: + return (self.match_end_time - self.match_start_time) + + @property + def task_execute_cost(self) -> float: + return (self.task_finished_time - self.task_launch_time) + + @property + @abstractmethod + def task_type(self) -> str: + ... + + def __str__(self): + return (f"FlexKVTask(task_id={self.task_id}, " + f"request={self.request.req_id}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVGetTask(FlexKVTask): + num_computed_tokens: int + num_new_matched_tokens: int + + @property + def task_type(self) -> str: + return "get" + + def __str__(self): + return (f"FlexKVGetTask(task_id={self.task_id}, " + f"request={self.request.req_id}, " + f"num_computed_tokens={self.num_computed_tokens}, " + f"num_new_matched_tokens={self.num_new_matched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVPutTask(FlexKVTask): + num_matched_tokens: int + num_unmatched_tokens: int + + @property + def task_type(self) -> str: + return "put" + + def __str__(self): + return (f"FlexKVPutTask(task_id={self.task_id}, " + f"request={self.request.req_id}, " + f"num_matched_tokens={self.num_matched_tokens}, " + f"num_unmatched_tokens={self.num_unmatched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") diff --git a/flexkv/integration/tensorrt_llm/trtllm_adapter.py b/flexkv/integration/tensorrt_llm/trtllm_adapter.py new file mode 100644 index 0000000000..ca07a03e48 --- /dev/null +++ b/flexkv/integration/tensorrt_llm/trtllm_adapter.py @@ -0,0 +1,545 @@ +import os +import time +from typing import TYPE_CHECKING, Optional, Literal, Any, List, Tuple +from dataclasses import dataclass, field +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from flexkv.kvmanager import KVManager +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ModelConfig, CacheConfig +from flexkv.common.request import KVResponseStatus +from flexkv.common.debug import flexkv_logger +from flexkv.integration.stats import FlexKVStats +from flexkv.integration.utils import cdiv +from flexkv.integration.config import FlexKVConfig +from flexkv.integration.tensorrt_llm.meta import( + FlexKVResponse, FlexKVTask, FlexKVGetTask, FlexKVPutTask, FlexKVConnectorMetadata) + +from flexkv.integration.tensorrt_llm.utils import RequestWrapper, get_dp_tp_info +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest +from tensorrt_llm.bindings.executor import ExecutorConfig +from tensorrt_llm._torch.pyexecutor.kv_cache_connector import ( + KvCacheConnectorScheduler, KvCacheConnectorWorker, + SchedulerOutput) + +class FlexKVSchedulerConnector(KvCacheConnectorScheduler): + def __init__(self, config: ExecutorConfig): + tp_size, dp_size, dp_rank = get_dp_tp_info(config) + flexkv_config = FlexKVConfig.from_env() + flexkv_config.post_init_from_trt_config(config, tp_size, dp_size, dp_rank) + + flexkv_logger.info(f"Start init FlexKVSchedulerConnector with {flexkv_config}") + self.server_recv_port = flexkv_config.server_recv_port + self.tp_size = flexkv_config.model_config.tp_size + self.dp_size = flexkv_config.model_config.dp_size + self.block_size = flexkv_config.cache_config.tokens_per_block + self.model_config = flexkv_config.model_config + self.cache_config = flexkv_config.cache_config + self.flexkv_manager = KVManager(model_config=self.model_config, + cache_config=self.cache_config, + server_recv_port=flexkv_config.server_recv_port, + dp_client_id=dp_rank) + self.flexkv_manager.start() + # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) + + # request_id -> task_id + self.req_id_to_task_dict: dict[str, int] = {} + # launched but unfinished tasks + self.get_tasks: dict[int, FlexKVGetTask] = {} + self.put_tasks: dict[int, FlexKVPutTask] = {} + # unlaunched tasks + self.tasks_to_launch: dict[int, FlexKVTask] = {} + self.tasks_to_cancel: dict[int, FlexKVTask] = {} + + self.flexkv_stats = FlexKVStats(os.getenv('FLEXKV_NUM_LOG_INTERVAL_REQUESTS', 200)) + + flexkv_logger.info("Finish init FlexKVSchedulerConnector") + + # Set environment variable for FlexKV with TensorRT-LLM + os.environ['FLEXKV_WITH_TRTLLM'] = '1' + + def is_ready( + self, + ) -> bool: + " Ask flexkv is ready " + return self.flexkv_manager.is_ready() + + def shutdown(self) -> None: + self.flexkv_manager.shutdown() + + @property + def dp_client_id(self) -> int: + return self.flexkv_manager.dp_client_id + + #################### + #### Get Method #### + #################### + + def get_num_new_matched_tokens( + self, + _request: "LlmRequest", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, bool]: A tuple containing two integer values representing the + number of new matched tokens and whether it is necessary + to get the new matched blocks from flexkv. + """ + request = RequestWrapper(_request) + task_id, num_new_matched_tokens = self._get_match(_request=_request, + num_computed_tokens=num_computed_tokens) + + self.flexkv_stats.record_get(num_prompt_tokens=request.num_prompt_tokens, + num_gpu_matched_tokens=num_computed_tokens, + num_flexkv_matched_tokens=num_new_matched_tokens) + + if not self._need_to_get(num_prompt_tokens=request.num_prompt_tokens, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens): + return 0, False + + return num_new_matched_tokens, True + + + def _get_match( + self, + _request: "LlmRequest", + num_computed_tokens: int = 0, + ) -> tuple[int, int]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, int]: A tuple containing two integer values representing + the task_id and number of new matched tokens. + """ + request = RequestWrapper(_request) + flexkv_logger.info(f"Get match request: {request}") + + match_start_time = time.perf_counter() + num_tokens_to_get = (request.num_prompt_tokens//self.block_size)*self.block_size + if num_tokens_to_get == 0: + return -1, 0 + + token_ids = request.all_token_ids[:num_tokens_to_get] + + assert num_computed_tokens <= num_tokens_to_get, ( + f"{num_computed_tokens=} must less equal to {num_tokens_to_get=}") + assert num_computed_tokens % self.block_size == 0,( + f"{num_computed_tokens=}, {self.block_size=}") + + if num_tokens_to_get == num_computed_tokens: + return -1, 0 + + np_token_ids = np.array(token_ids) + np_token_mask = np.ones_like(np_token_ids, dtype=bool) + np_token_mask[:num_computed_tokens] = False + task_id, matched_mask = self.flexkv_manager.get_match(token_ids=np_token_ids, + token_mask=np_token_mask) + num_new_matched_tokens = matched_mask.sum().item() + + # Auto cancel if not call update_state_after_alloc() + match_end_time = time.perf_counter() + flexkv_logger.debug(f"Get match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + if num_new_matched_tokens > 0: + self.req_id_to_task_dict[request.req_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVGetTask(task_id=task_id, + request=request, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + + flexkv_logger.debug(f"FlexKV create get task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_new_matched_tokens + + def _need_to_get( + self, + num_prompt_tokens: int, + num_computed_tokens: int, + num_new_matched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to get the new matched blocks from flexkv. + """ + return num_new_matched_tokens > 0 + + # def update_state_after_alloc( + # self, + # _request: "LlmRequest", + # blocks: "KVCacheBlocks", + # ) -> None: + def update_state_after_alloc( + self, + _request: "LlmRequest", + block_ids: List[int], + ) -> None: + """ + Compute slot mapping and prepare to launch task. + Only call after get_num_new_matched_tokens(). + + Args: + request: Request to get. + blocks: All blocks of the request. + num_new_matched_tokens: Number of new matched tokens returned by + get_num_new_matched_tokens(). + + Returns: + None. + """ + request = RequestWrapper(_request) + if request.num_new_matched_tokens == 0: + flexkv_logger.info(f"No new matched tokens, skip update state after alloc.") + return + + # prepare to launch task + task_id = self.req_id_to_task_dict[request.req_id] + task: FlexKVGetTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot_mapping + num_computed_blocks = task.num_computed_tokens // self.block_size + num_blocks_to_get = request.num_new_matched_tokens // self.block_size + block_ids_to_get = block_ids[num_computed_blocks:num_computed_blocks+num_blocks_to_get] + task.slot_mapping = np.array(block_ids_to_get).repeat(self.block_size)*self.block_size + + def wait_for_all_get_tasks(self) -> list[FlexKVResponse]: + return self._blocking_waiting_for_tasks(self.get_tasks) + + #################### + #### Put Method #### + #################### + + def request_finished( + self, + _request: "LlmRequest", + block_ids: list[int], + ) -> bool: + """ + Args: + request: Request to put. + blocks: All block_ids of the request. + + Returns: + bool: whether thire is unfinished task for this request. + """ + request = RequestWrapper(_request) + + # Task not finished, can't free blocks + if request.req_id in self.req_id_to_task_dict: + return True + + # Abnormal finished, don't put + if not (request.is_finished() and request.is_finished_normal()): + return False + + task_id, num_matched_tokens, num_unmatched_tokens = self._put_match(_request=_request) + + self.flexkv_stats.record_put(num_all_tokens=request.num_tokens, + num_unmatched_tokens=num_unmatched_tokens) + + if not self._need_to_put(num_all_tokens=request.num_tokens, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens): + return False + + # prepare to launch task + task: FlexKVPutTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot mapping + num_matched_blocks = num_matched_tokens // self.block_size + num_unmatched_blocks = num_unmatched_tokens // self.block_size + block_ids_to_put = block_ids[num_matched_blocks:num_matched_blocks+num_unmatched_blocks] + flexkv_logger.info(f"{num_matched_blocks=}, {num_matched_blocks+num_unmatched_blocks=}, {len(block_ids)=}") + task.slot_mapping = np.array(block_ids_to_put).repeat(self.block_size)*self.block_size + flexkv_logger.info(f"{task_id=}, {num_matched_tokens=}, {num_unmatched_tokens=}, {len(block_ids_to_put)=}, {self.block_size=}, {task.slot_mapping.shape=}") + return True + + def _put_match( + self, + _request: "LlmRequest" + ) -> tuple[int, int, int]: + """ + Args: + request: Request to put. + + Returns: + tuple[int, int, int]: A tuple containing three integer values representing + the task_id, number of matched tokens and number of unmatched tokens. + """ + request = RequestWrapper(_request) + flexkv_logger.info(f"Put match request: {request}") + match_start_time = time.perf_counter() + num_tokens_to_put = (cdiv(request.num_tokens+1, self.block_size)-1)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_put] + + if num_tokens_to_put == 0: + return -1, 0, 0 + + np_token_ids = np.array(token_ids) + task_id, unmatched_mask = self.flexkv_manager.put_match(token_ids=np_token_ids) + + num_unmatched_tokens = unmatched_mask.sum().item() + num_matched_tokens = num_tokens_to_put - num_unmatched_tokens + + # Auto cancel if not need to put. + match_end_time = time.perf_counter() + flexkv_logger.debug(f"Put match cost {(match_end_time-match_start_time)*1000:.2f} ms. {num_unmatched_tokens=}") + + if num_unmatched_tokens > 0: + self.req_id_to_task_dict[request.req_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVPutTask(task_id=task_id, + request=request, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + flexkv_logger.debug(f"FlexKV create put task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_matched_tokens, num_unmatched_tokens + + def _need_to_put( + self, + num_all_tokens: int, + num_matched_tokens: int, + num_unmatched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to put the unmatched blocks from flexkv. + """ + return num_unmatched_tokens > 0 + + def wait_for_all_put_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all put tasks. + + Returns: + list[FlexKVResponse]: Responses of all put tasks. + """ + return self._blocking_waiting_for_tasks(self.put_tasks) + + ####################### + #### Common Method #### + ####################### + + def cancel_tasks(self) -> None: + """ + Cancel tasks in self.cancel_tasks. + Call before launch_tasks() to delete req_id in self.req_id_to_task_dict + """ + if len(self.tasks_to_cancel) == 0: + return + for task in self.tasks_to_cancel.values(): + del self.req_id_to_task_dict[task.request.req_id] + flexkv_logger.info(f"FlexKV Cancel task: {task}") + self.flexkv_manager.cancel(task_ids=list(self.tasks_to_cancel.keys())) + self.tasks_to_cancel.clear() + + def launch_tasks(self) -> None: + """ + Launch tasks in self.unlaunched_tasks + """ + if len(self.tasks_to_launch) == 0: + return + task_launch_time = time.perf_counter() + task_ids: list[int] = [] + slot_mappings: list[np.ndarray] = [] + + for task_id, task in self.tasks_to_launch.items(): + flexkv_logger.info(f"FlexKV Launch task: {task}") + task.task_launch_time = task_launch_time + task_ids.append(task_id) + slot_mappings.append(task.slot_mapping) + if isinstance(task, FlexKVGetTask): + self.get_tasks[task_id] = task + else: + self.put_tasks[task_id] = task + self.flexkv_manager.launch(task_ids=task_ids, + slot_mappings=slot_mappings) + self.tasks_to_launch.clear() + + def query_finished_task(self) -> tuple[set[str], set[str]]: + """ + Get response of finished task. + + Returns: + list[FlexKVResponse]: Responses of finished tasks. + """ + if len(self.req_id_to_task_dict) == 0: + return set(), set() + task_ids = list(self.get_tasks.keys()) + list(self.put_tasks.keys()) + responses_from_manager = self.flexkv_manager.try_wait(task_ids) + task_finished_time = time.perf_counter() + finished_sending = set() + finished_recving = set() + num_failed_tasks = 0 + for task_id, response in responses_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + if task_id in self.get_tasks: + task = self.get_tasks.pop(task_id) + finished_recving.add(task.request.req_id) + else: + task = self.put_tasks.pop(task_id) + finished_sending.add(task.request.req_id) + del self.req_id_to_task_dict[task.request.req_id] + task.task_finished_time = task_finished_time + if success: + flexkv_logger.info(f"{task} finished successfully.") + else: + flexkv_logger.error(f"{task} failed, status: {response.status}.") + num_failed_tasks += 1 + flexkv_logger.debug(f"unfinished task: {self.req_id_to_task_dict}") + self.flexkv_stats.record_faild(num_failed_requests=num_failed_tasks) + return finished_sending, finished_recving + + def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[FlexKVResponse]: + """ + Blocking wait for tasks in task_dict. + + Returns: + list[FlexKVResponse]: Responses of all tasks in task_dict. + """ + if len(task_dict) == 0: + return [] + + task_ids = list(task_dict.keys()) + response_from_manager = self.flexkv_manager.wait(task_ids=task_ids) + task_finished_time = time.perf_counter() + responses_to_return: list[FlexKVResponse] = [] + for task_id, response in response_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + task = task_dict.pop(task_id) + task.task_finished_time = task_finished_time + if success: + flexkv_logger.info(f"{task} finished successfully.") + else: + flexkv_logger.error(f"{task} failed, status: {response.status}.") + responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + request=task.request, success=success)) + return responses_to_return + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + self.cancel_tasks() + self.launch_tasks() + finished_sending, finished_recving = self.query_finished_task() + metadata = FlexKVConnectorMetadata( + finished_sending=list(finished_sending), + finished_recving=list(finished_recving)) + return metadata + + @property + def dp_client_id(self) -> int: + return self.flexkv_manager.dp_client_id + +class FlexKVWorkerConnector(KvCacheConnectorWorker): + def __init__(self, config: ExecutorConfig): + tp_size, dp_size, dp_rank = get_dp_tp_info(config) + flexkv_config = FlexKVConfig.from_env() + flexkv_config.post_init_from_trt_config(config, tp_size, dp_size, dp_rank) + dp_client_id = dp_rank + + current_device_id = torch.cuda.current_device() + dp_client_id * flexkv_config.model_config.tp_size + self.flexkv_config = flexkv_config + flexkv_logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.gpu_register_port}, \ + dp_client_id: {dp_client_id}") + self.tp_client = KVTPClient(flexkv_config.gpu_register_port, dp_client_id, current_device_id) + flexkv_logger.info("Finish init FlexKVWorkerConnector") + + def register_kv_caches(self, kv_cache_tensor: torch.Tensor): + # vllm kv_caches: dict[str, torch.Tensor] + # trt kv_caches: torch.Tensor + + # shepe = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize}); + # 1. mNumPrimaryBlocks{blocksInPrimaryPool} blocksInPrimaryPool = tc::ceilDiv(maxTokens, tokensPerBlock); + # 2. layer_num + # 3. mKVFactor{mCacheType == CacheType::kSELFKONLY ? 1 : 2} + # 4. blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize) + + flexkv_logger.info(f"Start register kv_caches, shape: {kv_cache_tensor.shape}") + + # Get actual device from tensor (more reliable in MPI environment) + logical_device_id = kv_cache_tensor.device.index + flexkv_logger.debug(f"Tensor is on device: {kv_cache_tensor.device}, logical device.index={logical_device_id}") + flexkv_logger.debug(f"self.tp_client.device_id (from init): {self.tp_client.device_id}") + + # Get physical GPU ID (in case CUDA_VISIBLE_DEVICES is set) + import os + cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None) + if cuda_visible_devices: + # Map logical ID to physical ID + visible_gpus = [int(x) for x in cuda_visible_devices.split(',')] + physical_device_id = visible_gpus[logical_device_id] if logical_device_id < len(visible_gpus) else logical_device_id + flexkv_logger.debug(f"CUDA_VISIBLE_DEVICES={cuda_visible_devices}, mapping logical {logical_device_id} -> physical {physical_device_id}") + else: + physical_device_id = logical_device_id + flexkv_logger.debug(f"No CUDA_VISIBLE_DEVICES set, using logical device ID {logical_device_id}") + + # Use physical device ID for registration + correct_device_id = physical_device_id + + if self.flexkv_config.model_config.use_mla: + assert kv_cache_tensor.ndim == 4, (f"expect kv cached tensor has 4 dim but get shape={kv_cache_tensor.shape}") + + num_blocks = kv_cache_tensor.shape[0] + num_layers = kv_cache_tensor.shape[1] + kv_dim = kv_cache_tensor.shape[2] + block_size = self.flexkv_config.cache_config.tokens_per_block + num_kv_heads = 1 if self.flexkv_config.model_config.use_mla else self.flexkv_config.model_config.num_kv_heads + head_size = self.flexkv_config.model_config.head_size + if self.flexkv_config.model_config.use_mla: + assert kv_dim == 1, (f"expect kv_dim eqals to 1 when using MLA but get kv_dim={kv_dim}") + + assert num_kv_heads * head_size * block_size == kv_cache_tensor.shape[3], \ + (f"expect kv cached tensor last dim equals to num_kv_heads*head_size*block_size, " \ + f"but get last_dim = {kv_cache_tensor.shape[3]}, " \ + f"num_kv_heads = {num_kv_heads}, head_size = {head_size}, block_size = {block_size}") + + gpu_blocks = [kv_cache_tensor] # convert to list for flexkv register + + gpu_layout = KVCacheLayout( + type=KVCacheLayoutType.BLOCKFIRST, + num_layer=num_layers, + num_block=num_blocks, + tokens_per_block=block_size, + num_head=num_kv_heads, + head_size=head_size, + is_mla=self.flexkv_config.model_config.use_mla, + ) + # Use correct device_id from tensor's actual device + self.tp_client.register_to_server(gpu_blocks, gpu_layout, override_device_id=correct_device_id) + flexkv_logger.info(f"Finish register kv_caches on device {correct_device_id}") + + def start_load_kv(self, stream: torch.cuda.Stream): + return + + def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream): + return + + def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream): + return + + def wait_for_save(self, stream: torch.cuda.Stream): + return + + def get_finished( + self, finished_gen_req_ids: List[int], + started_loading_req_ids: List[int]) -> Tuple[List[int], List[int]]: + + finished_sending = self._metadata.finished_sending + finished_recving = self._metadata.finished_recving + return finished_sending, finished_recving \ No newline at end of file diff --git a/flexkv/integration/tensorrt_llm/utils.py b/flexkv/integration/tensorrt_llm/utils.py new file mode 100644 index 0000000000..f5619455b9 --- /dev/null +++ b/flexkv/integration/tensorrt_llm/utils.py @@ -0,0 +1,69 @@ + +from dataclasses import dataclass +from flexkv.common import request +from flexkv.common.debug import flexkv_logger +from tensorrt_llm.bindings.internal.batch_manager import LlmRequest +from tensorrt_llm.bindings.executor import ExecutorConfig + + +logger = flexkv_logger + +@dataclass +class RequestWrapper: + _request: LlmRequest + + @property + def req_id(self): + return self._request.request_id + + @property + def all_token_ids(self): + all_token_ids = self._request.get_tokens() + assert len(all_token_ids) == 1, "Don't support beam search." + return all_token_ids[0] + + @property + def num_tokens(self): + return len(self.all_token_ids) + + @property + def num_prompt_tokens(self): + return self._request.prompt_len + + @property + def num_new_matched_tokens(self): + return self._request.num_connector_matched_tokens + # return self._request.local_prepopulated_prompt_len + + def is_finished(self): + return self._request.is_finished + + def is_finished_normal(self): + # NORMAL = 0 + # ABNORMAL = 3 + + # if self._request.is_finished_normal(): + # return NORMAL + # else: + # return ABNORMAL + return self._request.is_finished_normal + + def __repr__(self): + return f"RequestWrapper(req_id={self.req_id}, " \ + f"num_prompt_tokens={self.num_prompt_tokens}, " \ + f"num_tokens={len(self.all_token_ids)}, " \ + f"num_new_matched_tokens={self.num_new_matched_tokens})" + +def get_dp_tp_info(config: ExecutorConfig): + mapping = config.mapping + + if mapping.enable_attention_dp: + # trt 也不支持同时开 tp+dp + tp_size = 1 + dp_size = mapping.tp_size + dp_rank = mapping.rank + else: + tp_size = mapping.tp_size + dp_size = 1 + dp_rank = 0 + return tp_size, dp_size, dp_rank \ No newline at end of file diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 06aa79e618..8cc0acc85a 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -69,7 +69,6 @@ def __init__(self, else: self.server_handle = None self.kv_task_engine = KVTaskEngine(model_config, cache_config, self.gpu_register_port) - @property def dpclient_id(self) -> int: return self.dp_client_id diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 62b1746fba..1b91b5f4d8 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -20,6 +20,7 @@ from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.transfer_manager import get_master_host_and_ports_from_env +from flexkv.common.debug import flexkv_logger class TaskStatus(Enum): # slot mapping is not ready @@ -107,13 +108,32 @@ def __init__(self, model_config_for_transfer = copy.deepcopy(self.model_config) if self.is_multinode_tp and not self.model_config.use_mla: model_config_for_transfer.num_kv_heads = self.tp_size_per_node + + if os.getenv("FLEXKV_WITH_TRTLLM", "0") == "1": + self.transfer_handles = [TransferManagerHandle( + model_config_for_transfer, + self.cache_config, + mode="process", + gpu_register_port=gpu_register_port + )] + else: + # When using FlexKV with TensorRT-LLM, we use remote mode to transfer data + # to avoid the way we launch subprocess in FlexKV + # conflict with TensorRT-LLM's MPI initialization + self.remote_process = TransferManagerOnRemote.create_process() + master_host, master_ports = get_master_host_and_ports_from_env() + self.transfer_handles = [ + TransferManagerHandle( + model_config_for_transfer, + self.cache_config, + mode="remote", + gpu_register_port=gpu_register_port, + master_host=master_host, + master_ports=master_ports + ) + ] + self.transfer_handles[0]._handle.send_config_to_remotes() - self.transfer_handles = [TransferManagerHandle( - model_config_for_transfer, - self.cache_config, - mode="process", - gpu_register_port=gpu_register_port - )] if self.is_multinode_tp: master_host, master_ports = get_master_host_and_ports_from_env() self.transfer_handles.append(TransferManagerHandle( @@ -152,6 +172,10 @@ def shutdown(self) -> None: if hasattr(self, "transfer_handles") and self.transfer_handles is not None: for transfer_handle in self.transfer_handles: transfer_handle.shutdown() + if hasattr(self, "remote_process") and self.remote_process is not None: + self.remote_process.join() + self.remote_process.close() + self.remote_process = None def create_get_task(self, task_id: int, diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 8c91c12a74..d74bc3568c 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -230,18 +230,22 @@ def register_to_server( self, kv_caches: List[torch.Tensor], kv_layout: KVCacheLayout, + override_device_id: Optional[int] = None, ) -> None: if not kv_caches or not kv_caches[0].is_cuda: raise ValueError("GPU blocks must be CUDA tensors") + # Use override_device_id if provided, otherwise use self.device_id + device_id = override_device_id if override_device_id is not None else self.device_id + handles = [] for _, tensor in enumerate(kv_caches): - handle = TensorSharedHandle(tensor, self.device_id) + handle = TensorSharedHandle(tensor, device_id) handles.append(handle) register_req = RegisterTPClientRequest( self.dp_client_id, - self.device_id, + device_id, handles, kv_layout ) diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 6afe811fbc..726a7a4d93 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -120,7 +120,8 @@ def create_worker(cls, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, *args: Any, **kwargs: Any) -> 'WorkerHandle': - """Generic worker creation template method""" + """Generic worker creation template method.""" + parent_conn, child_conn = mp_ctx.Pipe() # create pipe ready_event = mp_ctx.Event() worker_id = cls._get_worker_id() @@ -138,6 +139,8 @@ def create_worker(cls, @classmethod def _worker_process(cls, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, op_buffer_tensor: torch.Tensor, ready_event: Any, *args: Any, **kwargs: Any) -> None: + # Note: MPI initialization prevention is handled by create_safe_process + # Environment variables are set before this function is called worker = cls(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor, *args, **kwargs) ready_event.set() worker.run() @@ -446,7 +449,6 @@ def __init__(self, cudaHostRegister(cpu_blocks) self.num_layers = gpu_kv_layouts[0].num_layer - # here the chunk size doesn't include the layer info self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout.get_chunk_size() * self.dtype.itemsize \ for gpu_kv_layout in gpu_kv_layouts] diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 72c1a5ec38..2129362f7f 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -11,6 +11,10 @@ import tempfile import threading import numpy as np +import textwrap +import subprocess +import pickle +import sys from flexkv.common.transfer import TransferOpGraph from flexkv.common.config import CacheConfig, ModelConfig @@ -22,6 +26,7 @@ from flexkv.transfer.transfer_engine import TransferEngine from flexkv.server.utils import get_zmq_socket from flexkv.server.request import RegisterTPClientRequest, Response +from flexkv.common.debug import flexkv_logger class TransferManager: @@ -64,9 +69,10 @@ def _register_gpu_blocks_via_socket(self) -> None: flexkv_logger.info(f"GPU tensor registration server started on port {self.gpu_register_port}") expected_gpus = self.model_config.tp_size * self.model_config.dp_size - + flexkv_logger.info(f"{self.model_config.tp_size=}, {self.model_config.dp_size=}, {expected_gpus=}") while len(self.all_gpu_blocks) < expected_gpus: try: + # Recv from: flexkv.server.client.KVTPClient.register_to_server req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) except zmq.Again: time.sleep(0.001) @@ -311,17 +317,110 @@ def shutdown(self) -> None: def __del__(self) -> None: if not self._shutdown_flag: self.shutdown() - + @classmethod def create_process(cls, **kwargs: Any) -> Process: - def _run(): - instance = cls(**kwargs) - instance.start() - if hasattr(instance, '_worker_thread') and instance._worker_thread is not None: - instance._worker_thread.join() # block until worker thread exits - process = Process(target=_run, daemon=False) - process.start() - return process + import tempfile + import os + + # Serialize the class and kwargs + cls_data = pickle.dumps(cls) + kwargs_data = pickle.dumps(kwargs) + + # Create temporary files for serialized data + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.cls') as f: + f.write(cls_data) + cls_file = f.name + + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.kwargs') as f: + f.write(kwargs_data) + kwargs_file = f.name + + # Prepare environment - remove MPI-related variables to avoid conflicts + env = os.environ.copy() + # CRITICAL: Remove CUDA_VISIBLE_DEVICES to allow access to all GPUs + # TransferManager needs to access all physical GPUs for IPC + if 'CUDA_VISIBLE_DEVICES' in env: + flexkv_logger.info(f"Removing CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']} for TransferManager subprocess") + env.pop('CUDA_VISIBLE_DEVICES', None) + + # Create the subprocess script + transfer_manager_script = textwrap.dedent(f''' + import os + import sys + import pickle + import tempfile + + # Immediately disable MPI to avoid conflicts + os.environ['MPI4PY_RC_INITIALIZE'] = 'false' + + try: + # Load the class and kwargs + with open("{cls_file}", "rb") as f: + cls = pickle.load(f) + + with open("{kwargs_file}", "rb") as f: + kwargs = pickle.load(f) + + # Create and start TransferManager instance + instance = cls(**kwargs) + instance.start() + + # Keep running until worker thread exits + if hasattr(instance, '_worker_thread') and instance._worker_thread is not None: + instance._worker_thread.join() + + except Exception as e: + print(f"Error in TransferManager subprocess: {{e}}", file=sys.stderr) + sys.exit(1) + finally: + # Clean up temporary files + try: + os.unlink("{cls_file}") + os.unlink("{kwargs_file}") + except Exception: + pass + ''').strip() + + # Start the subprocess + process = subprocess.Popen([ + sys.executable, '-c', transfer_manager_script + ], env=env, stdout=None, stderr=None, text=True) # None = inherit parent's stdout/stderr + flexkv_logger.info(f"TransferManager subprocess started, PID: {process.pid}") + + # Clean up temporary files after subprocess completes + def cleanup_files(): + # Wait for subprocess to complete before cleaning up files + process.wait() + try: + os.unlink(cls_file) + os.unlink(kwargs_file) + except Exception: + pass + + import threading + cleanup_thread = threading.Thread(target=cleanup_files, daemon=True) + cleanup_thread.start() + + # Return a wrapper that mimics multiprocessing.Process interface + class SubprocessWrapper: + def __init__(self, popen_process): + self._popen = popen_process + self.pid = popen_process.pid + + def join(self, timeout=None): + return self._popen.wait(timeout) + + def close(self): + # Close the subprocess pipes + if self._popen.stdout: + self._popen.stdout.close() + if self._popen.stderr: + self._popen.stderr.close() + if self._popen.stdin: + self._popen.stdin.close() + + return SubprocessWrapper(process) class TransferManagerHandleBase(ABC): @abstractmethod @@ -386,6 +485,7 @@ def __init__(self, self.result_parent_conn, self.result_child_conn = self.mp_ctx.Pipe() self.process: Optional[Process] = None + self.start_event = self.mp_ctx.Event() self.ready_event = self.mp_ctx.Event() self._completed_results: List[Tuple[int, int]] = [] @@ -393,7 +493,7 @@ def __init__(self, def _start_process(self) -> None: if self.process is not None and self.process.is_alive(): return - + self.process = self.mp_ctx.Process( target=self._process_worker, args=(self.model_config, @@ -401,7 +501,8 @@ def _start_process(self) -> None: self.command_child_conn, self.result_child_conn, self.gpu_register_port, - self.ready_event), + self.ready_event, + self.start_event), daemon=False ) self.process.start() @@ -412,8 +513,11 @@ def _process_worker(self, command_conn, result_conn, gpu_register_port: str, - ready_event) -> None: + ready_event, + start_event) -> None: try: + start_event.set() + os.environ['MPI4PY_RC_INITIALIZE'] = 'false' transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) transfer_manager.initialize_transfer_engine() transfer_manager.start() @@ -443,7 +547,10 @@ def _process_worker(self, result_conn.close() def start(self) -> None: + os.environ['MPI4PY_RC_INITIALIZE'] = 'false' self._start_process() + self.start_event.wait() + os.environ['MPI4PY_RC_INITIALIZE'] = 'true' def is_ready(self) -> bool: return self.ready_event.is_set()