diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 4a710f41ca..d4854557c3 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -14,7 +14,6 @@ "enable_remote": false, "tokens_per_block": 16, "use_gds": false, - "use_pinned_memory": true, "gpu_kv_layout_type": "LAYERWISE", "cpu_kv_layout_type": "BLOCKWISE", "ssd_kv_layout_type": "BLOCKWISE", diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md index acc2f36de2..79a38f62ab 100644 --- a/docs/vllm_adapter/README_en.md +++ b/docs/vllm_adapter/README_en.md @@ -41,7 +41,6 @@ cat < ./flexkv_config.json "cache_config": { "enable_cpu": true, "num_cpu_blocks": 10240, - "use_pinned_memory": true }, "num_log_interval_requests": 200 } diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md index f13815db1e..81e291b5cc 100644 --- a/docs/vllm_adapter/README_zh.md +++ b/docs/vllm_adapter/README_zh.md @@ -40,7 +40,6 @@ cat < ./flexkv_config.json "cache_config": { "enable_cpu": true, "num_cpu_blocks": 10240, - "use_pinned_memory": true }, "num_log_interval_requests": 200 } @@ -81,4 +80,4 @@ bash benchmarks/flexkv_benchmark/serving_vllm.sh # 启动性能测试 bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh ``` -在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 \ No newline at end of file +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 diff --git a/examples/run_server.py b/examples/run_server.py index d5b6a182ec..48b24ecad1 100644 --- a/examples/run_server.py +++ b/examples/run_server.py @@ -12,16 +12,16 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - + # NAME - parser.add_argument("--enable-cpu", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-cpu", + action=argparse.BooleanOptionalAction, default=True) - parser.add_argument("--enable-ssd", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-ssd", + action=argparse.BooleanOptionalAction, default=False,) - parser.add_argument("--enable-remote", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-remote", + action=argparse.BooleanOptionalAction, default=False,) parser.add_argument("--model-path", type=str, help="model path", default="") parser.add_argument("--tp-size", type=int, default=1) @@ -54,7 +54,7 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() hf_config = AutoConfig.from_pretrained(args.model_path) - + num_layers=hf_config.num_hidden_layers if hasattr(hf_config, 'num_key_value_heads'): num_kv_heads=hf_config.num_key_value_heads @@ -65,7 +65,7 @@ def parse_args() -> argparse.Namespace: head_size=(hf_config.head_dim if hasattr(hf_config, 'head_dim') else hf_config.hidden_size//hf_config.num_attention_heads) use_mla=hf_config.architectures[0].startswith("Deepseek") - + # TODO: different model config may have different attribute name model_config = ModelConfig( num_layers=num_layers, @@ -76,14 +76,13 @@ def parse_args() -> argparse.Namespace: dp_size=args.dp_size, dtype=hf_config.torch_dtype ) - + cache_config = CacheConfig( enable_cpu=args.enable_cpu, enable_ssd=args.enable_ssd, enable_remote=args.enable_remote, use_gds=False, enable_trace=False, - use_pinned_memory=False, ssd_cache_iouring_entries=512, tokens_per_block=args.block_size, num_cpu_blocks=args.num_cpu_blocks, @@ -93,6 +92,6 @@ def parse_args() -> argparse.Namespace: remote_cache_size_mode=args.remote_cache_size_mode, remote_cache_path=args.remote_cache_path, ) - + kvserver = KVServer(model_config, cache_config, args.server_recv_port) - kvserver.run() \ No newline at end of file + kvserver.run() diff --git a/examples/scheduler_server_example.py b/examples/scheduler_server_example.py index 29826afc9a..059cc467aa 100644 --- a/examples/scheduler_server_example.py +++ b/examples/scheduler_server_example.py @@ -16,9 +16,9 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, model_config, gpu_kv_layout): """Run TP client process""" from flexkv.server.client import KVTPClient - + print(f"Starting TP client: dp_client_id={dp_client_id}, tp_rank={tp_rank}, device_id={device_id}") - + try: # Set CUDA device for this process if torch.cuda.is_available(): @@ -27,7 +27,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo torch.cuda.init() # Clear cache torch.cuda.empty_cache() - + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Create GPU blocks for this TP client @@ -51,7 +51,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo # Keep TP client running while True: time.sleep(1) - + except Exception as e: print(f"TP client {tp_rank} error: {e}") import traceback @@ -84,7 +84,6 @@ def main(): enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks, ) @@ -106,14 +105,14 @@ def main(): cache_config=cache_config, server_recv_port="ipc:///tmp/scheduler_server_example" # TPClient connects to this port ) - + # Start background server thread to handle TPClient registration scheduler_server.start_server_thread() - - print(f"SchedulerServer started!") + + print("SchedulerServer started!") print(f"TPClient can connect to: {scheduler_server.get_server_port()}") print("Starting TP client processes...") - + # Start TP client processes tp_client_processes = [] for tp_rank in range(tp_size): @@ -123,7 +122,7 @@ def main(): if device_id >= available_gpus: device_id = device_id % available_gpus print(f"Warning: Using GPU {device_id} for TP rank {tp_rank} (not enough GPUs)") - + tp_client_process = Process( target=run_tp_client_process, args=(0, tp_rank, device_id, scheduler_server.get_server_port(), model_config, gpu_kv_layout), @@ -134,32 +133,32 @@ def main(): print(f"Started TP client process for rank {tp_rank} on device {device_id}") print("Waiting for all TP clients to register...") - + time.sleep(5) - + # Now we can directly use scheduler_server without network communication # Example: Create some test data (following benchmark_kvmanager.py pattern) batch_size = 4 seq_len = 128 - + print("\n=== Generating test data ===") # Generate separate sequences for each request (correct approach) batch_token_ids = [] batch_slot_mappings = [] batch_token_masks = [] - + for i in range(batch_size): # Each sequence is independent (seq_len,) shape token_ids = torch.randint(0, 1000, (seq_len,)) slot_mapping = torch.arange(i * seq_len, (i + 1) * seq_len) token_mask = torch.ones(seq_len, dtype=torch.bool) - + batch_token_ids.append(token_ids) batch_slot_mappings.append(slot_mapping) batch_token_masks.append(token_mask) - + print(f"Generated {batch_size} sequences, each with {seq_len} tokens") - + print("\n=== Executing PUT Operations ===") # PUT operations - each sequence processed separately start_time = time.time() @@ -173,7 +172,7 @@ def main(): if task_id: put_task_ids.append(task_id) print(f"PUT task {task_id} created for sequence {i}") - + put_time = (time.time() - start_time) * 1000 print(f"Created {len(put_task_ids)} PUT tasks, time: {put_time:.2f}ms") time.sleep(2) @@ -190,10 +189,10 @@ def main(): if task_id: get_task_ids.append(task_id) print(f"GET task {task_id} created for sequence {i}") - + get_time = (time.time() - start_time) * 1000 print(f"Created {len(get_task_ids)} GET tasks, time: {get_time:.2f}ms") - + print("\n=== Waiting for All Tasks to Complete ===") # Wait for all tasks to complete - can wait for multiple tasks at once all_task_ids = put_task_ids + get_task_ids @@ -202,7 +201,7 @@ def main(): masks = scheduler_server.wait(all_task_ids) wait_time = (time.time() - start_time) * 1000 print(f"All {len(all_task_ids)} tasks completed, time: {wait_time:.2f}ms") - + # Analyze results if masks: total_tokens = 0 @@ -211,7 +210,7 @@ def main(): tokens = mask.sum().item() if hasattr(mask, 'sum') else len(mask) total_tokens += tokens print(f"Task {task_id}: {tokens} tokens processed") - + print("\n=== Trying Non-blocking Wait ===") # Create a few more tasks and try non-blocking wait extra_task_ids = [] @@ -223,7 +222,7 @@ def main(): ) if task_id: extra_task_ids.append(task_id) - + if extra_task_ids: # Immediately try to wait (might not be completed yet) masks = scheduler_server.try_wait(extra_task_ids) @@ -233,15 +232,15 @@ def main(): print(f"Tasks {extra_task_ids} not ready yet, will wait...") masks = scheduler_server.wait(extra_task_ids) print(f"Tasks {extra_task_ids} completed after wait") - + print("\n✅ All operations completed successfully!") - - + + # Clean up resources print("\n=== Shutting down SchedulerServer ===") scheduler_server.shutdown() print("SchedulerServer has been shut down") - + # Terminate TP client processes print("Terminating TP client processes...") for i, process in enumerate(tp_client_processes): @@ -253,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch index fc0a558d03..812a1d6e2f 100644 --- a/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +++ b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch @@ -1,24 +1,9 @@ -From a434b67b8097990f20d8c020a8c713b10dd3d5b0 Mon Sep 17 00:00:00 2001 -From: zuogan -Date: Wed, 3 Sep 2025 05:11:50 -0700 -Subject: [PATCH] add flexkv connector - ---- - .../prefix_caching_flexkv.py | 163 +++++++++++++++ - .../kv_transfer/kv_connector/factory.py | 5 + - .../kv_connector/v1/flexkv_connector.py | 191 ++++++++++++++++++ - vllm/v1/core/sched/scheduler.py | 13 +- - .../worker/kv_connector_model_runner_mixin.py | 6 +- - 5 files changed, 373 insertions(+), 5 deletions(-) - create mode 100644 examples/offline_inference/prefix_caching_flexkv.py - create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py - diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py new file mode 100644 -index 000000000..4cfe2ef7f +index 000000000..a57328ffd --- /dev/null +++ b/examples/offline_inference/prefix_caching_flexkv.py -@@ -0,0 +1,163 @@ +@@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import time @@ -36,7 +21,6 @@ index 000000000..4cfe2ef7f + "cache_config": { + "enable_cpu": True, + "num_cpu_blocks": 10240, -+ "use_pinned_memory": True + }, + "num_log_interval_requests": 200 +} @@ -84,7 +68,7 @@ index 000000000..4cfe2ef7f + +def main(): + # Create an LLM without prefix caching as a baseline. -+ regular_llm = LLM(model=model_path, ++ regular_llm = LLM(model=model_path, + enable_prefix_caching=False, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size @@ -114,7 +98,7 @@ index 000000000..4cfe2ef7f + # return + + # Create an LLM with prefix caching enabled. -+ prefix_cached_llm = LLM(model=model_path, ++ prefix_cached_llm = LLM(model=model_path, + enable_prefix_caching=True, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, @@ -124,7 +108,7 @@ index 000000000..4cfe2ef7f + # Warmup so that the shared prompt's KV cache is computed. + prefix_cached_llm.generate(generating_prompts[0], sampling_params) + -+ # wait for offload kv task finished. ++ # wait for offload kv task finished. + time.sleep(2) + + # Generate with prefix caching. @@ -149,7 +133,7 @@ index 000000000..4cfe2ef7f + ]) + print(f"Generated answers are the same: {generated_same}") + -+ # wait for offload kv task finished. ++ # wait for offload kv task finished. + time.sleep(2) + + # reset prefix cache to use flexkv @@ -249,9 +233,9 @@ index 000000000..bdfa9f321 + **kwargs: additional arguments for the load operation + + Note: -+ The number of elements in kv_caches and layer_names should be ++ The number of elements in kv_caches and layer_names should be + the same. -+ ++ + """ + self._flexkv_connector.start_load_kv(forward_context, **kwargs) + @@ -260,7 +244,7 @@ index 000000000..bdfa9f321 + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. -+ ++ + This interface will be useful for layer-by-layer pipelining. + + Args: @@ -271,13 +255,13 @@ index 000000000..bdfa9f321 + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ -+ Start saving the a layer of KV cache from vLLM's paged buffer ++ Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. -+ kv_layer (torch.Tensor): the paged KV buffer of the current ++ kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. @@ -310,7 +294,7 @@ index 000000000..bdfa9f321 + call to this method (this call or a prior one). + """ + return self._flexkv_connector.get_finished(finished_req_ids) -+ ++ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the @@ -332,14 +316,14 @@ index 000000000..bdfa9f321 + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. -+ ++ + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: -+ the number of tokens that can be loaded from the ++ the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._flexkv_connector.get_num_new_matched_tokens( @@ -398,30 +382,30 @@ index 981023409..a6c8fac38 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -118,6 +118,7 @@ class Scheduler(SchedulerInterface): - + # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + self.sending_kv_reqs: dict[str, Request] = {} - + # Encoder-related. # Calculate encoder cache size if applicable @@ -1029,7 +1030,8 @@ class Scheduler(SchedulerInterface): - + if not delay_free_blocks: self._free_blocks(request) - + else: + self.sending_kv_reqs[request.request_id] = request return kv_xfer_params - + def _free_blocks(self, request: Request): @@ -1041,7 +1043,7 @@ class Scheduler(SchedulerInterface): return len(self.waiting) + len(self.running) - + def has_finished_requests(self) -> bool: - return len(self.finished_req_ids) > 0 + return len(self.finished_req_ids) > 0 or len(self.sending_kv_reqs) > 0 - + def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() @@ -1082,6 +1084,8 @@ class Scheduler(SchedulerInterface): @@ -430,20 +414,20 @@ index 981023409..a6c8fac38 100644 self.kv_event_publisher.shutdown() + if self.connector and hasattr(self.connector, "shutdown"): + self.connector.shutdown() - + ######################################################################## # KV Connector Related Methods @@ -1149,6 +1153,10 @@ class Scheduler(SchedulerInterface): scheduler the request during the next step. """ - + + # avoid busy checking + if len(self.running) == 0: + time.sleep(0.01) + if self.connector is not None: self.connector.update_connector_output(kv_connector_output) - + @@ -1158,4 +1166,5 @@ class Scheduler(SchedulerInterface): self.finished_recving_kv_req_ids.add(req_id) for req_id in (kv_connector_output.finished_sending or ()): @@ -457,16 +441,13 @@ index a03ebe35d..8e4460957 100644 @@ -66,9 +66,9 @@ class KVConnectorModelRunnerMixin: scheduler_output, wait_for_save=False) as kv_connector_output: pass - + - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT + # if (not kv_connector_output.finished_sending + # and not kv_connector_output.finished_recving): + # return EMPTY_MODEL_RUNNER_OUTPUT - + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = kv_connector_output --- -2.34.1 - diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 7d280c7ead..e113aeb69b 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -301,7 +301,7 @@ def get(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block @@ -652,7 +652,7 @@ def put(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) # the mask should has a prefix of True @@ -1068,7 +1068,7 @@ def _get_block_range(self, token_mask: np.ndarray) -> Tuple[int, int]: mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: - return 0, 0 + return len(token_mask)//self.tokens_per_block, len(token_mask)//self.tokens_per_block start_idx = mask_idx[0].item() // self.tokens_per_block end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 diff --git a/flexkv/common/config.py b/flexkv/common/config.py index fbcf465727..d20d7518dd 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -32,14 +32,13 @@ class CacheConfig: enable_ssd: bool = False enable_remote: bool = False use_gds: bool = False - use_pinned_memory: bool = False index_accel: bool = False # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE + cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE # mempool capacity configs num_cpu_blocks: int = 1000000 @@ -72,6 +71,6 @@ class CacheConfig: trace_max_file_size_mb: int = 100 trace_max_files: int = 5 trace_flush_interval_ms: int = 1000 - + #evict ratio evict_ratio: float = 0.0 diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index 0f79cf869b..a522c5549a 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -16,14 +16,18 @@ def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") - formatter = logging.Formatter( - fmt=_FORMAT, - datefmt=_DATE_FORMAT, + has_console_handler = any( + isinstance(handler, logging.StreamHandler) + for handler in self.logger.handlers ) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - self.logger.addHandler(console_handler) + if not has_console_handler: + formatter = logging.Formatter( + fmt=_FORMAT, + datefmt=_DATE_FORMAT, + ) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) self.set_level(debug_level) diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 92668ae3f8..dff6b1ff3a 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -121,7 +121,6 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_kv_layout_type": str(cache_config.ssd_kv_layout_type), "remote_kv_layout_type": str(cache_config.remote_kv_layout_type), "use_gds": cache_config.use_gds, - "use_pinned_memory": cache_config.use_pinned_memory, "remote_cache_size_mode": cache_config.remote_cache_size_mode, "num_cpu_blocks": cache_config.num_cpu_blocks, "num_ssd_blocks": cache_config.num_ssd_blocks, diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 5c5f6ed27c..c129e44563 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -194,11 +194,11 @@ def get_num_new_matched_tokens( """ task_id, num_new_matched_tokens = self._get_match(request=request, num_computed_tokens=num_computed_tokens) - self.flexkv_stats.record_get(num_prompt_tokens=request.num_prompt_tokens, + self.flexkv_stats.record_get(num_prompt_tokens=request.num_tokens, num_gpu_matched_tokens=num_computed_tokens, num_flexkv_matched_tokens=num_new_matched_tokens) - if not self._need_to_get(num_prompt_tokens=request.num_prompt_tokens, + if not self._need_to_get(num_prompt_tokens=request.num_tokens, num_computed_tokens=num_computed_tokens, num_new_matched_tokens=num_new_matched_tokens): return 0, False @@ -222,10 +222,11 @@ def _get_match( the task_id and number of new matched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_get = (cdiv(request.num_prompt_tokens+1, self.block_size)-1)*self.block_size - token_ids = request.prompt_token_ids[:num_tokens_to_get] + num_tokens_to_get = (request.num_tokens//self.block_size)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_get] - assert num_computed_tokens <= num_tokens_to_get + assert num_computed_tokens <= num_tokens_to_get, ( + f"{num_computed_tokens=} must less equal to {num_tokens_to_get=}") assert num_computed_tokens % self.block_size == 0 if num_tokens_to_get == num_computed_tokens: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index daf25ff7ca..1849c1e304 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -81,7 +81,7 @@ def register_dp_client( flexkv_logger.info(f"DP client {client_id} registered successfully") return client_id - + def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -105,7 +105,7 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: class KVServerHandle: def __init__(self, process: mp.Process): self.process = process - + def shutdown(self) -> None: self.process.join(timeout=5) if self.process.is_alive(): @@ -137,7 +137,7 @@ def __init__( self.req_counter = 0 self._is_ready = False self._running = False - + # Request handler dispatch table self.request_handlers = { StartRequest: self._handle_start_request, @@ -162,14 +162,14 @@ def start_server(self) -> None: self._is_ready = True @staticmethod - def _server_process(model_config: ModelConfig, + def _server_process(model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: str, server_recv_port: str) -> None: - + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) server.run() - + @classmethod def create_server(cls, model_config: ModelConfig, @@ -178,18 +178,15 @@ def create_server(cls, server_recv_port: Optional[str] = None) -> 'KVServerHandle': #if server_recv_port is None: # server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" #TODO unify this - + # Set spawn method for CUDA compatibility - try: + with contextlib.suppress(RuntimeError): mp.set_start_method("spawn") - except RuntimeError: - # If already set, just continue - pass process = mp.Process(target=cls._server_process, args=(model_config, cache_config, gpu_register_port, server_recv_port)) process.start() flexkv_logger.info(f"KVServer process started, PID: {process.pid}") - + return KVServerHandle(process) def run(self) -> None: @@ -216,13 +213,13 @@ def run(self) -> None: # Use dispatch table for request handling req_type = type(req) handler = self.request_handlers.get(req_type) - + if handler is None: raise TypeError(f"Unrecognized RequestType: {req_type}") - + # Call the corresponding handler method handler(req) - + # If the request is a shutdown request, exit the loop if req_type == ShutdownRequest: break @@ -246,7 +243,7 @@ def _verify_model_config( return True # Request Handler Methods - + def _handle_start_request(self, req: StartRequest) -> None: """Handle start request""" flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}") @@ -317,7 +314,7 @@ def _handle_put_match_request(self, req: PutMatchRequest) -> None: def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: """Handle LaunchTask request""" self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings) - + def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" self.kv_task_engine.cancel_tasks(req.task_ids) @@ -381,7 +378,6 @@ def __del__(self) -> None: enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks,) diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index 7cd38156e0..ed683e6505 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -95,7 +95,6 @@ def allocate(cls, layout: KVCacheLayout, dtype: torch.dtype, **kwargs: Any) -> StorageHandle: - pin_memory = kwargs.get("pin_memory", True) total_size = layout.get_total_elements() # although the kv layout may have multiple dimensions, we only have one-dim CPU tensor flexkv_logger.info(f"CPU allocate total_size: {2 * total_size/1024/1024/1024} GB") @@ -103,7 +102,7 @@ def allocate(cls, size=(total_size,), dtype=dtype, device="cpu", - pin_memory=pin_memory, + pin_memory=False, ) return StorageHandle( handle_type=AccessHandleType.TENSOR, diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 0762b0062d..0d48fe6230 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -35,7 +35,6 @@ def __init__(self, device_type=DeviceType.CPU, layout=self._cpu_layout, dtype=self._model_config.dtype, - pin_memory=self._cache_config.use_pinned_memory, ) if self._cache_config.enable_ssd: if not self._cache_config.ssd_kv_layout_type == self._cpu_layout.type: diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index 3ddc0ce810..fad6a20ea9 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -113,7 +113,6 @@ def parse_config_event(self, event: Dict[str, Any]): ssd_kv_layout_type=self._parse_layout_type(cache_config_data['ssd_kv_layout_type']), remote_kv_layout_type=self._parse_layout_type(cache_config_data['remote_kv_layout_type']), use_gds=cache_config_data['use_gds'], - use_pinned_memory=False,#cache_config_data['use_pinned_memory'], # for local test remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], num_cpu_blocks=cache_config_data['num_cpu_blocks'], num_ssd_blocks=cache_config_data['num_ssd_blocks'], diff --git a/tests/test_utils.py b/tests/test_utils.py index ba1392eabc..93541b612b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,7 +38,6 @@ 'remote_file_prefix': "remote_cache", 'use_gds': False, 'enable_trace': False, - 'use_pinned_memory': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], 'ssd_cache_iouring_entries': 32, 'remote_cache_path': ["remote_cache1", "remote_cache2"],