From c2a26fabbd904358fb1418081127061fb2aec461 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Mon, 1 Sep 2025 10:39:40 +0800 Subject: [PATCH 01/42] add vllm 0.10.0 support update readme --- README.md | 6 +- README_zh.md | 6 +- .../vllm_adaption/flexkv_vllm_0_10_0.patch | 1224 +++++++++++++++++ 3 files changed, 1232 insertions(+), 4 deletions(-) create mode 100644 examples/vllm_adaption/flexkv_vllm_0_10_0.patch diff --git a/README.md b/README.md index c4388efdf0..56875811e3 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ bash benchmarks/flexkv_benchmark/serving_vllm.sh # Start benchmark bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh ``` +Apply the patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. > **Note**: The current script is only compatible with the `main` branch. Support for the latest features in the `dev` branch is under development. @@ -84,8 +85,9 @@ FlexKV performs: - *put* requests can be called asynchronously; the time to copy data from GPU to CPU memory can overlap with subsequent computation. Data transfers between CPU memory, SSD, and scalable storage are fully handled asynchronously by the TransferEngine and transparent to the main process. ## Branch -- main is the stable branch, maintaining commits that have been tested. -- dev is the development branch, maintaining newer features. +- The main branch is the stable branch, which maintains already tested commits. Please pull from main branch if you need stable code. +- The dev branch is the development branch, which contains newer features. Please branch from and merge into dev if you need new features or are developing new functionality. +- The bugfix branch is for bug fixes, maintaining urgent bugs that need immediate resolution or documentation that requires prompt updates. If you need to fix a bug or update documentation urgently, please branch from and merge into the bugfix branch. ## Roadmap diff --git a/README_zh.md b/README_zh.md index 5ec5c476fb..8223a5d9c0 100644 --- a/README_zh.md +++ b/README_zh.md @@ -28,6 +28,7 @@ bash benchmarks/flexkv_benchmark/serving_vllm.sh # 启动性能测试 bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh ``` +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch`,测试方法同上。 > **注意**:当前脚本仅适配 `main` 分支。`dev` 分支的最新特性支持脚本正在开发中。 @@ -84,8 +85,9 @@ FlexKV 在处理 *get* 请求时: - *put*请求可以异步调用,从GPU copy到内存的时间可以与之后的计算重合。内存与SSD以及扩展存储间的传输则完全由TransferEngine之后执行,主进程不感知。 ## Branch -- main 为稳定分支,维护已经测试过的commit。 -- dev 为开发分支,维护较新特性。 +- main 为稳定分支,维护已经测试过的commit。需要稳定的代码请从此分支拉取。 +- dev 为开发分支,维护较新特性。需要新特性和开发新特性请从此分支拉取和合入。 +- bugfix 为bug分支,维护需要立即解决的bug或需要立即更新的文档。需要解决bug和立即更新的文档请从此分支拉取和合入。 ## Roadmap diff --git a/examples/vllm_adaption/flexkv_vllm_0_10_0.patch b/examples/vllm_adaption/flexkv_vllm_0_10_0.patch new file mode 100644 index 0000000000..f6349a0ac7 --- /dev/null +++ b/examples/vllm_adaption/flexkv_vllm_0_10_0.patch @@ -0,0 +1,1224 @@ +diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py +index c7229dbb8..d2325fd3a 100644 +--- a/benchmarks/backend_request_func.py ++++ b/benchmarks/backend_request_func.py +@@ -9,6 +9,7 @@ import time + import traceback + from dataclasses import dataclass, field + from typing import Optional, Union ++import asyncio + + import aiohttp + import huggingface_hub.constants +@@ -23,10 +24,10 @@ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + @dataclass + class RequestFuncInput: +- prompt: str ++ prompt: Union[str, list[str]] + api_url: str +- prompt_len: int +- output_len: int ++ prompt_len: Union[int, list[int]] ++ output_len: Union[int, list[int]] + model: str + model_name: Optional[str] = None + logprobs: Optional[int] = None +@@ -555,6 +556,107 @@ async def async_request_openai_audio( + pbar.update(1) + return output + ++async def async_request_openai_chat_completions_multiturns( ++ request_func_input: RequestFuncInput, ++ pbar: Optional[tqdm] = None, ++ turn_interval_time: float = 3.0, ++) -> RequestFuncOutput: ++ api_url = request_func_input.api_url ++ assert api_url.endswith( ++ ("chat/completions", "profile") ++ ), "OpenAI Chat Completions API URL must end with 'chat/completions'." ++ assert isinstance(request_func_input.prompt, list) ++ assert isinstance(request_func_input.prompt_len, list) ++ assert isinstance(request_func_input.output_len, list) ++ ++ async with aiohttp.ClientSession(trust_env=True, ++ timeout=AIOHTTP_TIMEOUT) as session: ++ payload = { ++ "model": request_func_input.model_name \ ++ if request_func_input.model_name else request_func_input.model, ++ "messages": [ ++ ], ++ "temperature": 0.0, ++ "stream": True, ++ "stream_options": { ++ "include_usage": True, ++ }, ++ } ++ payload["ignore_eos"] = request_func_input.ignore_eos ++ if request_func_input.extra_body: ++ payload.update(request_func_input.extra_body) ++ headers = { ++ "Content-Type": "application/json", ++ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", ++ } ++ ++ output_list = [] ++ for turn_id, prompt in enumerate(request_func_input.prompt): ++ output = RequestFuncOutput() ++ output.prompt_len = request_func_input.prompt_len[turn_id] ++ ++ payload["messages"].append({"role": "user", "content": prompt}) ++ payload["max_tokens"] = request_func_input.output_len[turn_id] ++ ++ generated_text = "" ++ ttft = 0.0 ++ st = time.perf_counter() ++ most_recent_timestamp = st ++ try: ++ async with session.post(url=api_url, json=payload, ++ headers=headers) as response: ++ if response.status == 200: ++ async for chunk_bytes in response.content: ++ chunk_bytes = chunk_bytes.strip() ++ if not chunk_bytes: ++ continue ++ ++ chunk = chunk_bytes.decode("utf-8").removeprefix( ++ "data: ") ++ if chunk != "[DONE]": ++ timestamp = time.perf_counter() ++ data = json.loads(chunk) ++ ++ if choices := data.get("choices"): ++ content = choices[0]["delta"].get("content") ++ # First token ++ if ttft == 0.0: ++ ttft = timestamp - st ++ output.ttft = ttft ++ ++ # Decoding phase ++ else: ++ output.itl.append(timestamp - ++ most_recent_timestamp) ++ ++ generated_text += content or "" ++ elif usage := data.get("usage"): ++ output.output_tokens = usage.get( ++ "completion_tokens") ++ ++ most_recent_timestamp = timestamp ++ ++ output.generated_text = generated_text ++ output.success = True ++ output.latency = most_recent_timestamp - st ++ else: ++ output.error = response.reason or "" ++ output.success = False ++ break ++ except Exception: ++ output.success = False ++ exc_info = sys.exc_info() ++ output.error = "".join(traceback.format_exception(*exc_info)) ++ break ++ payload["messages"].append({"role": "assistant", "content": generated_text}) ++ ++ output_list.append(output) ++ if turn_id != len(request_func_input.prompt) - 1: ++ await asyncio.sleep(turn_interval_time) ++ ++ if pbar: ++ pbar.update(1) ++ return output_list + + def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true": +@@ -619,6 +721,7 @@ ASYNC_REQUEST_FUNCS = { + "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, + "llama.cpp": async_request_openai_completions, ++ "openai-chat-multiturns": async_request_openai_chat_completions_multiturns, + } + + OPENAI_COMPATIBLE_BACKENDS = [ +diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py +index 1ad6cef7a..9178528d0 100644 +--- a/benchmarks/benchmark_dataset.py ++++ b/benchmarks/benchmark_dataset.py +@@ -49,9 +49,9 @@ class SampleRequest: + Represents a single inference request for benchmarking. + """ + +- prompt: Union[str, Any] +- prompt_len: int +- expected_output_len: int ++ prompt: Union[str, list[str], Any] ++ prompt_len: Union[int, list[int]] ++ expected_output_len: Union[int, list[int]] + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + +@@ -617,6 +617,108 @@ class SonnetDataset(BenchmarkDataset): + ) + return samples + ++ ++# ----------------------------------------------------------------------------- ++# ShareGPT Multiturn Dataset Implementation ++# ----------------------------------------------------------------------------- ++ ++ ++class ShareGPTMultiTurnsDataset(BenchmarkDataset): ++ def __init__(self, min_num_turns: int = 2, **kwargs) -> None: ++ super().__init__(**kwargs) ++ self.load_data(min_num_turns) ++ ++ def load_data(self, min_num_turns: int) -> None: ++ if self.dataset_path is None: ++ raise ValueError("dataset_path must be provided for loading data.") ++ ++ with open(self.dataset_path, encoding="utf-8") as f: ++ self.data = json.load(f) ++ # Filter entries with at least two conversation turns. ++ new_data = [] ++ for entry in self.data: ++ if "conversations" in entry: ++ while len(entry["conversations"]) > 0 and entry["conversations"][0]['from'] != 'human': ++ entry["conversations"].pop(0) ++ if len(entry["conversations"]) % 2 != 0: ++ entry["conversations"].pop(-1) ++ if len(entry["conversations"]) >= 2 * min_num_turns: ++ new_data.append(entry) ++ self.data = new_data ++ random.seed(self.random_seed) ++ random.shuffle(self.data) ++ ++ def sample( ++ self, ++ tokenizer: PreTrainedTokenizerBase, ++ num_requests: int, ++ lora_path: Optional[str] = None, ++ max_loras: Optional[int] = None, ++ output_len: Optional[int] = None, ++ **kwargs, ++ ) -> list: ++ samples: list = [] ++ for entry in self.data: ++ if len(samples) >= num_requests: ++ break ++ ++ prompt_list = [d["value"] for d in entry["conversations"][::2]] ++ completion_list = [d["value"] for d in entry["conversations"][1::2]] ++ # prompt, completion = ( ++ # entry["conversations"][0]["value"], ++ # entry["conversations"][1]["value"], ++ # ) ++ ++ lora_request, tokenizer = self.get_random_lora_request( ++ tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) ++ ++ ++ prompt_ids_list = [] ++ completion_ids_list = [] ++ prompt_len_list = [] ++ new_output_len_list = [] ++ history_len = 0 ++ for turn_id in range(len(prompt_list)): ++ try: ++ prompt_ids = tokenizer(prompt_list[turn_id]).input_ids ++ completion_ids = tokenizer(completion_list[turn_id]).input_ids ++ except: ++ print(entry) ++ raise ++ prompt_len = len(prompt_ids) + history_len ++ new_output_len = len(completion_ids) if output_len is None else output_len ++ if not is_valid_sequence( ++ prompt_len, ++ new_output_len, ++ min_len=4, ++ max_prompt_len=4096, ++ max_total_len=8192, ++ skip_min_output_len_check=output_len ++ is not None): ++ turn_id -= 1 ++ break ++ prompt_ids_list.append(prompt_ids) ++ completion_ids_list.append(completion_ids) ++ prompt_len_list.append(prompt_len) ++ new_output_len_list.append(new_output_len) ++ history_len += prompt_len ++ history_len += new_output_len ++ ++ if turn_id <= 0: ++ continue ++ ++ prompt_list = prompt_list[:turn_id+1] ++ ++ samples.append( ++ SampleRequest( ++ prompt=prompt_list, ++ prompt_len=prompt_len_list, ++ expected_output_len=new_output_len_list, ++ lora_request=lora_request, ++ )) ++ self.maybe_oversample_requests(samples, num_requests) ++ return samples ++ + + # ----------------------------------------------------------------------------- + # BurstGPT Dataset Implementation +diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py +index c597fb106..74e157927 100644 +--- a/benchmarks/benchmark_serving.py ++++ b/benchmarks/benchmark_serving.py +@@ -71,6 +71,7 @@ from benchmark_dataset import ( + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, ++ ShareGPTMultiTurnsDataset, + ) + from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json + from vllm.benchmarks.serve import get_request +@@ -142,7 +143,7 @@ def calculate_metrics( + ).input_ids + ) + actual_output_lens.append(output_len) +- total_input += input_requests[i].prompt_len ++ total_input += outputs[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft +@@ -278,6 +279,9 @@ async def benchmark( + ) + + test_output = await request_func(request_func_input=test_input) ++ if backend == "openai-chat-multiturns": ++ print("test_output ", test_output) ++ test_output = test_output[-1] + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " +@@ -394,6 +398,8 @@ async def benchmark( + task = limited_request_func(request_func_input=request_func_input, pbar=pbar) + tasks.append(asyncio.create_task(task)) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) ++ if backend == "openai-chat-multiturns": ++ outputs = [o for sub_o in outputs for o in sub_o] + + if profile: + print("Stopping profiler...") +@@ -748,6 +754,15 @@ def main(args: argparse.Namespace): + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), ++ "sharegpt_multiturns": ++ lambda: ShareGPTMultiTurnsDataset( ++ min_num_turns=4, ++ random_seed=args.seed, ++ dataset_path=args.dataset_path).sample( ++ tokenizer=tokenizer, ++ num_requests=args.num_prompts, ++ output_len=args.sharegpt_output_len, ++ ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), +@@ -930,7 +945,7 @@ def create_argument_parser(): + "--dataset-name", + type=str, + default="sharegpt", +- choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], ++ choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "sharegpt_multiturns"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( +diff --git a/benchmarks/flexkv_benchmark/container b/benchmarks/flexkv_benchmark/container +new file mode 100644 +index 000000000..cfc3b5bac +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/container +@@ -0,0 +1,12 @@ ++docker run --shm-size=50g --ipc=host --network host -it --gpus all -v /home/zichengm/FlexKV:/workspace -e GLOO_SOCKET_IFNAME=eno1 --entrypoint /bin/bash vllm/vllm-openai:v0.10.0 ++VLLM_USE_PRECOMPILED=1 pip install -e . ++apt update && apt install liburing-dev ++> vllm.log 2>&1 & ++export GLOO_SOCKET_IFNAME=eno1 ++nohup bash run_flexkv_server.sh > kvserver.log 2>&1 & ++nohup bash serving_vllm.sh 2 > vllm.log 2>&1 & ++bash multiturn_benchmark.sh ++ ++gdb -q -ex "set pagination off" -ex "set confirm off" -ex "set env PYTHONFAULTHANDLER=1" -ex "handle SIGPIPE noprint nostop pass" -ex "handle SIGBUS stop print" -ex "run" --args python3 examples/run_server.py --model-path Qwen/Qwen3-8B --tp-size 1 --dp-size 1 --block-size 16 --num-cpu-blocks 8192 --server-recv-port ipc:///tmp/tmpe0x8_0gq ++ ++cpu block num = cpu memory size / layer_num / 2 / token_per_block / num_heads / head_size / sizeof(data_type) +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/lmcache_config.yaml b/benchmarks/flexkv_benchmark/lmcache_config.yaml +new file mode 100644 +index 000000000..8016df5b6 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/lmcache_config.yaml +@@ -0,0 +1,6 @@ ++# Basic configurations ++chunk_size: 16 ++ ++# CPU offloading configurations ++local_cpu: true ++max_local_cpu_size: 32 +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/multiturn_benchmark.sh b/benchmarks/flexkv_benchmark/multiturn_benchmark.sh +new file mode 100644 +index 000000000..4a15ca771 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/multiturn_benchmark.sh +@@ -0,0 +1,17 @@ ++current_time=$(date +"%Y-%m-%d-%H:%M:%S") ++for workers in 128; do ++ concurrency_multiplier=4 ++ if [ $workers -gt 128 ]; then ++ concurrency_multiplier=2 ++ fi ++ python3 ../benchmark_serving.py \ ++ --backend openai-chat-multiturns \ ++ --model Qwen/Qwen3-8B \ ++ --dataset-name sharegpt_multiturns \ ++ --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \ ++ --num-prompts $((workers*concurrency_multiplier)) \ ++ --max-concurrency $workers \ ++ --host 0.0.0.0 \ ++ --port 12599 \ ++ --endpoint /v1/chat/completions 2>&1 ++done +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/run_flexkv_server.sh b/benchmarks/flexkv_benchmark/run_flexkv_server.sh +new file mode 100644 +index 000000000..56b38fa01 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/run_flexkv_server.sh +@@ -0,0 +1,15 @@ ++MODEL_PATH=Qwen/Qwen3-8B ++ ++CMD="python3 examples/run_server.py \ ++ --model-path $MODEL_PATH \ ++ --tp-size 1 \ ++ --dp-size 1 \ ++ --block-size 16 \ ++ --num-cpu-blocks 7282 \ ++ --server-recv-port ipc:///tmp/tmpe0x8_0gq \ ++ " ++echo ++echo ++ ++echo $CMD ++eval $CMD +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/serving_vllm.sh b/benchmarks/flexkv_benchmark/serving_vllm.sh +new file mode 100644 +index 000000000..18dee4732 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/serving_vllm.sh +@@ -0,0 +1,79 @@ ++#!/bin/bash ++# vLLM服务启动脚本 ++# 使用方法: ./serving_vllm.sh ++# type选项: ++# 0: 无前缀缓存 ++# 1: GPU前缀缓存 ++# 2: FlexKV ++# 3: LMCache (需要先配置LMCACHE_CONFIG_FILE环境变量) ++ ++MODEL_PATH=Qwen/Qwen3-8B ++ ++type=${1} ++ ++if [[ $type = 0 ]]; then ++ # no prefix cache ++ prefix_args="--no-enable-prefix-caching" ++ use_lmcache=false ++elif [[ $type = 1 ]]; then ++ # gpu prefix cache ++ prefix_args="" ++ use_lmcache=false ++elif [[ $type = 2 ]]; then ++ # flexkv ++ prefix_args="" ++ export ENABLE_FLEXKV="true" ++ export FLEXKV_SERVER_RECV_PORT="ipc:///tmp/tmpe0x8_0gq" ++ use_lmcache=false ++elif [[ $type = 3 ]]; then ++ # lmcache ++ prefix_args="" ++ use_lmcache=true ++ export LMCACHE_CONFIG_FILE="./lmcache_config.yaml" ++else ++ echo "ERROR: Unknown running type [$type]" ++ exit -1 ++fi ++ ++# nccl envs ++export GLOO_SOCKET_IFNAME=eno1 ++export NCCL_SOCKET_IFNAME=eno1 ++export NCCL_IB_GID_INDEX=3 ++export NCCL_IB_DISABLE=0 ++export NCCL_NET_GDR_LEVEL=2 ++export NCCL_IB_QPS_PER_CONNECTION=4 ++export NCCL_IB_TC=160 ++export NCCL_IB_TIMEOUT=22 ++export NCCL_PXN_DISABLE=0 ++ ++if [[ $use_lmcache = true ]]; then ++ # 使用vllm serve命令和LMCache ++ CMD="python3 -m vllm.entrypoints.openai.api_server --model $MODEL_PATH \ ++ --port=12599 \ ++ --tensor-parallel-size=1 \ ++ --data-parallel-size=1 \ ++ --pipeline-parallel-size=1 \ ++ --max-model-len=8192 \ ++ --max-num-seqs=256 \ ++ --gpu-memory-utilization 0.4 \ ++ --max-num-batched-tokens 8192 \ ++ --kv-transfer-config '{\"kv_connector\":\"LMCacheConnectorV1\",\"kv_role\":\"kv_both\"}' \ ++ $prefix_args" ++else ++ # 使用原有的api_server启动方式 ++ CMD="python3 -m vllm.entrypoints.openai.api_server --model $MODEL_PATH \ ++ --port=12599 \ ++ --tensor-parallel-size=1 \ ++ --data-parallel-size=1 \ ++ --pipeline-parallel-size=1 \ ++ --max-model-len=8192 \ ++ --max-num-seqs=256 \ ++ --gpu-memory-utilization 0.4 \ ++ --max-num-batched-tokens 8192 \ ++ $prefix_args" ++fi ++echo ++echo ++ ++echo $CMD ++eval $CMD +\ No newline at end of file +diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py +new file mode 100644 +index 000000000..6ff17dfca +--- /dev/null ++++ b/examples/offline_inference/prefix_caching_flexkv.py +@@ -0,0 +1,123 @@ ++# SPDX-License-Identifier: Apache-2.0 ++import os ++ ++from vllm import LLM, SamplingParams ++from vllm.distributed import cleanup_dist_env_and_memory ++ ++# NOTE: This is just a running example. For benchmarking purpose, ++# please see benchmarks/benchmark_prefix_caching.py ++ ++os.environ["ENABLE_FLEXKV"] = "true" ++os.environ["FLEXKV_SERVER_RECV_PORT"] = "ipc:///tmp/tmpe0x8_0gq" ++ ++# Common prefix. ++prefix = ( ++ "You are an expert school principal, skilled in effectively managing " ++ "faculty and staff. Draft 10-15 questions for a potential first grade " ++ "Head Teacher for my K-12, all-girls', independent school that emphasizes " ++ "community, joyful discovery, and life-long learning. The candidate is " ++ "coming in for a first-round panel interview for a 8th grade Math " ++ "teaching role. They have 5 years of previous teaching experience " ++ "as an assistant teacher at a co-ed, public school with experience " ++ "in middle school math teaching. Based on these information, fulfill " ++ "the following paragraph: ") ++ ++# Sample prompts. ++prompts = [ ++ "Hello, my name is", ++ "The president of the United States is", ++ "The capital of France is", ++ "The future of AI is", ++] ++ ++generating_prompts = [prefix + prompt for prompt in prompts] ++ ++# Create a sampling params object. ++sampling_params = SamplingParams(temperature=0.0) ++ ++def main(): ++ # Create an LLM without prefix caching as a baseline. ++ regular_llm = LLM(model="facebook/opt-125m", ++ enable_prefix_caching=False, ++ gpu_memory_utilization=0.4) ++ ++ print("Results without `enable_prefix_caching`") ++ ++ # ruff: noqa: E501 ++ # Generate texts from the prompts. The output is a list of RequestOutput objects ++ # that contain the prompt, generated text, and other information. ++ outputs = regular_llm.generate(generating_prompts, sampling_params) ++ ++ regular_generated_texts = [] ++ # Print the outputs. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ regular_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Destroy the LLM object and free up the GPU memory. ++ del regular_llm ++ cleanup_dist_env_and_memory() ++ ++ # Create an LLM with prefix caching enabled. ++ prefix_cached_llm = LLM(model="facebook/opt-125m", ++ enable_prefix_caching=True, ++ gpu_memory_utilization=0.4) ++ ++ # Warmup so that the shared prompt's KV cache is computed. ++ prefix_cached_llm.generate(generating_prompts[0], sampling_params) ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `enable_prefix_caching`") ++ ++ cached_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ cached_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == cached_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ # reset prefix cache to use flexkv ++ prefix_cached_llm.reset_prefix_cache() ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `flexkv`") ++ ++ flexkv_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ flexkv_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == flexkv_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/vllm/distributed/flexkv_extension/__init__.py b/vllm/distributed/flexkv_extension/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/vllm/distributed/flexkv_extension/client.py b/vllm/distributed/flexkv_extension/client.py +new file mode 100644 +index 000000000..478683fa9 +--- /dev/null ++++ b/vllm/distributed/flexkv_extension/client.py +@@ -0,0 +1,101 @@ ++import torch ++from typing import Optional ++ ++from flexkv.server.client import KVDPClient, KVTPClient ++from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType ++from flexkv.common.config import ModelConfig ++from vllm.distributed.flexkv_extension.config import FlexKVConfig ++from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec ++from vllm.logger import init_logger ++ ++logger = init_logger(__name__) ++ ++ ++class FlexKVDPClient: ++ def __init__( ++ self, ++ flexkv_config: FlexKVConfig ++ ): ++ self.flexkv_config = flexkv_config ++ self.server_recv_port = flexkv_config.server_recv_port ++ self.tp_size = flexkv_config.tp_size ++ self.model_config = ModelConfig( ++ num_layers=flexkv_config.num_layers, ++ num_kv_heads=flexkv_config.num_kv_heads, ++ head_size=flexkv_config.head_size, ++ use_mla=flexkv_config.use_mla, ++ dtype=flexkv_config.dtype, ++ tp_size=flexkv_config.tp_size, ++ ) ++ ++ logger.info(f"start init FlexKVDPClient to {self.server_recv_port}") ++ self.dp_client = KVDPClient(self.server_recv_port, self.model_config) ++ logger.info(f"finish init FlexKVDPClient") ++ ++ def put_async( ++ self, ++ token_ids: torch.Tensor, ++ slot_mapping: torch.Tensor, ++ token_mask: Optional[torch.Tensor] = None, ++ ) -> int: ++ " return task_id " ++ return self.dp_client.put_async(token_ids, slot_mapping, token_mask) ++ ++ def get_async( ++ self, ++ token_ids: torch.Tensor, ++ slot_mapping: torch.Tensor, ++ token_mask: Optional[torch.Tensor] = None, ++ ) -> int: ++ " return task_id " ++ return self.dp_client.get_async(token_ids, slot_mapping, token_mask) ++ ++ def wait( ++ self, ++ wait_task_ids: list[int], ++ ) -> dict[int, torch.Tensor]: ++ return self.dp_client.wait(wait_task_ids) ++ ++ def try_wait( ++ self, ++ wait_task_ids: list[int], ++ ) -> dict[int, Optional[torch.Tensor]]: ++ # print("--------------------------------") ++ # print(f"[FlexKVDPClient] About to call dp_client.try_wait with {wait_task_ids}") ++ try: ++ result = self.dp_client.try_wait(wait_task_ids) ++ # print(f"[FlexKVDPClient] dp_client.try_wait returned: {result}") ++ return result ++ except Exception as e: ++ # print(f"[FlexKVDPClient ERROR] Exception calling dp_client.try_wait: {e}") ++ import traceback ++ traceback.print_exc() ++ return {} ++ ++ ++class FlexKVTPClient: ++ def __init__( ++ self, ++ flexkv_config: FlexKVConfig, ++ dp_client_id: int, ++ tp_rank: int, ++ device_id: int, ++ gpu_blocks: list[torch.Tensor], ++ kv_shape: tuple[int], ++ ): ++ logger.info(f"start init FlexKVTPClient to {flexkv_config.server_recv_port}") ++ self.tp_client = KVTPClient(flexkv_config.server_recv_port, dp_client_id, device_id, tp_rank) ++ logger.info(f"finish init FlexKVTPClient") ++ gpu_layout = KVCacheLayout( ++ type=KVCacheLayoutType.LAYERWISE, ++ num_layer=flexkv_config.num_layers, ++ num_block=flexkv_config.num_blocks, ++ tokens_per_block=flexkv_config.block_size, ++ num_head=flexkv_config.num_kv_heads, ++ head_size=flexkv_config.head_size, ++ is_mla=flexkv_config.use_mla, ++ ) ++ logger.info(f"start register FlexKVTPClient") ++ self.tp_client.register_to_server(gpu_blocks, gpu_layout) ++ ++ logger.info(f"finish register FlexKVTPClient") +\ No newline at end of file +diff --git a/vllm/distributed/flexkv_extension/config.py b/vllm/distributed/flexkv_extension/config.py +new file mode 100644 +index 000000000..f2724e712 +--- /dev/null ++++ b/vllm/distributed/flexkv_extension/config.py +@@ -0,0 +1,45 @@ ++from dataclasses import dataclass ++import json ++import os ++import torch ++from vllm.v1.kv_cache_interface import KVCacheConfig, FullAttentionSpec ++from vllm.logger import init_logger ++ ++logger = init_logger(__name__) ++ ++ ++@dataclass ++class FlexKVConfig: ++ enable_flexkv: bool ++ server_recv_port: str ++ num_blocks: int = None ++ block_size: int = None ++ num_layers: int = None ++ num_kv_heads: int = None ++ head_size: int = None ++ dtype: torch.dtype = None ++ use_mla: bool = False ++ tp_size: int = 1 ++ ++ @classmethod ++ def from_env(cls) -> 'FlexKVConfig': ++ enable_flexkv = (os.getenv('ENABLE_FLEXKV', "false").lower() == "true") ++ server_recv_port = os.getenv('FLEXKV_SERVER_RECV_PORT', "") ++ ++ return cls(enable_flexkv=enable_flexkv, ++ server_recv_port=server_recv_port) ++ ++ def post_init( ++ self, ++ kv_cache_config: KVCacheConfig, ++ tp_size: int ++ ): ++ self.num_blocks = kv_cache_config.num_blocks ++ self.num_layers = len(kv_cache_config.kv_cache_groups) ++ kv_cache_spec: FullAttentionSpec = kv_cache_config.kv_cache_groups[0].kv_cache_spec ++ self.block_size = kv_cache_spec.block_size ++ self.num_kv_heads = kv_cache_spec.num_kv_heads ++ self.head_size = kv_cache_spec.head_size ++ self.dtype = kv_cache_spec.dtype ++ self.use_mla = kv_cache_spec.use_mla ++ self.tp_size = tp_size +\ No newline at end of file +diff --git a/vllm/logger.py b/vllm/logger.py +index 69aaf4390..fe426f420 100644 +--- a/vllm/logger.py ++++ b/vllm/logger.py +@@ -21,7 +21,7 @@ VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH + VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL + VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX + +-_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " ++_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s.%(msecs)03d " + "[%(filename)s:%(lineno)d] %(message)s") + _DATE_FORMAT = "%m-%d %H:%M:%S" + +diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py +index 5b0218640..aa590eb6f 100644 +--- a/vllm/v1/core/kv_cache_utils.py ++++ b/vllm/v1/core/kv_cache_utils.py +@@ -87,8 +87,9 @@ class PrefixCachingMetrics: + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 ++ self.aggregated_query_flexkv_hit = 0 + # A deque of (requests, queries, hits) for the most recent requests. +- self.query_queue: deque[tuple[int, int, int]] = deque() ++ self.query_queue: deque[tuple[int, int, int, int]] = deque() + + def observe(self, stats: PrefixCacheStats): + """Observe the prefix caching for a set of requests. +@@ -108,14 +109,15 @@ class PrefixCachingMetrics: + self.reset() + + # Update the metrics. +- self.query_queue.append((stats.requests, stats.queries, stats.hits)) ++ self.query_queue.append((stats.requests, stats.queries, stats.hits, stats.flexkv_hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits ++ self.aggregated_query_flexkv_hit += stats.flexkv_hits + + # Remove the oldest stats if the number of requests exceeds. + if self.aggregated_requests > self.max_recent_requests: +- old_requests, old_queries, old_hits = self.query_queue.popleft() ++ old_requests, old_queries, old_hits, _ = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits +@@ -125,6 +127,7 @@ class PrefixCachingMetrics: + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 ++ self.aggregated_query_flexkv_hit = 0 + self.query_queue.clear() + + @property +@@ -133,6 +136,13 @@ class PrefixCachingMetrics: + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total ++ ++ @property ++ def flexkv_hit_rate(self) -> float: ++ """Calculate the hit rate for the past N requests.""" ++ if self.aggregated_query_total == 0: ++ return 0.0 ++ return self.aggregated_query_flexkv_hit / self.aggregated_query_total + + + @dataclass +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 446f98034..b465c4cf1 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -5,6 +5,7 @@ from __future__ import annotations + + import itertools + import time ++import torch + from collections import defaultdict + from collections.abc import Iterable + from typing import Any, Optional, Union +@@ -34,6 +35,9 @@ from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + from vllm.v1.structured_output import StructuredOutputManager ++# flexkv ++from vllm.utils import cdiv ++from vllm.distributed.flexkv_extension.config import FlexKVConfig + + logger = init_logger(__name__) + +@@ -162,6 +166,23 @@ class Scheduler(SchedulerInterface): + ) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + ++ # flexkv ++ self.enable_flexkv = False ++ self.flexkv_client = None ++ # task_id -> Request ++ self.load_kv_tasks: dict[int, Request] = {} ++ # task_id -> Request ++ self.offload_kv_tasks: dict[int, Request] = {} ++ # request_id -> time info ++ self.flexkv_timer: dict[str, dict[str, float]] = {} ++ ++ ++ def init_flexkv(self, flexkv_config: FlexKVConfig) -> int: ++ self.enable_flexkv = True ++ from vllm.distributed.flexkv_extension.client import FlexKVDPClient ++ self.flexkv_client = FlexKVDPClient(flexkv_config) ++ return self.flexkv_client.dp_client.dp_client_id ++ + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. +@@ -174,6 +195,13 @@ class Scheduler(SchedulerInterface): + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + ++ # flexkv ++ if self.enable_flexkv: ++ # aviod busy loop ++ if self.get_num_unfinished_requests() == 0: ++ time.sleep(0.01) ++ self.check_offload_kv_tasks() ++ + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] +@@ -448,6 +476,27 @@ class Scheduler(SchedulerInterface): + if new_blocks is None: + # The request cannot be scheduled. + break ++ ++ if self.enable_flexkv and num_new_tokens > self.block_size and request.status == RequestStatus.WAITING: ++ # don't match the last block ++ num_new_blocks_to_get = cdiv(num_new_tokens, self.block_size)-1 ++ num_new_tokens_to_match = num_new_blocks_to_get*self.block_size ++ num_tokens_to_get = num_computed_tokens + num_new_tokens_to_match ++ blocks_ids_to_get = [block.block_id for block in new_blocks.blocks[0][:num_new_blocks_to_get]] ++ slot_mapping = torch.tensor(blocks_ids_to_get).repeat_interleave(self.block_size)*self.block_size ++ token_mask_to_get = torch.ones(num_tokens_to_get, dtype=torch.bool) ++ token_mask_to_get[:num_computed_tokens] = False ++ t_async_get_start = time.monotonic() ++ task_id = self.flexkv_client.get_async( ++ token_ids=torch.tensor(request.all_token_ids[:num_tokens_to_get]), ++ slot_mapping=slot_mapping, ++ token_mask=token_mask_to_get) ++ t_async_get_return = time.monotonic() ++ ++ self.load_kv_tasks[task_id] = request ++ self.flexkv_timer[request.request_id] = {} ++ self.flexkv_timer[request.request_id]['get_async_start'] = t_async_get_start ++ self.flexkv_timer[request.request_id]['get_async_return'] = t_async_get_return + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that +@@ -505,6 +554,31 @@ class Scheduler(SchedulerInterface): + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget ++ # batch wait ++ ++ # batch wait ++ if self.enable_flexkv: ++ if len(self.load_kv_tasks) != 0: ++ task_ids = list(self.load_kv_tasks.keys()) ++ print(f"[DEBUG] scheduler wait for {task_ids}") ++ results = self.flexkv_client.wait(task_ids) ++ print(f"[DEBUG] scheduler wait result: {results}") ++ t_async_get_end = time.monotonic() ++ for task_id, task_result in results.items(): ++ request = self.load_kv_tasks.pop(task_id) ++ t_get_async_start = self.flexkv_timer[request.request_id]["get_async_start"] ++ t_get_async_return = self.flexkv_timer[request.request_id]["get_async_return"] ++ match_length = task_result.sum().item() ++ self.flexkv_timer.pop(request.request_id) ++ logger.info( ++ f"[FlexKV] req: {request.request_id}, task: {task_id}, " ++ f"get {match_length} tokens cost {(t_async_get_end-t_get_async_start)*1000:.2f} ms, " ++ f"get_async() api cost {(t_get_async_return-t_get_async_start)*1000:.2f} ms") ++ ++ token_budget += match_length ++ num_scheduled_tokens[request.request_id] -= match_length ++ request.num_computed_tokens += match_length ++ self.kv_cache_manager.prefix_cache_stats.flexkv_hits += (match_length//self.block_size) + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: +@@ -1016,11 +1090,49 @@ class Scheduler(SchedulerInterface): + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + +- if not delay_free_blocks: +- self._free_blocks(request) ++ # flexkv: offload BEFORE freeing blocks to preserve req_to_blocks info ++ if self.enable_flexkv: ++ self._offload_kv(request) ++ else: ++ if not delay_free_blocks: ++ self._free_blocks(request) ++ # else: ++ # self._free_block(request) ++ + + return kv_xfer_params + ++ def _free_block(self, request: Request) -> None: ++ self.kv_cache_manager.free(request) ++ self.kv_cache_manager.free_block_hashes(request) ++ del self.requests[request.request_id] ++ ++ def _offload_kv(self, request: Request): ++ # print(f"single_type_managers: {self.kv_cache_manager.coordinator.single_type_managers}") ++ # print(f"req_to_blocks: {self.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks}") ++ req_blocks = self.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks.get(request.request_id, []) ++ req_token_ids = torch.tensor(request.all_token_ids[:-1]) ++ req_block_ids = torch.tensor([block.block_id for block in req_blocks]) ++ ++ # Debug information for empty req_blocks ++ # if len(req_blocks) == 0: ++ # print(f"WARNING: Empty req_blocks for request {request.request_id}") ++ # print(f" request.all_token_ids length: {len(request.all_token_ids)}") ++ # print(f" req_token_ids length: {len(req_token_ids)}") ++ # print(f" req_to_blocks keys: {list(self.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks.keys())}") ++ ++ slot_mapping = req_block_ids.repeat_interleave(self.block_size)[:len(req_token_ids)] * self.block_size ++ ++ # Additional debug info ++ # print(f"FlexKV _offload_kv: req_id={request.request_id}, " ++ # f"blocks={len(req_blocks)}, tokens={len(req_token_ids)}, slots={len(slot_mapping)}") ++ ++ self.flexkv_timer[request.request_id] = {} ++ self.flexkv_timer[request.request_id]["put_async_start"] = time.monotonic() ++ task_id = self.flexkv_client.put_async(token_ids=req_token_ids, slot_mapping=slot_mapping) ++ self.offload_kv_tasks[task_id] = request ++ self.flexkv_timer[request.request_id]["put_async_return"] = time.monotonic() ++ + def _free_blocks(self, request: Request): + assert request.is_finished() + self.kv_cache_manager.free(request) +@@ -1068,7 +1180,27 @@ class Scheduler(SchedulerInterface): + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) + return spec_decoding_stats +- ++ ++ def check_offload_kv_tasks(self): ++ if len(self.offload_kv_tasks) == 0: ++ return ++ logger.info(f"check_offload_kv_tasks") ++ task_ids = list(self.offload_kv_tasks.keys()) ++ results = self.flexkv_client.try_wait(task_ids) ++ # logger.info(f"results {results}") ++ t_async_put_end = time.monotonic() ++ for task_id, task_result in results.items(): ++ if task_result is not None: ++ request = self.offload_kv_tasks.pop(task_id) ++ t_put_async_start = self.flexkv_timer[request.request_id]["put_async_start"] ++ t_put_async_return = self.flexkv_timer[request.request_id]["put_async_return"] ++ self.flexkv_timer.pop(request.request_id) ++ logger.info( ++ f"[FlexKV] req: {request.request_id}, task: {task_id}, " ++ f"put {sum(task_result).item()} tokens cost {(t_async_put_end-t_put_async_start)*1000:.2f} ms, " ++ f"put_async() api cost {(t_put_async_return-t_put_async_start)*1000:.2f} ms") ++ self._free_block(request) ++ + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() +diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py +index 7779b559c..2d17908ea 100644 +--- a/vllm/v1/engine/core.py ++++ b/vllm/v1/engine/core.py +@@ -46,6 +46,8 @@ from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + from vllm.v1.structured_output import StructuredOutputManager + from vllm.version import __version__ as VLLM_VERSION + ++from vllm.distributed.flexkv_extension.config import FlexKVConfig ++ + logger = init_logger(__name__) + + POLLING_TIMEOUT_S = 2.5 +@@ -118,6 +120,8 @@ class EngineCore: + log_stats=self.log_stats, + ) + ++ self.init_flexkv(vllm_config, kv_cache_config) ++ + # Setup MM Input Mapper. + self.mm_input_cache_server = MirroredProcessingCache( + vllm_config.model_config) +@@ -194,6 +198,23 @@ class EngineCore: + "warmup model) took %.2f seconds"), elapsed) + return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + ++ ++ def init_flexkv( ++ self, ++ taco_llm_config: VllmConfig, ++ kv_cache_config: KVCacheConfig ++ ): ++ self.scheduler: V1Scheduler ++ if taco_llm_config.cache_config.enable_prefix_caching: ++ flexkv_config = FlexKVConfig.from_env() ++ if flexkv_config.enable_flexkv: ++ flexkv_config.post_init( ++ kv_cache_config=kv_cache_config, ++ tp_size=taco_llm_config.parallel_config.tensor_parallel_size, ++ ) ++ dp_client_id = self.scheduler.init_flexkv(flexkv_config) ++ self.model_executor.init_flexkv(flexkv_config, dp_client_id) ++ + def add_request(self, request: EngineCoreRequest): + """Add request to the scheduler.""" + if pooling_params := request.pooling_params: +diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py +index 50b9634a4..3d7bdd4c8 100644 +--- a/vllm/v1/executor/abstract.py ++++ b/vllm/v1/executor/abstract.py +@@ -15,7 +15,7 @@ from vllm.executor.uniproc_executor import ( # noqa + UniProcExecutor as UniProcExecutorV0) + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + from vllm.v1.outputs import ModelRunnerOutput +- ++from vllm.distributed.flexkv_extension.config import FlexKVConfig + FailureCallback = Callable[[], None] + + +@@ -88,6 +88,10 @@ class Executor(ExecutorBase): + args=(scheduler_output, )) + return output[0] + ++ def init_flexkv(self, flexkv_config: FlexKVConfig, dp_client_id: int): ++ self.collective_rpc("init_flexkv", ++ args=(flexkv_config, dp_client_id, )) ++ + @property + def max_concurrent_batches(self) -> int: + return 1 +diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py +index 7f2556bab..e7fb79486 100644 +--- a/vllm/v1/metrics/loggers.py ++++ b/vllm/v1/metrics/loggers.py +@@ -125,7 +125,8 @@ class LoggingStatLogger(StatLoggerBase): + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Waiting: %d reqs, " + "GPU KV cache usage: %.1f%%, " +- "Prefix cache hit rate: %.1f%%", ++ "Prefix cache hit rate: %.1f%%, " ++ "FlexKV hit rate: %.1f%%", + self.engine_index, + prompt_throughput, + generation_throughput, +@@ -133,6 +134,7 @@ class LoggingStatLogger(StatLoggerBase): + scheduler_stats.num_waiting_reqs, + scheduler_stats.kv_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, ++ self.prefix_caching_metrics.flexkv_hit_rate * 100, + ) + self.spec_decoding_logging.log(log_fn=log_fn) + +diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py +index 1eb10ccb6..1073aa571 100644 +--- a/vllm/v1/metrics/stats.py ++++ b/vllm/v1/metrics/stats.py +@@ -24,7 +24,8 @@ class PrefixCacheStats: + queries: int = 0 + # The number of hits in these requests. + hits: int = 0 +- ++ # flexkv ++ flexkv_hits: int = 0 + + @dataclass + class SchedulerStats: +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index a5bf197ba..d10265d0c 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -2494,6 +2494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + ) == 0, "Attention backends are already initialized" + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): ++ print("init attn backend ", i) + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(kv_cache_spec, AttentionSpec): + attn_backend_i = get_attn_backend( +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 522946351..31a3bed13 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -18,7 +18,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, + set_custom_all_reduce) + from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) +-from vllm.distributed.parallel_state import get_pp_group, get_tp_group ++from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_tensor_model_parallel_rank + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest + from vllm.model_executor import set_random_seed +@@ -33,6 +33,10 @@ from vllm.v1.utils import report_usage_stats + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.worker_base import WorkerBase + ++# flexkv ++from vllm.distributed.flexkv_extension.config import FlexKVConfig ++from vllm.distributed.flexkv_extension.client import FlexKVTPClient ++ + logger = init_logger(__name__) + + if TYPE_CHECKING: +@@ -556,6 +560,23 @@ class Worker(WorkerBase): + max_size=max_size, + ) + ++ def init_flexkv( ++ self, ++ flexkv_config: FlexKVConfig, ++ dp_client_id: int, ++ ) -> None: ++ from vllm.distributed.flexkv_extension.client import FlexKVTPClient ++ layer_kv_shape = self.model_runner.attn_backends[0].get_kv_cache_shape( ++ flexkv_config.num_blocks, flexkv_config.block_size, ++ flexkv_config.num_kv_heads, flexkv_config.head_size) ++ kv_shape = (flexkv_config.num_layers, *layer_kv_shape) ++ self.flexkv_client = FlexKVTPClient(flexkv_config=flexkv_config, ++ dp_client_id=dp_client_id, ++ tp_rank=get_tensor_model_parallel_rank(), ++ device_id=self.device.index, ++ gpu_blocks=self.model_runner.kv_caches, ++ kv_shape=kv_shape) ++ + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", From b8a9ff0d47ba2880551c1c66c9bd492cd1c6ecb2 Mon Sep 17 00:00:00 2001 From: charliecgxu Date: Tue, 19 Aug 2025 16:14:41 +0800 Subject: [PATCH 02/42] radix tree c++ impl (#70) * radix tree implementation in c++ Signed-off-by: charliecgxu * support new radix-tree in cache engine Signed-off-by: charliecgxu --------- Signed-off-by: charliecgxu --- csrc/bindings.cpp | 33 +++ csrc/radix_tree.cpp | 268 ++++++++++++++++++++++++ csrc/radix_tree.h | 346 +++++++++++++++++++++++++++++++ flexkv/cache/cache_engine.py | 164 ++++++++++++++- flexkv/common/config.py | 1 + setup.py | 2 + tests/test_cache_engine_accel.py | 231 +++++++++++++++++++++ 7 files changed, 1039 insertions(+), 6 deletions(-) create mode 100644 csrc/radix_tree.cpp create mode 100644 csrc/radix_tree.h create mode 100644 tests/test_cache_engine_accel.py diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index a991954be0..2918cad482 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -19,6 +19,7 @@ #include "tp_transfer_thread_group.h" #include "transfer.cuh" #include "transfer_ssd.h" +#include "radix_tree.h" namespace py = pybind11; @@ -195,4 +196,36 @@ PYBIND11_MODULE(c_ext, m) { "Call Pcfs::write from C++", py::arg("file_nodeid"), py::arg("offset"), py::arg("buffer"), py::arg("size"), py::arg("thread_id")); #endif + + py::class_(m, "CRadixTreeIndex") + .def(py::init()) + .def("is_empty", &flexkv::CRadixTreeIndex::is_empty) + .def("reset", &flexkv::CRadixTreeIndex::reset) + .def("lock", &flexkv::CRadixTreeIndex::lock, py::arg("node")) + .def("unlock", &flexkv::CRadixTreeIndex::unlock, py::arg("node")) + .def("set_ready", &flexkv::CRadixTreeIndex::set_ready, + py::arg("node"), py::arg("ready"), py::arg("ready_length")) + .def("insert", &flexkv::CRadixTreeIndex::insert, py::return_value_policy::reference, + py::arg("physical_block_ids"), py::arg("block_hashes"), py::arg("num_blocks"), + py::arg("num_insert_blocks"), py::arg("ready") = true, py::arg("node") = nullptr, + py::arg("num_matched_blocks") = -1, py::arg("last_node_matched_length") = -1) + .def("evict", &flexkv::CRadixTreeIndex::evict, py::arg("evicted_blocks"), py::arg("num_evicted")) + .def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks) + .def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks) + .def("total_ready_blocks", &flexkv::CRadixTreeIndex::total_ready_blocks) + .def("match_prefix", &flexkv::CRadixTreeIndex::match_prefix, + py::arg("block_hashes"), py::arg("num_blocks"), py::arg("update_cache_info")); + + py::class_(m, "CRadixNode") + .def(py::init()) + .def("size", &flexkv::CRadixNode::size); + + py::class_>(m, "CMatchResult") + .def(py::init *>()) + .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) + .def_readonly("last_node", &flexkv::CMatchResult::last_node) + .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) + .def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks) + .def_readonly("num_matched_blocks", &flexkv::CMatchResult::num_matched_blocks) + .def_readonly("last_node_matched_length", &flexkv::CMatchResult::last_node_matched_length); } diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp new file mode 100644 index 0000000000..32dc8ab0e7 --- /dev/null +++ b/csrc/radix_tree.cpp @@ -0,0 +1,268 @@ +#include +#include +#include +#include +#include +#include + +#include "cache_utils.h" +#include "radix_tree.h" + +namespace flexkv { + +CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) { + assert(index != nullptr); + + this->on_leaf = false; + this->parent = nullptr; + this->index = index; + this->ready = ready; + this->lock_cnt = lock_cnt; + + struct timeval now; + gettimeofday(&now, nullptr); + last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + + index->inc_node_count(); +} + +CRadixNode::~CRadixNode() { + assert(parent == nullptr); + + block_hashes.clear(); + physical_blocks.clear(); + children.clear(); + + index->dec_node_count(); +} + +CRadixNode *CRadixNode::split(int prefix_length) { + assert(prefix_length < size()); + assert(prefix_length > 0); + assert(parent != nullptr); + + auto new_node = new CRadixNode(index, is_ready(), 0); + new_node->set_time(get_time()); + new_node->set_parent(parent); + get_index()->add_node(new_node); + + auto &new_block_hashes = new_node->get_block_hashes(); + auto &new_physical_blocks = new_node->get_physical_blocks(); + + new_block_hashes.insert(new_block_hashes.end(), block_hashes.cbegin(), block_hashes.cbegin() + prefix_length); + new_physical_blocks.insert(new_physical_blocks.end(), physical_blocks.cbegin(), physical_blocks.cbegin() + prefix_length); + + block_hashes.erase(block_hashes.begin(), block_hashes.begin() + prefix_length); + physical_blocks.erase(physical_blocks.begin(), physical_blocks.begin() + prefix_length); + + parent->set_child(new_node->get_head_hash(), new_node); + new_node->set_parent(parent); + new_node->set_child(get_head_hash(), this); + + set_parent(new_node); + return new_node; +} + +void CRadixNode::merge_child() { + auto child = children.begin()->second; + + assert(get_num_children() == 1); + assert(child->is_leaf()); + + block_hashes.insert(block_hashes.end(), child->get_block_hashes().cbegin(), + child->get_block_hashes().cend()); + physical_blocks.insert(physical_blocks.end(), child->get_physical_blocks().cbegin(), + child->get_physical_blocks().cend()); + + set_time(std::max(get_time(), child->get_time())); + children.clear(); + + child->clear_parent(); + index->remove_leaf(child); + index->remove_node(child); +} + +std::deque *CRadixNode::shrink(int length) { + assert(length < size()); + assert(length > 0); + assert(is_leaf()); + assert(in_use() == false); + + auto remaining_length = size() - length; + auto shrink_blocks = new std::deque(); + + shrink_blocks->insert(shrink_blocks->end(), physical_blocks.begin() + remaining_length, physical_blocks.end()); + + block_hashes.erase(block_hashes.begin() + remaining_length, block_hashes.end()); + physical_blocks.erase(physical_blocks.begin() + remaining_length, physical_blocks.end()); + + return shrink_blocks; +} + +CRadixNode *CRadixTreeIndex::insert(torch::Tensor &physical_block_ids, + torch::Tensor &block_hashes, int num_blocks, int num_insert_blocks, bool ready, + CRadixNode *last_node, int num_matched_blocks, int last_node_matched_length) { + if (num_insert_blocks == -1) { + num_insert_blocks = num_blocks; + } + assert(num_insert_blocks >= 0); + assert(num_insert_blocks <= num_blocks); + assert(physical_block_ids.ndim() == 1); + + if (last_node == nullptr) { + auto match_result = match_prefix(block_hashes, num_blocks, true); + num_matched_blocks = match_result->num_matched_blocks; + last_node_matched_length = match_result->last_node_matched_length; + last_node = match_result->last_node; + } + + assert(last_node != nullptr); + assert(last_node_matched_length != 0 || is_root(last_node)); + assert(physical_block_ids.size() == num_insert_blocks - num_matched_blocks); + + if (num_matched_blocks >= num_insert_blocks) { + return nullptr; + } + + auto new_node = new CRadixNode(this, ready, 0); + auto &new_block_hashes = new_node->get_block_hashes(); + auto &new_physical_blocks = new_node->get_physical_blocks(); + + auto block_hashes_ptr = block_hashes.data_ptr(); + auto physical_block_ids_ptr = physical_block_ids.data_ptr(); + for (auto i = 0; i + num_matched_blocks < num_insert_blocks; i++) { + new_block_hashes.insert(new_block_hashes.end(), block_hashes_ptr[i+num_matched_blocks]); + new_physical_blocks.insert(new_physical_blocks.end(), physical_block_ids_ptr[i]); + } + + if (last_node_matched_length < last_node->size()) { + last_node->split(last_node_matched_length); + last_node = last_node->get_parent(); + assert(last_node != nullptr); + } + + if (last_node->is_leaf()) { + remove_leaf(last_node); + } + + new_node->set_parent(last_node); + last_node->set_child(new_node->get_head_hash(), new_node); + + add_node(new_node); + add_leaf(new_node); + return new_node; +} + +int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, int num_evicted) { + int64_t *evicted_blocks_ptr = evicted_blocks.data_ptr(); + int has_evicted = 0; + std::priority_queue, CRadixNode::Compare> candidate; + + for (auto it = leaf_list.begin(); it != leaf_list.end(); it++) { + if ((*it)->evictable()) { + candidate.push(*it); + } + } + + while ((has_evicted < num_evicted) && candidate.size()) { + auto node = candidate.top(); + candidate.pop(); + + if (node->size() > num_evicted - has_evicted) { + auto blocks = node->shrink(num_evicted - has_evicted); + for (auto it = blocks->begin(); it != blocks->end(); it++) { + evicted_blocks_ptr[has_evicted] = *it; + has_evicted++; + } + delete blocks; + } else { + auto parent = node->get_parent(); + auto &blocks = node->get_physical_blocks(); + + assert(parent != nullptr); + parent->remove_child(node->get_head_hash()); + + for (auto it = blocks.begin(); it != blocks.end(); it++) { + evicted_blocks_ptr[has_evicted] = *it; + has_evicted++; + } + + if (parent->is_leaf() && !is_root(parent)) { + add_leaf(parent); + if (parent->evictable()) { + candidate.push(parent); + } + } + + node->clear_parent(); + remove_leaf(node); + remove_node(node); + } + } + return has_evicted; +} + +std::shared_ptr CRadixTreeIndex::match_prefix( + torch::Tensor &block_hashes, int num_blocks, bool update_cache_info) { + auto current_node = root; + auto last_ready_node = root; + auto prefix_blocks_num = 0; + auto ready_prefix_blocks_num = 0; + auto last_node_matched_length = 0; + auto physical_blocks = new std::vector(); + auto block_hashes_ptr = block_hashes.data_ptr(); + HashType child_hash; + + while (prefix_blocks_num < num_blocks) { + if (update_cache_info) { + current_node->update_time(); + } + + child_hash = HashType(block_hashes_ptr[prefix_blocks_num + current_node->size()]); + if (current_node->lookup_child(child_hash)) { + if (current_node->is_ready()) { + last_ready_node = current_node; + ready_prefix_blocks_num += current_node->size(); + } + prefix_blocks_num += current_node->size(); + physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), + current_node->get_physical_blocks().end()); + current_node = current_node->get_child(child_hash); + } else { + auto matched_length = 0; + if (is_root(current_node) == false) { + auto cmp_length = std::min(current_node->size(), num_blocks - prefix_blocks_num); + auto left = 0; + auto right = cmp_length; + + while (left < right) { + auto mid = (left + right) / 2; + if (current_node->get_hash(mid) == HashType(block_hashes_ptr[prefix_blocks_num+mid])) { + left = mid + 1; + } else { + right = mid; + } + } + matched_length = left; + physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), + current_node->get_physical_blocks().begin() + matched_length); + } else { + matched_length = 0; + } + + if (current_node->is_ready()) { + last_ready_node = current_node; + ready_prefix_blocks_num += matched_length; + } + + last_node_matched_length = matched_length; + prefix_blocks_num += matched_length; + break; + } + } + + return std::make_shared(prefix_blocks_num, ready_prefix_blocks_num, last_node_matched_length, + last_ready_node, current_node, physical_blocks); +} + +} // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h new file mode 100644 index 0000000000..63560a3a8d --- /dev/null +++ b/csrc/radix_tree.h @@ -0,0 +1,346 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "cache_utils.h" + +namespace flexkv { + +class CRadixTreeIndex; + +class CRadixNode { +private: + bool on_leaf; + bool ready; + int lock_cnt; + time_t last_access_time; + + std::deque block_hashes; + std::deque physical_blocks; + std::map children; + + CRadixTreeIndex *index; + CRadixNode *parent; + +public: + CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt); + ~CRadixNode(); + + struct Compare { + bool operator() (CRadixNode *a, CRadixNode *b) { + return a->get_time() > b->get_time(); + } + }; + + bool get_leaf_state() { + return on_leaf; + } + + void set_leaf_state(bool on_leaf) { + this->on_leaf = on_leaf; + } + + CRadixTreeIndex *get_index() { + return index; + } + + void set_time(time_t time) { + last_access_time = time; + } + + time_t get_time() { + return last_access_time; + } + + void update_time() { + struct timeval now; + + gettimeofday(&now, nullptr); + last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + } + + CRadixNode *get_parent() { + return parent; + } + + void set_parent(CRadixNode *parent) { + this->parent = parent; + } + + void clear_parent() { + this->parent = nullptr; + } + + HashType get_hash(int pos) { + return HashType(block_hashes[pos]); + } + + HashType get_head_hash() { + if (size() > 0) { + return HashType(block_hashes[0]); + } else { + return HashType(0); + } + } + + int size() { + return block_hashes.size(); + } + + int get_num_children() { + return children.size(); + } + + std::deque &get_block_hashes() { + return block_hashes; + } + + std::deque &get_physical_blocks() { + return physical_blocks; + } + + bool lookup_child(HashType hash) { + auto iter = children.find(hash); + if (iter != children.end()) + return true; + else + return false; + } + + CRadixNode *get_child(HashType hash) { + return children.at(hash); + } + + void set_child(HashType hash, CRadixNode *node) { + children[hash] = node; + } + + void remove_child(HashType hash) { + children.erase(hash); + } + + bool is_leaf() { + return get_num_children() == 0; + } + + bool in_use() { + return lock_cnt > 0 || !ready; + } + + bool evictable() { + return is_leaf() && !in_use(); + } + + void lock() { + assert(lock_cnt >= 0); + lock_cnt++; + } + + void unlock() { + assert(lock_cnt > 0); + lock_cnt--; + } + + void set_ready(bool ready) { + this->ready = ready; + } + + bool is_ready() { + return ready; + } + + CRadixNode *split(int prefix_length); + std::deque *shrink(int length); + void merge_child(); +}; + +class CMatchResult { +public: + int num_ready_matched_blocks; + int num_matched_blocks; + int last_node_matched_length; + + CRadixNode *last_ready_node; + CRadixNode *last_node; + std::vector *physical_blocks; + + CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, + CRadixNode *_last_ready_node, CRadixNode *_last_node, std::vector *blocks) + : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), + last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), + last_node(_last_node), physical_blocks(blocks) { + } + + ~CMatchResult() { + delete physical_blocks; + }; +}; + +class CRadixTreeIndex { +private: + CRadixNode *root; + std::list node_list; + std::list leaf_list; + + int max_num_blocks; + int tokens_per_block; + int node_count; + +public: + CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000) { + this->tokens_per_block = tokens_per_block; + this->max_num_blocks = max_num_blocks; + this->node_count = 0; + + root = new CRadixNode(this, true, 0); + node_list.push_back(root); + } + + ~CRadixTreeIndex() { + leaf_list.clear(); + + while (node_list.size()) { + auto node = node_list.front(); + node->set_parent(nullptr); + node_list.pop_front(); + delete node; + } + + if (node_count) { + std::cerr << "CRadix Node count" << node_count << std::endl; + } + } + + void reset() { + leaf_list.clear(); + + while (node_list.size()) { + auto node = node_list.front(); + node->set_parent(nullptr); + node_list.pop_front(); + delete node; + } + + root = new CRadixNode(this, true, 0); + node_list.push_back(root); + } + + bool is_root(CRadixNode *node) { + return node == root; + } + + CRadixNode *get_root() { + return root; + } + + void remove_node(CRadixNode *node) { + assert(node != root); + assert(node->get_parent() == nullptr); + + node_list.remove(node); + delete node; + } + + void remove_leaf(CRadixNode *node) { + assert(node != root); + assert(node->get_leaf_state()); + + if (node->get_leaf_state() == false) { + return; + } + + leaf_list.remove(node); + node->set_leaf_state(false); + } + + void add_node(CRadixNode *node) { + assert(node != nullptr); + assert(node->get_parent() != nullptr); + node_list.push_back(node); + } + + void add_leaf(CRadixNode *node) { + assert(node != nullptr); + assert(node->get_leaf_state() == false); + + if (node->get_leaf_state() == true) { + return; + } + + leaf_list.push_back(node); + node->set_leaf_state(true); + } + + void lock(CRadixNode *node) { + node->lock(); + } + + void unlock(CRadixNode *node) { + node->unlock(); + } + + bool is_empty() { + return node_list.size() == 1; + } + + void inc_node_count() { + node_count++; + } + + void dec_node_count() { + node_count--; + } + + void set_ready(CRadixNode *node, bool ready = true, int ready_length = -1) { + node->set_ready(ready); + if (ready_length > 0) { + ready_length -= node->size(); + while (ready_length > 0) { + assert(node->get_parent() != nullptr); + node = node->get_parent(); + ready_length -= node->size(); + node->set_ready(true); + } + assert(ready_length == 0); + } + } + + int total_node_num() { + return node_list.size() - 1; + } + + int total_cached_blocks() { + auto total_blocks = 0; + + for (auto it = node_list.begin(); it != node_list.end(); it++) { + total_blocks += (*it)->size(); + } + return total_blocks; + } + + int total_ready_blocks() { + auto total_blocks = 0; + for (auto it = node_list.begin(); it != node_list.end(); it++) { + if ((*it)->is_ready()) { + total_blocks += (*it)->size(); + } + } + return total_blocks; + } + + int total_unready_blocks() { + return total_cached_blocks() - total_ready_blocks(); + } + + int evict(torch::Tensor &evicted_blocks, int num_evicted); + std::shared_ptr match_prefix(torch::Tensor &block_hashes, + int num_blocks, bool update_cache_info = true); + CRadixNode *insert(torch::Tensor &physical_block_ids, torch::Tensor &block_hashes, int num_blocks, + int num_insert_blocks, bool ready = true, CRadixNode *node = nullptr, int num_matched_blocks = -1, + int last_node_matched_length = -1); +}; + +} // namespace flexkv diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 669163cc83..3cf178a9d7 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -18,8 +18,10 @@ from functools import partial from queue import Queue from typing import List, Tuple, Optional, Dict, Callable +from dataclasses import dataclass import torch +from flexkv.c_ext import CRadixNode, CRadixTreeIndex, CMatchResult from flexkv.cache.mempool import Mempool from flexkv.cache.radixtree import RadixTreeIndex, RadixNode, MatchResult @@ -33,6 +35,108 @@ DeviceType, TransferOpGraph, TransferOp, TransferType, TransferDescriptor ) +@dataclass +class MatchResultAccel: + num_ready_matched_blocks: int = 0 + num_matched_blocks: int = 0 + last_ready_node: Optional['CRadixNode'] = None + last_node: Optional['CRadixNode'] = None + last_node_matched_length: int = 0 + physical_blocks: torch.Tensor = torch.empty(0, dtype=torch.int64) + + def __post_init__(self) -> None: + assert self.physical_blocks.ndim == 1 + assert self.physical_blocks.dtype == torch.int64 + +class CacheEngineAccel: + def __init__(self, + device_type: DeviceType, + num_total_blocks: int, + tokens_per_block: int): + if not isinstance(device_type, DeviceType): + raise InvalidConfigError(f"Unknown device type: {device_type}") + if num_total_blocks <= 0: + raise InvalidConfigError(f"Invalid num_total_blocks: {num_total_blocks}") + if tokens_per_block <= 0 or (tokens_per_block & (tokens_per_block - 1)) != 0: + raise InvalidConfigError(f"Invalid tokens_per_block: {tokens_per_block}, " + f"tokens_per_block must be a power of 2") + + self.device_type = device_type + + self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks) + + self.mempool = Mempool(num_total_blocks=num_total_blocks) + + self.tokens_per_block = tokens_per_block + self.num_total_blocks = num_total_blocks + + def reset(self) -> None: + self.index.reset() + self.mempool.reset() + + def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: + sequence_meta.gen_hashes() + match_result = self.index.match_prefix(torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, True) + return MatchResultAccel(match_result.num_ready_matched_blocks, match_result.num_matched_blocks, + match_result.last_ready_node, match_result.last_node, + match_result.last_node_matched_length, + torch.tensor(match_result.physical_blocks, dtype=torch.int64)) + + def insert(self, + sequence_meta: SequenceMeta, + physical_block_ids: torch.Tensor, + num_insert_blocks: int = -1, + is_ready: bool = True, + match_result: Optional[MatchResultAccel] = None) -> Optional[CRadixNode]: + sequence_meta.gen_hashes() + if match_result is None: + return self.index.insert(physical_block_ids, + torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, + num_insert_blocks, + is_ready) + else: + return self.index.insert(physical_block_ids, + torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, + num_insert_blocks, + is_ready, + match_result.last_node, + match_result.num_matched_blocks, + match_result.last_node_matched_length) + + def lock_node(self, node: CRadixNode) -> None: + self.index.lock(node) + + def cleanup(self, node: CRadixNode, cleanup_length: int) -> None: + self.index.unlock(node) + self.index.set_ready(node, True, cleanup_length) + + def take(self, + num_required_blocks: int, + protected_node: Optional[CRadixNode] = None, + strict: bool = True) -> torch.Tensor: + if num_required_blocks > self.mempool.num_free_blocks: + if protected_node is not None: + self.index.lock(protected_node) + target_blocks = torch.zeros(num_required_blocks - self.mempool.num_free_blocks, dtype=torch.int64) + num_evicted = self.index.evict(target_blocks, num_required_blocks - self.mempool.num_free_blocks) + if num_evicted != num_required_blocks - self.mempool.num_free_blocks: + target_blocks.resize_(num_evicted) + self.mempool.recycle_blocks(target_blocks) + + if protected_node is not None: + self.index.unlock(protected_node) + if strict and num_required_blocks > self.mempool.num_free_blocks: + raise NotEnoughSpaceError("Not enough free blocks to take, ", + required=num_required_blocks, + available=self.mempool.num_free_blocks) + num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) + return self.mempool.allocate_blocks(num_allocated_blocks) + + def recycle(self, physical_blocks: torch.Tensor) -> None: + self.mempool.recycle_blocks(physical_blocks) class CacheEngine: def __init__(self, @@ -119,17 +223,32 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): self.cache_engines = {} if cache_config.enable_cpu: - self.cpu_cache_engine = CacheEngine(DeviceType.CPU, + if cache_config.index_accel: + self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, + cache_config.num_cpu_blocks, + cache_config.tokens_per_block) + else: + self.cpu_cache_engine = CacheEngine(DeviceType.CPU, cache_config.num_cpu_blocks, cache_config.tokens_per_block) self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: - self.ssd_cache_engine = CacheEngine(DeviceType.SSD, + if cache_config.index_accel: + self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, + cache_config.num_ssd_blocks, + cache_config.tokens_per_block) + else: + self.ssd_cache_engine = CacheEngine(DeviceType.SSD, cache_config.num_ssd_blocks, cache_config.tokens_per_block) self.cache_engines[DeviceType.SSD] = self.ssd_cache_engine if cache_config.enable_remote: - self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, + if cache_config.index_accel: + self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, + cache_config.num_remote_blocks, + cache_config.tokens_per_block) + else: + self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, cache_config.num_remote_blocks, cache_config.tokens_per_block) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine @@ -429,7 +548,10 @@ def _get_impl_local(self, assert self.cache_config.enable_cpu assert self.cpu_cache_engine is not None - cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result = self.match_local_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) # tailor the blocks to assure: # the blocks are needed by the mask & the blocks are ready @@ -639,7 +761,10 @@ def _put_impl_global(self, assert self.cpu_cache_engine is not None assert self.remote_cache_engine is not None - cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all(sequence_meta) cpu_matched_blocks = cpu_matched_result.physical_blocks[ :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ @@ -805,7 +930,10 @@ def _put_impl_local(self, assert self.cpu_cache_engine is not None # assert self.ssd_cache_engine is not None - cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result = self.match_local_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) cpu_matched_blocks = cpu_matched_result.physical_blocks[ :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ @@ -924,6 +1052,16 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]: + cpu_matched_result = MatchResultAccel() + ssd_matched_result = MatchResultAccel() + if self.cpu_cache_engine: + cpu_matched_result = self.cpu_cache_engine.match(sequence_meta) + if self.ssd_cache_engine: + ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) + + return cpu_matched_result, ssd_matched_result + def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() @@ -934,6 +1072,20 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe return cpu_matched_result, ssd_matched_result + def match_all_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: + cpu_matched_result = MatchResultAccel() + ssd_matched_result = MatchResultAccel() + remote_matched_result = MatchResultAccel() + # TODO: avoid redundant match? + if self.cpu_cache_engine: + cpu_matched_result = self.cpu_cache_engine.match(sequence_meta) + if self.ssd_cache_engine: + ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) + if self.remote_cache_engine: + remote_matched_result = self.remote_cache_engine.match(sequence_meta) + + return cpu_matched_result, ssd_matched_result, remote_matched_result + def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index a022292eb8..9148edf1d9 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -32,6 +32,7 @@ class CacheConfig: enable_remote: bool = False use_gds: bool = False use_pinned_memory: bool = False + index_accel: bool = False # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE diff --git a/setup.py b/setup.py index f69e28b923..67997949b2 100755 --- a/setup.py +++ b/setup.py @@ -25,12 +25,14 @@ "csrc/hash.cpp", "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", + "csrc/radix_tree.cpp", ] hpp_sources = [ "csrc/cache_utils.h", "csrc/tp_transfer_thread_group.h", "csrc/transfer_ssd.h", + "csrc/radix_tree.h", ] extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py new file mode 100644 index 0000000000..500b1ac7e9 --- /dev/null +++ b/tests/test_cache_engine_accel.py @@ -0,0 +1,231 @@ +import random + +import pytest +import torch + +from flexkv.cache.mempool import Mempool +from flexkv.cache.cache_engine import CacheEngineAccel +from flexkv.common.transfer import DeviceType +from flexkv.common.exceptions import InvalidConfigError, NotEnoughSpaceError +from flexkv.common.block import SequenceMeta + +@pytest.fixture +def cache_engine(request: pytest.FixtureRequest) -> CacheEngineAccel: + param = request.param if hasattr(request, 'param') else {} + default_config_kwargs = { + 'device_type': DeviceType.CPU, + 'num_total_blocks': 64, + 'tokens_per_block': 4, + } + default_config_kwargs.update(param) + return CacheEngineAccel(**default_config_kwargs) + +@pytest.mark.parametrize( + "config, should_raise", + [ + ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ] +) +def test_config_init(config: dict, should_raise: bool): + if should_raise: + with pytest.raises(InvalidConfigError) as e: + CacheEngineAccel(**config) + else: + engine = CacheEngineAccel(**config) + assert isinstance(engine, CacheEngineAccel) + +def test_mempool(): + mempool = Mempool(num_total_blocks=64) + assert mempool.num_free_blocks == 64 + block_ids = mempool.allocate_blocks(16) + assert isinstance(block_ids, torch.Tensor) + assert block_ids.dtype == torch.int64 + assert block_ids.shape == (16,) + assert mempool.num_free_blocks == 48 + mempool.recycle_blocks(block_ids) + assert mempool.num_free_blocks == 64 + + block_ids = torch.cat([mempool.allocate_blocks(16), + mempool.allocate_blocks(16), + mempool.allocate_blocks(16), + mempool.allocate_blocks(16)]) + assert mempool.num_free_blocks == 0 + + with pytest.raises(NotEnoughSpaceError): + mempool.allocate_blocks(1) + + mempool.recycle_blocks(block_ids) + assert mempool.num_free_blocks == 64 + + empty_blocks = mempool.allocate_blocks(0) + assert empty_blocks.shape == (0, ) + assert empty_blocks.dtype == torch.int64 + assert mempool.num_free_blocks == 64 + + with pytest.raises(ValueError): + mempool.allocate_blocks(-1) + + mempool.recycle_blocks(torch.tensor([], dtype=torch.int64)) + assert mempool.num_free_blocks == 64 + + with pytest.raises(ValueError): + mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int32)) + with pytest.raises(ValueError): + mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int64)) + with pytest.raises(ValueError): + mempool.recycle_blocks(torch.tensor([[1, 2, 3]], dtype=torch.int64)) + +def test_reset(cache_engine: CacheEngineAccel): + cache_engine.reset() + assert cache_engine.index.is_empty() + assert cache_engine.mempool.num_used_blocks == 0 + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 10000000, 'tokens_per_block': 1, 'device_type': DeviceType.CPU}, + {'num_total_blocks': 10000000, 'tokens_per_block': 16, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +@pytest.mark.parametrize( + "num_insert", + [100], +) +@pytest.mark.parametrize( + "seq_len", + [1, 10, 16, 32, 10000], +) +def test_match_and_insert(cache_engine: CacheEngineAccel, num_insert: int, seq_len: int): + base_token_ids = torch.randint(0, 10000, (seq_len, ), dtype=torch.int64) + base_num_blocks = seq_len // cache_engine.tokens_per_block + cache_engine.insert(SequenceMeta(token_ids=base_token_ids, + tokens_per_block=cache_engine.tokens_per_block), + torch.arange(base_num_blocks, dtype=torch.int64), + is_ready=True) + cur_cached_blocks = base_num_blocks + for i in range(num_insert): + prefix_ratio = random.random() + prefix_len = int(len(base_token_ids)*prefix_ratio) + num_prefix_blocks = prefix_len // cache_engine.tokens_per_block + token_ids = torch.cat([base_token_ids[:prefix_len], + torch.randint(10000 + i * seq_len, + 10000 + (i+1) * seq_len, + (seq_len-prefix_len, ), + dtype=torch.int64)]) + insert_sequence_meta = SequenceMeta(token_ids=token_ids, + tokens_per_block=cache_engine.tokens_per_block) + match_result = cache_engine.match(insert_sequence_meta) + assert match_result.num_ready_matched_blocks == num_prefix_blocks + assert match_result.num_matched_blocks == num_prefix_blocks + + num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks + cache_engine.insert(insert_sequence_meta, + torch.arange(num_insert_blocks, dtype=torch.int64), + is_ready=True, + match_result=match_result) + cur_cached_blocks += num_insert_blocks + assert cache_engine.index.total_cached_blocks() == cur_cached_blocks + + match_result = cache_engine.match(insert_sequence_meta) + assert match_result.num_matched_blocks == insert_sequence_meta.num_blocks + assert match_result.num_ready_matched_blocks == insert_sequence_meta.num_blocks + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 100, 'tokens_per_block': 16, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +def test_take_and_recycle(cache_engine: CacheEngineAccel): + num_total_blocks = cache_engine.num_total_blocks + tokens_per_block = cache_engine.tokens_per_block + seq_blocks = 10 + token_ids = torch.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=torch.int64) + sequence_meta = SequenceMeta(token_ids=token_ids, + tokens_per_block=tokens_per_block) + physical_blocks = cache_engine.take(seq_blocks) + radixnode = cache_engine.insert(sequence_meta, physical_blocks, is_ready=True) + assert cache_engine.index.total_cached_blocks() == seq_blocks + + empty_blocks = cache_engine.take(0) + assert empty_blocks.shape == (0, ) + assert empty_blocks.dtype == torch.int64 + + with pytest.raises(ValueError): + cache_engine.take(-1) + with pytest.raises(NotEnoughSpaceError): + cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) + + physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) + assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) + assert physical_blocks2.dtype == torch.int64 + + cache_engine.recycle(physical_blocks2) + + cache_engine.lock_node(radixnode) + with pytest.raises(NotEnoughSpaceError): + cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) + cache_engine.cleanup(radixnode, radixnode.size()) + + physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) + assert physical_blocks.shape == (num_total_blocks, ) + assert cache_engine.index.total_cached_blocks() == 0 + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 100, 'tokens_per_block': 1, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +def test_cleanup(cache_engine: CacheEngineAccel): + if cache_engine.tokens_per_block != 1: + pytest.skip("tokens_per_block != 1") + tokens_per_block = cache_engine.tokens_per_block + token_ids_list = [torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64), + torch.tensor([0, 1, 2, 3, 17, 15, 19, 20], dtype=torch.int64), + torch.tensor([0, 23, 22, 21], dtype=torch.int64)] + sequence_meta_list = [SequenceMeta(token_ids=token_ids, + tokens_per_block=tokens_per_block) + for token_ids in token_ids_list] + num_insert_blocks0 = sequence_meta_list[0].num_blocks + radixnode0 = cache_engine.insert(sequence_meta_list[0], + torch.arange(num_insert_blocks0, dtype=torch.int64), + is_ready=False) + cache_engine.lock_node(radixnode0) + radixnode0_size = radixnode0.size() + match_result = cache_engine.match(sequence_meta_list[1]) + num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks + radixnode1 = cache_engine.insert(sequence_meta_list[1], + torch.arange(num_insert_blocks1, dtype=torch.int64), + match_result=match_result, + is_ready=False) + cache_engine.lock_node(radixnode1) + radixnode1_size = radixnode1.size() + match_result = cache_engine.match(sequence_meta_list[2]) + num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks + radixnode2 = cache_engine.insert(sequence_meta_list[2], + torch.arange(num_insert_blocks2, dtype=torch.int64), + match_result=match_result, + is_ready=False) + cache_engine.lock_node(radixnode2) + radixnode2_size = radixnode2.size() + total_insert_blocks = num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 + assert cache_engine.index.total_cached_blocks() == total_insert_blocks + assert cache_engine.index.total_unready_blocks() == total_insert_blocks + assert cache_engine.index.total_ready_blocks() == 0 + + cache_engine.cleanup(radixnode2, radixnode2_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 + + cache_engine.cleanup(radixnode1, radixnode1_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 + + cache_engine.cleanup(radixnode0, radixnode0_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 From ef0720bd49b4d2a817e797858cb8720bab94623f Mon Sep 17 00:00:00 2001 From: moritzxu Date: Wed, 20 Aug 2025 14:51:13 +0800 Subject: [PATCH 03/42] sync kernel launch --- csrc/tp_transfer_thread_group.cpp | 80 ++++++++++++++++++++++++------- csrc/tp_transfer_thread_group.h | 14 +++++- 2 files changed, 76 insertions(+), 18 deletions(-) diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 617ac8abd2..8c6f2d620f 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -24,6 +24,11 @@ TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, torch::Tensor &cpu_blocks, int dp_group_id) { num_gpus_ = num_gpus; + + queues_.resize(num_gpus_); + mtxs_ = std::vector(num_gpus_); + cvs_ = std::vector(num_gpus_); + int num_layers = gpu_blocks[0].size(); cudaMallocHost((void **)&gpu_blocks_, num_gpus_ * num_layers * sizeof(void *)); @@ -41,9 +46,46 @@ TPTransferThreadGroup::TPTransferThreadGroup( cudaSetDevice(dp_group_id * num_gpus_ + i); cudaStreamCreate(&streams_[i]); } + // create the thread pool + stop_pool_=false; + for (int i = 0; i < num_gpus_; ++i) { + threads_.emplace_back([this, i]() { + int device_id = dp_group_id_ * num_gpus_ + i; + cudaSetDevice(device_id); // only once + + while (true) { + Task task; + { + std::unique_lock lk(mtxs_[i]); + cvs_[i].wait(lk, [&]{ return stop_pool_ || !queues_[i].empty(); }); + if (stop_pool_ && queues_[i].empty()) return; + + task = std::move(queues_[i].front()); + queues_[i].pop(); + } + task(); // + } + }); + } + } -TPTransferThreadGroup::~TPTransferThreadGroup() {} +TPTransferThreadGroup::~TPTransferThreadGroup() { + stop_pool_ = true; + for (auto& cv : cvs_) cv.notify_all(); + for (auto& t : threads_) if (t.joinable()) t.join(); +} + +std::future TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx, Task task) { + auto pkg = std::make_shared>(std::move(task)); + auto fut = pkg->get_future(); + { + std::lock_guard lk(mtxs_[gpu_idx]); + queues_[gpu_idx].emplace([pkg]{ (*pkg)(); }); + } + cvs_[gpu_idx].notify_one(); + return fut; +} void TPTransferThreadGroup::tp_group_transfer( const torch::Tensor &gpu_block_id_tensor, @@ -60,11 +102,15 @@ void TPTransferThreadGroup::tp_group_transfer( std::atomic failed{false}; std::string error_msg; - threads_.clear(); - threads_.reserve(num_gpus_); + // threads_.clear(); + // threads_.reserve(num_gpus_); - for (int i = 0; i < num_gpus_; ++i) { - threads_.emplace_back([&, i]() { + // Barrier sync_point(num_gpus_); + std::vector> futures; + futures.reserve(num_gpus_); + + for (int i=0; i #include #include - +#include +#include +#include +#include namespace flexkv { - class TPTransferThreadGroup { public: TPTransferThreadGroup( @@ -48,12 +50,20 @@ class TPTransferThreadGroup { const int layer_granularity, const bool is_mla); private: + using Task = std::function; + std::future enqueue_for_gpu(int gpu_idx, Task task); + int num_gpus_; int dp_group_id_; void **gpu_blocks_; void *cpu_blocks_; std::vector threads_; std::vector streams_; + + std::vector> queues_; + std::vector mtxs_; + std::vector cvs_; + std::atomic stop_pool_; }; } // namespace flexkv From 0290841dce65ae9b036a23d733cf94e47e814934 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 27 Aug 2025 14:26:13 +0800 Subject: [PATCH 04/42] kvmanager refactor (#73) * add KVCacheEngineClient APIs * basic implementation for KVCacheEngineClient * initial transfer manager * init transfer handle * init kv engine * refactor kvmanager * update kvmanager * some refactor * kv response * add benchmark * serialize graph * fix bugs * ready check * update * rename * rename benchmark * use numpy instead of tensor * small fix * remove transfer descriptor * rename to kvmanager * update api * add gpu-kvcache-verifier, draft * update * create a new tp-worker process and create gpu blocks for verification * rename * the test_kvmanager works now * fix virtual op initialize * fix verifier bug when tp > 1 and mla enabled * fix * remove task id && some fix * only create one h2d op * pass slotmapping for launch * quick fix --------- Co-authored-by: linhu-nv Co-authored-by: Fei Liang --- benchmarks/benchmark_kvmanager.py | 273 ----------- benchmarks/benchmark_single_batch.py | 179 +++++++ benchmarks/benchmark_workers.py | 14 +- flexkv/cache/cache_engine.py | 301 +++++------- flexkv/cache/mempool.py | 16 +- flexkv/cache/radixtree.py | 29 +- flexkv/cache/transfer_pattern.py | 167 ++----- flexkv/common/block.py | 10 +- flexkv/common/hash_utils.py | 23 +- flexkv/common/request.py | 20 + flexkv/common/transfer.py | 112 ++--- flexkv/kvmanager.py | 668 ++++++--------------------- flexkv/kvtask.py | 504 ++++++++++++++++++++ flexkv/server/client.py | 115 +++-- flexkv/server/request.py | 42 +- flexkv/server/server.py | 553 ++++++---------------- flexkv/transfer/transfer_engine.py | 21 +- flexkv/transfer/worker.py | 19 +- flexkv/transfer_manager.py | 344 ++++++++++++++ tests/replay_from_tracer.py | 8 +- tests/test_kvmanager.py | 214 +++++++-- tests/test_utils.py | 325 ++++++++++++- 22 files changed, 2199 insertions(+), 1758 deletions(-) delete mode 100644 benchmarks/benchmark_kvmanager.py create mode 100644 benchmarks/benchmark_single_batch.py create mode 100644 flexkv/kvtask.py create mode 100644 flexkv/transfer_manager.py diff --git a/benchmarks/benchmark_kvmanager.py b/benchmarks/benchmark_kvmanager.py deleted file mode 100644 index c1cedb9519..0000000000 --- a/benchmarks/benchmark_kvmanager.py +++ /dev/null @@ -1,273 +0,0 @@ -import os -import tempfile -from multiprocessing import Process -import argparse -import json -import time -from dataclasses import dataclass - -import torch - -from flexkv.server.client import KVDPClient, KVTPClient -from flexkv.server.server import KVServer, SchedulerServer -from flexkv.common.config import ModelConfig, CacheConfig -from flexkv.common.storage import KVCacheLayoutType, KVCacheLayout -from flexkv.common.debug import flexkv_logger -from utils import load_config - -flexkv_logger.set_level("INFO") - - -@dataclass -class BenchmarkConfig: - num_layers_to_transfer: int - batch_size: int - sequence_length: int - cache_ratio: float - -def run_server(model_config, cache_config, server_recv_port): - """Run server process""" - kvserver = KVServer(model_config, cache_config, server_recv_port) - kvserver.run() - time.sleep(10) - -def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): - """Run tp_client process""" - device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) - - num_gpu_blocks = cache_config.num_gpu_blocks - - gpu_kv_layout = KVCacheLayout( - type=cache_config.gpu_kv_layout_type, - num_layer=model_config.num_layers, - num_block=num_gpu_blocks, - tokens_per_block=cache_config.tokens_per_block, - num_head=model_config.num_kv_heads, - head_size=model_config.head_size, - is_mla=model_config.use_mla, - ) - - # Create GPU blocks for this tp_rank in the tp_client process - gpu_blocks_for_tp = [] - for _ in range(model_config.num_layers): - gpu_blocks_for_tp.append( - torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) - ) - tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) - # Keep the process running - while True: - time.sleep(1) - -def shutdown_tp_client(tp_client_processes): - for tp_process in tp_client_processes: - if tp_process.is_alive(): - tp_process.terminate() - tp_process.join(timeout=5) - if tp_process.is_alive(): - print(f"Force killing tp_client process {tp_process.pid}") - tp_process.kill() - tp_process.join(timeout=2) - -class FlexkvWrapper: - def __init__(self, model_config, cache_config, server_recv_port): - self.model_config = model_config - self.cache_config = cache_config - self.server_recv_port = server_recv_port - - self.use_scheduler_server = model_config.dp_size == 1 - if self.use_scheduler_server: - self.launch_scheduler_server() - else: - self.launch_server() - - def launch_server(self): - def server_process(): - kvserver = KVServer(self.model_config, self.cache_config, self.server_recv_port) - kvserver.run() - time.sleep(10) - self.server_process = Process( - target=server_process, - daemon=False - ) - self.server_process.start() - time.sleep(5) - self.dp_client = KVDPClient(self.server_recv_port, self.model_config) - - def launch_scheduler_server(self): - self.scheduler_server = SchedulerServer(self.model_config, self.cache_config, self.server_recv_port) - self.scheduler_server.start_server_thread() - time.sleep(10) - - @property - def dp_client_id(self): - if self.use_scheduler_server: - return 0 - else: - return self.dp_client.dp_client_id - - def put_async(self, token_ids, slot_mapping, token_mask=None): - if self.use_scheduler_server: - return self.scheduler_server.put_async(token_ids, slot_mapping, token_mask) - else: - return self.dp_client.put_async(token_ids, slot_mapping, token_mask) - - def get_async(self, token_ids, slot_mapping, token_mask=None): - if self.use_scheduler_server: - return self.scheduler_server.get_async(token_ids, slot_mapping, token_mask) - else: - return self.dp_client.get_async(token_ids, slot_mapping, token_mask) - - def wait(self, request_ids): - if self.use_scheduler_server: - return self.scheduler_server.wait(request_ids) - else: - return self.dp_client.wait(request_ids) - - def try_wait(self, request_ids): - if self.use_scheduler_server: - return self.scheduler_server.try_wait(request_ids) - else: - return self.dp_client.try_wait(request_ids) - - def check_running(self): - if self.use_scheduler_server: - return self.scheduler_server.check_running() - else: - return self.dp_client.check_running() - - def shutdown(self): - if not self.use_scheduler_server: - try: - # Send a shutdown request to the server - self.dp_client.shutdown() - # Wait a bit for graceful shutdown - time.sleep(3) - except Exception as e: - print(f"Error sending shutdown request: {e}") - if self.server_process.is_alive(): - self.server_process.terminate() - self.server_process.join(timeout=10) - if self.server_process.is_alive(): - print(f"Force killing server process {self.server_process.pid}") - self.server_process.kill() - self.server_process.join(timeout=5) - if self.server_recv_port.startswith('ipc://'): - temp_file = self.server_recv_port[6:] # Remove 'ipc://' prefix - try: - if os.path.exists(temp_file): - os.unlink(temp_file) - except Exception as e: - print(f"Error cleaning up temporary file: {e}") - else: - self.scheduler_server.shutdown() - -def benchmark_kvmanager(model_config, cache_config, benchmark_config, server_recv_port): - if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): - raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " - f"the number of available GPUs {torch.cuda.device_count()}") - print(f"{model_config = }") - print(f"{cache_config = }") - print(f"{benchmark_config = }") - flexkv_wrapper = FlexkvWrapper(model_config, cache_config, server_recv_port) - - tp_client_processes = [] - - sequence_length = benchmark_config.sequence_length - batch_size = benchmark_config.batch_size - num_required_gpu_blocks = sequence_length * batch_size // cache_config.tokens_per_block - cache_config.num_gpu_blocks = num_required_gpu_blocks - print(f"allocate {num_required_gpu_blocks} gpu blocks for benchmark") - for tp_rank in range(model_config.tp_size): - tp_client_process = Process( - target=run_tp_client, - args=(flexkv_wrapper.dp_client_id, tp_rank, server_recv_port, - model_config, cache_config), - daemon=True - ) - tp_client_process.start() - tp_client_processes.append(tp_client_process) - time.sleep(5) - - batch_sequence_tensor = [] - batch_slot_mapping = [] - cache_length = int(sequence_length * benchmark_config.cache_ratio) - - # generate requests - for i in range(batch_size): - batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length, ), dtype=torch.int64)) - batch_slot_mapping.append(torch.arange(i * sequence_length, (i+1) * sequence_length, dtype=torch.int64)) - - while not flexkv_wrapper.check_running(): - time.sleep(0.1) - print("waiting for flexkv wrapper to be ready") - # benchmark put - start_time = time.time() - put_ids = [] - if benchmark_config.cache_ratio > 0: - for i in range(batch_size): - put_ids.append(flexkv_wrapper.put_async(batch_sequence_tensor[i][:cache_length], - batch_slot_mapping[i][:cache_length], - token_mask=None)) - put_result = flexkv_wrapper.wait(put_ids) - end_time = time.time() - time.sleep(1) - elapsed_time_put = end_time - start_time - put_tokens = 0 - for _, return_mask in put_result.items(): - put_tokens += return_mask.sum().item() - transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 - transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put - print(f"put {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " - f"time: {elapsed_time_put*1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") - - #benchmark get - start_time = time.time() - get_ids = [] - for i in range(batch_size): - get_ids.append(flexkv_wrapper.get_async(batch_sequence_tensor[i], - batch_slot_mapping[i], - token_mask=None)) - get_result = flexkv_wrapper.wait(get_ids) - end_time = time.time() - elapsed_time_get = end_time - start_time - cached_tokens = 0 - all_tokens = 0 - for _, return_mask in get_result.items(): - cached_tokens += return_mask.sum().item() - all_tokens += len(return_mask) - transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 - transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get - print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " - f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " - f"time: {elapsed_time_get*1000:.2f}ms, bandwidth: {transfer_bandwidth_get:.2f} GB/s") - - shutdown_tp_client(tp_client_processes) - flexkv_wrapper.shutdown() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="benchmarks/example_config.json") - # benchmark config - parser.add_argument("--num-layers", type=int, default=-1) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--sequence-length", type=int, default=1024) - parser.add_argument("--cache-ratio", type=float, default=1) - return parser.parse_args() - -if __name__ == "__main__": - args = parse_args() - benchmark_config = BenchmarkConfig( - num_layers_to_transfer=args.num_layers, - batch_size=args.batch_size, - sequence_length=args.sequence_length, - cache_ratio=args.cache_ratio - ) - model_config, cache_config = load_config(args.config) - #cache_config.num_cpu_blocks = 8192 - 2048 - # pad sequence length to divisible by tokens_per_block - benchmark_config.sequence_length = \ - ((benchmark_config.sequence_length - 1) // cache_config.tokens_per_block + 1) * cache_config.tokens_per_block - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - benchmark_kvmanager(model_config, cache_config, benchmark_config, server_recv_port) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py new file mode 100644 index 0000000000..b58c3b8e71 --- /dev/null +++ b/benchmarks/benchmark_single_batch.py @@ -0,0 +1,179 @@ +import tempfile +from multiprocessing import Process +import argparse +import time +from dataclasses import dataclass + +import torch + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout +from flexkv.common.debug import flexkv_logger +from utils import load_config +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +flexkv_logger.set_level("INFO") + + +@dataclass +class BenchmarkConfig: + num_layers_to_transfer: int + batch_size: int + sequence_length: int + cache_ratio: float + +def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): + """Run tp_client process""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + + num_gpu_blocks = cache_config.num_gpu_blocks + + gpu_kv_layout = KVCacheLayout( + type=cache_config.gpu_kv_layout_type, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + # Keep the process running + while True: + time.sleep(1) + +def shutdown_tp_client(tp_client_processes): + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + +def benchmark_flexkv(model_config, cache_config, benchmark_config, gpu_register_port, server_recv_port): + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " + f"the number of available GPUs {torch.cuda.device_count()}") + print(f"{benchmark_config = }") + kvmanager = KVManager(model_config, cache_config, gpu_register_port, server_recv_port) + kvmanager.start() + + tp_client_processes = [] + + sequence_length = benchmark_config.sequence_length + batch_size = benchmark_config.batch_size + num_required_gpu_blocks = sequence_length * batch_size // cache_config.tokens_per_block + cache_config.num_gpu_blocks = num_required_gpu_blocks + print(f"allocate {num_required_gpu_blocks} gpu blocks for benchmark") + for tp_rank in range(model_config.tp_size): + tp_client_process = Process( + target=run_tp_client, + args=(0, tp_rank, gpu_register_port, + model_config, cache_config), + daemon=True + ) + tp_client_process.start() + tp_client_processes.append(tp_client_process) + + while not kvmanager.is_ready(): + time.sleep(1) + flexkv_logger.info("waiting for flexkv to be ready") + + batch_sequence_tensor = [] + batch_slot_mapping = [] + cache_length = int(sequence_length * benchmark_config.cache_ratio) + + # generate requests + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length, ), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i+1) * sequence_length, dtype=torch.int64)) + + # benchmark put + start_time = time.time() + batch_put_ids = [] + if benchmark_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async(batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put + print(f"put {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put*1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], + token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + kvmanager.launch(batch_get_ids, batch_slot_mapping) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get + print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time*1000:.2f}ms, " + f"e2e time: {elapsed_time_get*1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + + shutdown_tp_client(tp_client_processes) + kvmanager.shutdown() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="benchmarks/example_config.json") + # benchmark config + parser.add_argument("--num-layers", type=int, default=-1) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--sequence-length", type=int, default=1024) + parser.add_argument("--cache-ratio", type=float, default=1) + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + benchmark_config = BenchmarkConfig( + num_layers_to_transfer=args.num_layers, + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio + ) + model_config, cache_config = load_config(args.config) + #cache_config.num_cpu_blocks = 8192 - 2048 + # pad sequence length to divisible by tokens_per_block + benchmark_config.sequence_length = \ + ((benchmark_config.sequence_length - 1) // cache_config.tokens_per_block + 1) * cache_config.tokens_per_block + import uuid + gpu_register_port = f"ipc:///tmp/flexkv_gpu_{uuid.uuid4().hex[:8]}" + server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" + + benchmark_flexkv(model_config, cache_config, benchmark_config, gpu_register_port, server_recv_port) diff --git a/benchmarks/benchmark_workers.py b/benchmarks/benchmark_workers.py index f9c50f8bc7..c02fb6f925 100644 --- a/benchmarks/benchmark_workers.py +++ b/benchmarks/benchmark_workers.py @@ -9,7 +9,7 @@ import torch -from flexkv.common.transfer import TransferOp, TransferType, TransferDescriptor +from flexkv.common.transfer import TransferOp, TransferType from flexkv.transfer.worker import GPUCPUTransferWorker, CPUSSDDiskTransferWorker, WorkerHandle, tpGPUCPUTransferWorker from flexkv.storage.allocator import CPUAllocator, GPUAllocator, SSDAllocator from flexkv.common.storage import KVCacheLayoutType, KVCacheLayout @@ -212,12 +212,8 @@ def bench_worker(args): transfer_type=transfer_type, layer_id=0, layer_granularity=num_layers_to_transfer, - src_descriptor=TransferDescriptor( - physical_block_ids=block_ids, - ), - dst_descriptor=TransferDescriptor( - physical_block_ids=block_ids, - ), + src_block_ids=block_ids, + dst_block_ids=block_ids, graph_id=0, dp_id=0, successors=[], @@ -226,8 +222,8 @@ def bench_worker(args): if transfer_type == TransferType.DISK2H: tmp_op = copy.deepcopy(transfer_op) tmp_op.transfer_type = TransferType.H2DISK - tmp_op.src_descriptor = transfer_op.dst_descriptor - tmp_op.dst_descriptor = transfer_op.src_descriptor + tmp_op.src_block_ids = transfer_op.dst_block_ids + tmp_op.dst_block_ids = transfer_op.src_block_ids launch_transfer(worker_handle, finished_ops_queue, tmp_op) for _ in range(warmup_round): launch_transfer(worker_handle, finished_ops_queue, transfer_op) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 3cf178a9d7..b2c76b795d 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -20,21 +20,20 @@ from typing import List, Tuple, Optional, Dict, Callable from dataclasses import dataclass +import numpy as np import torch from flexkv.c_ext import CRadixNode, CRadixTreeIndex, CMatchResult from flexkv.cache.mempool import Mempool from flexkv.cache.radixtree import RadixTreeIndex, RadixNode, MatchResult -from flexkv.cache.transfer_pattern import ( - convert_read_graph_to_layer_wise_graph, add_virtal_op_for_mutiple_finished_ops -) +from flexkv.cache.transfer_pattern import add_virtal_op_for_mutiple_finished_ops from flexkv.common.block import SequenceMeta from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.exceptions import InvalidConfigError, NotEnoughSpaceError from flexkv.common.transfer import ( - DeviceType, TransferOpGraph, TransferOp, TransferType, TransferDescriptor + DeviceType, TransferOpGraph, TransferOp, TransferType ) - +from flexkv.common.debug import flexkv_logger @dataclass class MatchResultAccel: num_ready_matched_blocks: int = 0 @@ -171,7 +170,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResult: def insert(self, sequence_meta: SequenceMeta, - physical_block_ids: torch.Tensor, + physical_block_ids: np.ndarray, num_insert_blocks: int = -1, is_ready: bool = True, match_result: Optional[MatchResult] = None) -> Optional[RadixNode]: @@ -191,7 +190,7 @@ def cleanup(self, node: RadixNode, cleanup_length: int) -> None: def take(self, num_required_blocks: int, protected_node: Optional[RadixNode] = None, - strict: bool = True) -> torch.Tensor: + strict: bool = True) -> np.ndarray: if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) @@ -207,7 +206,7 @@ def take(self, num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) return self.mempool.allocate_blocks(num_allocated_blocks) - def recycle(self, physical_blocks: torch.Tensor) -> None: + def recycle(self, physical_blocks: np.ndarray) -> None: self.mempool.recycle_blocks(physical_blocks) class GlobalCacheEngine: @@ -268,25 +267,34 @@ def reset(self) -> None: def get(self, request_id: int, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor, + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, torch.Tensor, Callable, List[int]]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: self._check_input(token_ids, token_mask, slot_mapping) + if layer_num == -1: layer_num = self.model_config.num_layers if layer_granularity == -1: layer_granularity = layer_num + if layer_num != layer_granularity: + flexkv_logger.error(f"Layerwise transfer is not supported yet, " + f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") + raise NotImplementedError(f"Layerwise transfer is not supported yet, " + f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") + # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] token_mask[aligned_length:] = False + block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block - gpu_block_mapping = self._slot_to_block_mapping(slot_mapping)[:block_end_idx-block_start_idx] + gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, + self.tokens_per_block)[:block_end_idx-block_start_idx] sequence_meta = SequenceMeta(token_ids=aligned_token_ids, tokens_per_block=self.cache_config.tokens_per_block) @@ -298,7 +306,7 @@ def get(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) else: @@ -308,24 +316,24 @@ def get(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) - transfer_graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops( + transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, finished_ops_ids ) - return_mask = torch.zeros_like(token_mask) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[block_start_idx* self.tokens_per_block: (block_start_idx + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True - if layer_num // layer_granularity != 1: - transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, - finished_ops_ids=finished_ops_ids, - layer_num=layer_num, - layer_granularity=layer_granularity) + # if layer_num // layer_granularity != 1: + # transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, + # finished_ops_ids=finished_ops_ids, + # layer_num=layer_num, + # layer_granularity=layer_granularity) transfer_graph.bind_to_dp_group(dp_id) for device_type in node_to_unlock: @@ -335,14 +343,14 @@ def get(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, finished_ops_ids + return transfer_graph, return_mask, callback, task_end_op_id def _get_impl_global(self, request_id: int, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: """ transfer pattern: @@ -371,7 +379,7 @@ def _get_impl_global(self, #early return if no blocks to transfer if fragment123_num_blocks == 0: return self._empty_get_return(request_id) - assert fragment123_num_blocks <= len(gpu_block_mapping) + assert fragment123_num_blocks <= len(gpu_block_ids) transfer_graph = TransferOpGraph() finished_ops_ids = [] @@ -382,7 +390,7 @@ def _get_impl_global(self, fragment3_num_blocks = max(len(remote_matched_blocks) - fragment12_num_blocks, 0) fragment23_num_blocks = fragment2_num_blocks + fragment3_num_blocks - fragment123_gpu_blocks = gpu_block_mapping[:fragment123_num_blocks] + fragment123_gpu_blocks = gpu_block_ids[:fragment123_num_blocks] fragment123_cpu_blocks = cpu_matched_blocks fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] fragment3_remote_blocks = remote_matched_blocks[-fragment3_num_blocks:] @@ -390,7 +398,7 @@ def _get_impl_global(self, cpu_node_to_unlock = cpu_matched_result.last_ready_node ssd_node_to_unlock = ssd_matched_result.last_ready_node remote_node_to_unlock = remote_matched_result.last_ready_node - cpu_blocks_to_free = torch.tensor([], dtype=torch.int64) + cpu_blocks_to_free = np.array([], dtype=np.int64) if fragment23_num_blocks > 0: num_extra_required_blocks = fragment23_num_blocks @@ -402,7 +410,7 @@ def _get_impl_global(self, if len(fragment23_cpu_blocks) < num_extra_required_blocks: self.cpu_cache_engine.recycle(fragment23_cpu_blocks) return self._empty_get_return(request_id) - fragment123_cpu_blocks = torch.cat([fragment123_cpu_blocks, fragment23_cpu_blocks]) + fragment123_cpu_blocks = np.concatenate([fragment123_cpu_blocks, fragment23_cpu_blocks]) # we only insert the buffer blocks to cpu cache engine only: # 1. the cpu cache engine satisfies prefix cache after insertion # 2. the sequence is all ready blocks @@ -421,14 +429,8 @@ def _get_impl_global(self, op_disk2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[fragment1_num_blocks:fragment12_num_blocks] - ), + src_block_ids = fragment2_ssd_blocks, + dst_block_ids = fragment123_cpu_blocks[fragment1_num_blocks:fragment12_num_blocks], layer_id = 0, layer_granularity = layer_num ) @@ -439,14 +441,8 @@ def _get_impl_global(self, op_remote2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.REMOTE2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=fragment3_remote_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[-fragment3_num_blocks:], - ), + src_block_ids = fragment3_remote_blocks, + dst_block_ids = fragment123_cpu_blocks[-fragment3_num_blocks:], layer_id = 0, layer_granularity = layer_num ) @@ -472,14 +468,8 @@ def _get_impl_global(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[-fragment3_num_blocks:], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment3_ssd_blocks, - ), + src_block_ids = fragment123_cpu_blocks[-fragment3_num_blocks:], + dst_block_ids = fragment3_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -496,15 +486,8 @@ def _get_impl_global(self, op_h2d = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment123_gpu_blocks, - device_id = 0 - ), + src_block_ids = fragment123_cpu_blocks, + dst_block_ids = fragment123_gpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -533,7 +516,7 @@ def _get_impl_local(self, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: """ transfer pattern: @@ -567,12 +550,12 @@ def _get_impl_local(self, #early return if no blocks to transfer if fragment12_num_blocks == 0: return self._empty_get_return(request_id) - assert fragment12_num_blocks <= len(gpu_block_mapping) + assert fragment12_num_blocks <= len(gpu_block_ids) transfer_graph = TransferOpGraph() finished_ops_ids = [] - fragment12_gpu_blocks = gpu_block_mapping[:fragment12_num_blocks] + fragment12_gpu_blocks = gpu_block_ids[:fragment12_num_blocks] fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] fragment1_cpu_blocks = cpu_matched_blocks[:fragment1_num_blocks] @@ -580,7 +563,9 @@ def _get_impl_local(self, ssd_node_to_unlock = ssd_matched_result.last_ready_node # prepare cpu blocks to transfer - cpu_blocks_to_free = torch.tensor([], dtype=torch.int64) + cpu_blocks_to_free = np.array([], dtype=np.int64) + op_disk2h = None + fragment2_cpu_blocks = None if fragment2_num_blocks > 0: fragment2_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment2_num_blocks, @@ -596,39 +581,12 @@ def _get_impl_local(self, op_disk2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), + src_block_ids = fragment2_ssd_blocks, + dst_block_ids = fragment2_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) transfer_graph.add_transfer_op(op_disk2h) - - op_h2d_frag2 = TransferOp( - graph_id = transfer_graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks[-fragment2_num_blocks:], - device_id = 0 - ), - layer_id = 0, - layer_granularity = layer_num - ) - transfer_graph.add_transfer_op(op_h2d_frag2) - - transfer_graph.add_dependency(op_h2d_frag2.op_id, op_disk2h.op_id) - finished_ops_ids.append(op_h2d_frag2.op_id) - # we only insert the buffer blocks to cpu cache engine only: # 1. the cpu cache engine satisfies prefix cache after insertion # 2. the sequence is all ready blocks @@ -642,23 +600,22 @@ def _get_impl_local(self, match_result=cpu_matched_result) else: cpu_blocks_to_free = fragment2_cpu_blocks - op_h2d_frag1 = TransferOp( + if fragment2_cpu_blocks is not None: + fragment12_cpu_blocks = np.concatenate([fragment1_cpu_blocks, fragment2_cpu_blocks]) + else: + fragment12_cpu_blocks = fragment1_cpu_blocks + op_h2d = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment1_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks[:fragment1_num_blocks], - device_id = 0 - ), + src_block_ids = fragment12_cpu_blocks, + dst_block_ids = fragment12_gpu_blocks, layer_id = 0, layer_granularity = layer_num ) - transfer_graph.add_transfer_op(op_h2d_frag1) - finished_ops_ids.append(op_h2d_frag1.op_id) + transfer_graph.add_transfer_op(op_h2d) + if op_disk2h is not None: + transfer_graph.add_dependency(op_h2d.op_id, op_disk2h.op_id) + finished_ops_ids.append(op_h2d.op_id) node_to_unlock = {} if cpu_node_to_unlock is not None: @@ -671,11 +628,11 @@ def _get_impl_local(self, def put(self, request_id: int, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor, + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, torch.Tensor, Callable, List[int]]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -689,7 +646,8 @@ def put(self, # the mask should has a prefix of True assert block_start_idx == 0 - gpu_block_mapping = self._slot_to_block_mapping(slot_mapping)[:block_end_idx-block_start_idx] + gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, + self.tokens_per_block)[:block_end_idx-block_start_idx] sequence_meta = SequenceMeta(token_ids=aligned_token_ids, tokens_per_block=self.cache_config.tokens_per_block) @@ -702,7 +660,7 @@ def put(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) else: @@ -713,16 +671,16 @@ def put(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) - transfer_graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops( + transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, finished_ops_ids ) - return_mask = torch.zeros_like(token_mask) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[(block_start_idx + skipped_gpu_blocks)* self.tokens_per_block: (block_start_idx + skipped_gpu_blocks + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True transfer_graph.bind_to_dp_group(dp_id) @@ -734,14 +692,14 @@ def put(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, finished_ops_ids + return transfer_graph, return_mask, callback, task_end_op_id def _put_impl_global(self, request_id: int, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: """ transfer pattern: @@ -773,15 +731,15 @@ def _put_impl_global(self, :remote_matched_result.num_matched_blocks][block_mask_start:block_mask_end] num_skipped_blocks = len(cpu_matched_blocks) - fragment12_num_blocks = len(gpu_block_mapping) - num_skipped_blocks + fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks if fragment12_num_blocks == 0: return self._empty_put_return(request_id) - fragment2_num_blocks = len(gpu_block_mapping) - len(ssd_matched_blocks) + fragment2_num_blocks = len(gpu_block_ids) - len(ssd_matched_blocks) if not self.cache_config.enable_ssd: fragment2_num_blocks = 0 - fragment3_num_blocks = len(gpu_block_mapping) - len(remote_matched_blocks) + fragment3_num_blocks = len(gpu_block_ids) - len(remote_matched_blocks) - fragment12_gpu_blocks = gpu_block_mapping[num_skipped_blocks:] + fragment12_gpu_blocks = gpu_block_ids[num_skipped_blocks:] fragment12_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment12_num_blocks, @@ -803,7 +761,7 @@ def _put_impl_global(self, else: self.ssd_cache_engine.recycle(fragment2_ssd_blocks) else: - fragment2_ssd_blocks = torch.tensor([], dtype=torch.int64) + fragment2_ssd_blocks = np.array([], dtype=np.int64) put_to_remote = False if fragment3_num_blocks > 0: fragment3_remote_blocks = self.remote_cache_engine.take( @@ -816,7 +774,7 @@ def _put_impl_global(self, else: self.remote_cache_engine.recycle(fragment3_remote_blocks) else: - fragment3_remote_blocks = torch.tensor([], dtype=torch.int64) + fragment3_remote_blocks = np.array([], dtype=np.int64) transfer_graph = TransferOpGraph() finished_ops_ids = [] @@ -824,14 +782,8 @@ def _put_impl_global(self, op_d2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment12_cpu_blocks, - ), + src_block_ids = fragment12_gpu_blocks, + dst_block_ids = fragment12_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -843,14 +795,8 @@ def _put_impl_global(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), + src_block_ids = fragment2_cpu_blocks, + dst_block_ids = fragment2_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -861,21 +807,15 @@ def _put_impl_global(self, if put_to_remote: if fragment3_num_blocks > fragment12_num_blocks: extra_num_cpu_blocks = fragment3_num_blocks - fragment12_num_blocks - fragment3_cpu_blocks = torch.cat([fragment12_cpu_blocks, + fragment3_cpu_blocks = np.concatenate([fragment12_cpu_blocks, cpu_matched_blocks[-extra_num_cpu_blocks:]]) else: fragment3_cpu_blocks = fragment12_cpu_blocks[-fragment3_num_blocks:] op_h2remote = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2REMOTE, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment3_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=fragment3_remote_blocks, - ), + src_block_ids = fragment3_cpu_blocks, + dst_block_ids = fragment3_remote_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -914,7 +854,7 @@ def _put_impl_local(self, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: """ transfer pattern: @@ -940,14 +880,14 @@ def _put_impl_local(self, :ssd_matched_result.num_matched_blocks][block_mask_start:block_mask_end] num_skipped_blocks = len(cpu_matched_blocks) - fragment12_num_blocks = len(gpu_block_mapping) - num_skipped_blocks + fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks if fragment12_num_blocks == 0: return self._empty_put_return(request_id) - fragment2_num_blocks = len(gpu_block_mapping) - len(ssd_matched_blocks) + fragment2_num_blocks = len(gpu_block_ids) - len(ssd_matched_blocks) if not self.cache_config.enable_ssd: fragment2_num_blocks = 0 - fragment12_gpu_blocks = gpu_block_mapping[num_skipped_blocks:] + fragment12_gpu_blocks = gpu_block_ids[num_skipped_blocks:] fragment12_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment12_num_blocks, @@ -961,7 +901,7 @@ def _put_impl_local(self, strict=False ) else: - fragment2_ssd_blocks = torch.tensor([], dtype=torch.int64) + fragment2_ssd_blocks = np.array([], dtype=np.int64) if len(fragment12_cpu_blocks) < fragment12_num_blocks or \ len(fragment2_ssd_blocks) < fragment2_num_blocks: self.cpu_cache_engine.recycle(fragment12_cpu_blocks) @@ -975,14 +915,8 @@ def _put_impl_local(self, op_d2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment12_cpu_blocks, - ), + src_block_ids = fragment12_gpu_blocks, + dst_block_ids = fragment12_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -994,14 +928,8 @@ def _put_impl_local(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), + src_block_ids = fragment2_cpu_blocks, + dst_block_ids = fragment2_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -1031,7 +959,7 @@ def _put_impl_local(self, def _transfer_callback(self, node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]], - buffer_to_free: Optional[Dict[DeviceType, torch.Tensor]] = None) -> None: + buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None) -> None: if DeviceType.CPU in node_to_unlock: assert self.cpu_cache_engine is not None self.cpu_cache_engine.cleanup(node_to_unlock[DeviceType.CPU][0], node_to_unlock[DeviceType.CPU][1]) @@ -1072,7 +1000,8 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe return cpu_matched_result, ssd_matched_result - def match_all_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: + def match_all_accel(self, + sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: cpu_matched_result = MatchResultAccel() ssd_matched_result = MatchResultAccel() remote_matched_result = MatchResultAccel() @@ -1101,25 +1030,29 @@ def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResu return cpu_matched_result, ssd_matched_result, remote_matched_result def _check_input(self, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor) -> None: + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray) -> None: + assert token_ids.dtype == np.int64 + # assert token_mask.dtype == np.bool_, f"token_mask.dtype={token_mask.dtype}" + assert slot_mapping.dtype == np.int64 assert token_ids.ndim == 1 assert token_mask.ndim == 1 assert slot_mapping.ndim == 1 - assert len(token_ids) == len(token_mask), f"len(token_ids)={len(token_ids)}, len(token_mask)={len(token_mask)}" - assert len(slot_mapping) == token_mask.sum().item(), f"len(slot_mapping)={len(slot_mapping)}, token_mask.sum().item()={token_mask.sum().item()}" + assert token_ids.size == token_mask.size, f"token_ids.size={token_ids.size}, token_mask.size={token_mask.size}" + assert slot_mapping.size == token_mask.sum(), \ + f"slot_mapping.size={slot_mapping.size}, token_mask.sum()={token_mask.sum()}" - def _slot_to_block_mapping(self, - slot_mapping: torch.Tensor) -> torch.Tensor: - block_mapping: torch.Tensor = slot_mapping[::self.tokens_per_block] // self.tokens_per_block - return block_mapping + @staticmethod + def slot_mapping_to_block_ids(slot_mapping: np.ndarray, tokens_per_block: int) -> np.ndarray: + block_ids: np.ndarray = slot_mapping[::tokens_per_block] // tokens_per_block + return block_ids def _get_block_range(self, - token_mask: torch.Tensor) -> Tuple[int, int]: - mask_idx = torch.where(token_mask)[0] + token_mask: np.ndarray) -> Tuple[int, int]: + mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: return 0, 0 - start_idx = int(mask_idx[0].item() // self.tokens_per_block) - end_idx = int(mask_idx[-1].item() // self.tokens_per_block) + start_idx = mask_idx[0].item() // self.tokens_per_block + end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 diff --git a/flexkv/cache/mempool.py b/flexkv/cache/mempool.py index d0e6848a30..1a99a5cff6 100644 --- a/flexkv/cache/mempool.py +++ b/flexkv/cache/mempool.py @@ -1,7 +1,7 @@ from collections import deque from typing import List -import torch +import numpy as np from flexkv.common.exceptions import NotEnoughSpaceError @@ -14,18 +14,18 @@ def __init__( assert num_total_blocks > 0 self.num_total_blocks = num_total_blocks - self._free_mask = torch.ones(self.num_total_blocks, dtype=torch.bool) + self._free_mask = np.ones(self.num_total_blocks, dtype=np.bool_) self._num_free = num_total_blocks - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 def reset(self) -> None: self._free_mask.fill_(True) self._num_free = self.num_total_blocks - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 - def allocate_blocks(self, num: int) -> torch.Tensor: + def allocate_blocks(self, num: int) -> np.ndarray: if num < 0: raise ValueError(f"num must be greater than 0, but got {num}") if num > self._num_free: @@ -41,8 +41,8 @@ def allocate_blocks(self, num: int) -> torch.Tensor: self._num_free -= num return free_ids - def recycle_blocks(self, block_ids: torch.Tensor) -> None: - if block_ids.ndim != 1 or block_ids.dtype != torch.int64: + def recycle_blocks(self, block_ids: np.ndarray) -> None: + if block_ids.ndim != 1 or block_ids.dtype != np.int64: raise ValueError("block_ids must be a 1D tensor of int64") if self._free_mask[block_ids].any(): free_ids = block_ids[self._free_mask[block_ids]] @@ -51,7 +51,7 @@ def recycle_blocks(self, block_ids: torch.Tensor) -> None: self._num_free += len(block_ids) def _update_free_ids(self) -> None: - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 @property diff --git a/flexkv/cache/radixtree.py b/flexkv/cache/radixtree.py index 818e778a63..a9f69c6a35 100644 --- a/flexkv/cache/radixtree.py +++ b/flexkv/cache/radixtree.py @@ -32,11 +32,14 @@ class MatchResult: last_ready_node: Optional['RadixNode'] = None last_node: Optional['RadixNode'] = None last_node_matched_length: int = 0 - physical_blocks: torch.Tensor = torch.empty(0, dtype=torch.int64) + physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) def __post_init__(self) -> None: assert self.physical_blocks.ndim == 1 - assert self.physical_blocks.dtype == torch.int64 + assert self.physical_blocks.dtype == np.int64 + + def is_empty(self) -> bool: + return self.num_matched_blocks == 0 @dataclass class RadixNode: @@ -192,7 +195,7 @@ def match_prefix(self, last_ready_node=last_ready_node, last_node=current_node, last_node_matched_length=last_node_matched_length, - physical_blocks=torch.from_numpy(physical_blocks).to(torch.int64)) + physical_blocks=physical_blocks) def num_matched_blocks(self, sequence: SequenceMeta) -> int: @@ -201,7 +204,7 @@ def num_matched_blocks(self, def insert(self, sequence_meta: SequenceMeta, - physical_block_ids: torch.Tensor, + physical_block_ids: np.ndarray, num_insert_blocks: int = -1, is_ready: bool = True, match_result: Optional[MatchResult] = None) -> Optional[RadixNode]: @@ -210,7 +213,7 @@ def insert(self, assert 0 <= num_insert_blocks <= sequence_meta.num_blocks assert physical_block_ids.ndim == 1 - assert physical_block_ids.dtype == torch.int64 + assert physical_block_ids.dtype == np.int64 sequence_meta.gen_hashes() if match_result is None: @@ -232,7 +235,7 @@ def insert(self, new_node = RadixNode( block_hashes=sequence_meta.block_hashes[num_matched_blocks:num_insert_blocks], - physical_blocks=physical_block_ids.numpy(), + physical_blocks=physical_block_ids, is_ready=is_ready, lock_cnt=0, last_access_time=time.time() @@ -255,7 +258,7 @@ def insert(self, return new_node - def evict(self, num_evicted: int) -> torch.Tensor: + def evict(self, num_evicted: int) -> np.ndarray: candidates = [] for node in self.leaf_nodes.values(): if node.evictable(): @@ -277,7 +280,7 @@ def evict(self, num_evicted: int) -> torch.Tensor: physical_blocks = node.physical_blocks node.parent = None evicted_blocks = np.concatenate([evicted_blocks, physical_blocks]) - return torch.from_numpy(evicted_blocks).to(torch.int64) + return evicted_blocks def lock(self, node: RadixNode) -> None: if node.lock_cnt < 0: @@ -340,22 +343,22 @@ def total_unready_blocks(self) -> int: index = RadixTreeIndex(tokens_per_block=tokens_per_block) print(f"init index, tokens_per_block = {tokens_per_block}") - token_ids1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) - token_ids2 = torch.tensor([1, 2, 3, 4, 15, 16, 17, 18]) + token_ids1 = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64) + token_ids2 = np.array([1, 2, 3, 4, 15, 16, 17, 18], dtype=np.int64) seq1 = SequenceMeta(token_ids=token_ids1, tokens_per_block=tokens_per_block) seq2 = SequenceMeta(token_ids=token_ids2, tokens_per_block=tokens_per_block) - index.insert(seq1, torch.tensor([0, 1, 2, 3], dtype=torch.int64), is_ready=True) + index.insert(seq1, np.array([0, 1, 2, 3], dtype=np.int64), is_ready=True) print(f"insert seq1 = {seq1.token_ids}, " f"total cached blocks = {index.total_cached_blocks()}") seq2_matched_blocks = index.num_matched_blocks(seq2) assert seq2_matched_blocks == 2 - index.insert(seq2, torch.tensor([8, 9], dtype=torch.int64), is_ready=True) + index.insert(seq2, np.array([8, 9], dtype=np.int64), is_ready=True) print(f"insert seq2 = {seq2.token_ids}, " f"total cached blocks = {index.total_cached_blocks()}") - seq3 = SequenceMeta(token_ids=torch.tensor([1,2,3,4,0,0]), + seq3 = SequenceMeta(token_ids=np.array([1,2,3,4,0,0], dtype=np.int64), tokens_per_block=tokens_per_block) match_result = index.num_matched_blocks(seq3) print(f"match {seq3.token_ids}, num cached blocks: {match_result}") diff --git a/flexkv/cache/transfer_pattern.py b/flexkv/cache/transfer_pattern.py index 1246d248e8..0a434bde04 100644 --- a/flexkv/cache/transfer_pattern.py +++ b/flexkv/cache/transfer_pattern.py @@ -1,28 +1,33 @@ from typing import List, Optional, Tuple +import numpy as np import torch -from flexkv.common.transfer import DeviceType, TransferType -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferDescriptor +from flexkv.common.transfer import TransferType +from flexkv.common.transfer import TransferOp, TransferOpGraph def add_virtal_op_for_mutiple_finished_ops( graph: TransferOpGraph, finished_ops_ids: List[int] -)->Tuple[TransferOpGraph, List[int]]: - if len(finished_ops_ids) <= 1: - return graph, finished_ops_ids +)->Tuple[TransferOpGraph, int]: + if len(finished_ops_ids) == 0: + return graph, -1 + elif len(finished_ops_ids) == 1: + return graph, finished_ops_ids[0] else: op = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.VIRTUAL, + src_block_ids = np.array([], dtype=np.int64), + dst_block_ids = np.array([], dtype=np.int64), layer_id = -1, layer_granularity = -1, ) graph.add_transfer_op(op) for op_id in finished_ops_ids: graph.add_dependency(op.op_id, op_id) - return graph, [op.op_id] + return graph, op.op_id def create_read_graph_cpu_storage( gpu_blocks: torch.Tensor, @@ -49,15 +54,8 @@ def create_read_graph_cpu_storage( op = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), + src_block_ids = cpu_blocks, + dst_block_ids = gpu_blocks, layer_id = 0, layer_granularity = layer_num, ) @@ -69,14 +67,8 @@ def create_read_graph_cpu_storage( op1 = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):] - ), + src_block_ids = ssd_blocks, + dst_block_ids = cpu_blocks[-len(ssd_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -84,15 +76,8 @@ def create_read_graph_cpu_storage( op2 = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):] - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[-len(ssd_blocks):], - device_id = gpu_device_id - ), + src_block_ids = cpu_blocks[-len(ssd_blocks):], + dst_block_ids = gpu_blocks[-len(ssd_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -102,15 +87,8 @@ def create_read_graph_cpu_storage( op3 = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[:len(cpu_blocks) - len(ssd_blocks)] - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], - device_id = gpu_device_id - ), + src_block_ids = cpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], + dst_block_ids = gpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], layer_id = 0, layer_granularity = layer_num ) @@ -121,14 +99,8 @@ def create_read_graph_cpu_storage( op1 = TransferOp( graph_id = graph.graph_id, transfer_type=TransferType.DISK2H, - src_descriptor=TransferDescriptor( - device_type=DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - dst_descriptor=TransferDescriptor( - device_type=DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), + src_block_ids=ssd_blocks, + dst_block_ids=cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -136,14 +108,8 @@ def create_read_graph_cpu_storage( op2 = TransferOp( graph_id = graph.graph_id, transfer_type=TransferType.H2D, - src_descriptor=TransferDescriptor( - device_type=DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - dst_descriptor=TransferDescriptor( - device_type=DeviceType.GPU, - physical_block_ids=gpu_blocks, - ), + src_block_ids=cpu_blocks, + dst_block_ids=gpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -191,14 +157,8 @@ def create_read_graph_cpu_ssd_remote( op_r2h = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.REMOTE2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=remote_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), + src_block_ids = remote_blocks, + dst_block_ids = cpu_blocks[-len(remote_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -206,15 +166,8 @@ def create_read_graph_cpu_ssd_remote( op_h2d = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[-len(remote_blocks):], - device_id = gpu_device_id - ), + src_block_ids = cpu_blocks[-len(remote_blocks):], + dst_block_ids = gpu_blocks[-len(remote_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -224,14 +177,8 @@ def create_read_graph_cpu_ssd_remote( op_h2disk = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks[-len(remote_blocks):], - ), + src_block_ids = cpu_blocks[-len(remote_blocks):], + dst_block_ids = ssd_blocks[-len(remote_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -265,15 +212,8 @@ def create_write_graph_cpu_storage( op_d2h = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), + src_block_ids = gpu_blocks, + dst_block_ids = cpu_blocks[-len(gpu_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -284,14 +224,8 @@ def create_write_graph_cpu_storage( op_h2disk = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), + src_block_ids = cpu_blocks[-len(ssd_blocks):], + dst_block_ids = ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -318,15 +252,8 @@ def create_write_graph_cpu_ssd_remote( op_d2h = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), + src_block_ids = gpu_blocks, + dst_block_ids = cpu_blocks[-len(gpu_blocks):], layer_id = 0, layer_granularity = layer_num ) @@ -335,14 +262,8 @@ def create_write_graph_cpu_ssd_remote( op_h2disk = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), + src_block_ids = cpu_blocks[-len(gpu_blocks):], + dst_block_ids = ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -352,14 +273,8 @@ def create_write_graph_cpu_ssd_remote( op_h2remote = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.H2REMOTE, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=remote_blocks, - ), + src_block_ids = cpu_blocks[-len(remote_blocks):], + dst_block_ids = remote_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -394,8 +309,8 @@ def convert_read_graph_to_layer_wise_graph( new_op = TransferOp( graph_id=new_graph.graph_id, transfer_type=op.transfer_type, - src_descriptor=op.src_descriptor, - dst_descriptor=op.dst_descriptor, + src_block_ids=op.src_block_ids, + dst_block_ids=op.dst_block_ids, layer_id=i * layer_granularity, layer_granularity=layer_granularity, # Inherit these fields directly diff --git a/flexkv/common/block.py b/flexkv/common/block.py index d1ab24687a..7e2cc5d8b2 100644 --- a/flexkv/common/block.py +++ b/flexkv/common/block.py @@ -5,13 +5,13 @@ import numpy as np import torch -from flexkv.common.hash_utils import HashType, gen_hashes, get_hash_size, hash_tensor +from flexkv.common.hash_utils import HashType, gen_hashes, get_hash_size, hash_array @dataclass class SequenceMeta: - token_ids: torch.Tensor + token_ids: np.ndarray tokens_per_block: int @@ -19,7 +19,7 @@ class SequenceMeta: _has_hashes: bool = False - def __init__(self, token_ids: torch.Tensor, tokens_per_block: int): + def __init__(self, token_ids: np.ndarray, tokens_per_block: int): self.token_ids = token_ids self.tokens_per_block = tokens_per_block @@ -44,13 +44,13 @@ def get_hash(self, block_id: int) -> Optional[HashType]: if self._has_hashes: return HashType(int(self.block_hashes[block_id].item())) else: - return hash_tensor(self.token_ids[:(block_id+1)*self.tokens_per_block]) + return hash_array(self.token_ids[:(block_id+1)*self.tokens_per_block]) def gen_hashes(self) -> None: if self._has_hashes: return assert self.token_ids.ndim == 1 - self.block_hashes = gen_hashes(self.token_ids, self.tokens_per_block).numpy() + self.block_hashes = gen_hashes(self.token_ids, self.tokens_per_block) assert self.block_hashes.ndim == 1 assert self.block_hashes.size == self.num_blocks assert self.block_hashes.itemsize == get_hash_size() diff --git a/flexkv/common/hash_utils.py b/flexkv/common/hash_utils.py index 6f8ec9fc96..8acdc49aa6 100644 --- a/flexkv/common/hash_utils.py +++ b/flexkv/common/hash_utils.py @@ -1,6 +1,7 @@ import time from typing import NewType, Optional +import numpy as np import torch from flexkv import c_ext @@ -18,32 +19,32 @@ def __init__(self) -> None: def reset(self) -> None: self.hasher.reset() - def update(self, tensor: torch.Tensor) -> None: - self.hasher.update(tensor) + def update(self, array: np.ndarray) -> None: + self.hasher.update(array) def digest(self) -> HashType: return HashType(self.hasher.digest()) -def hash_tensor(tensor: torch.Tensor) -> HashType: +def hash_array(array: np.ndarray) -> HashType: hasher = Hasher() - hasher.update(tensor) + hasher.update(array) return HashType(hasher.digest()) -def gen_hashes(token_ids: torch.Tensor, tokens_per_block: int, hasher: Optional[Hasher] = None) -> torch.Tensor: - block_hashes = torch.zeros(token_ids.numel() // tokens_per_block, dtype=torch.uint64) +def gen_hashes(token_ids: np.ndarray, tokens_per_block: int, hasher: Optional[Hasher] = None) -> np.ndarray: + block_hashes = np.zeros(token_ids.size // tokens_per_block, dtype=np.uint64) if hasher is None: hasher = Hasher() - c_ext.gen_hashes(hasher.hasher, token_ids, tokens_per_block, block_hashes) + c_ext.gen_hashes(hasher.hasher, torch.from_numpy(token_ids), tokens_per_block, torch.from_numpy(block_hashes)) return block_hashes if __name__ == "__main__": - torch.manual_seed(0) - token_ids = torch.randint(0, 10000, (32000, ), dtype=torch.int64) + np.random.seed(0) + token_ids = np.random.randint(0, 10000, (32000, ), dtype=np.int64) print(f"token ids length: {token_ids.shape[0]}") start = time.time() - result = hash_tensor(token_ids) + result = hash_array(token_ids) end = time.time() - print(f"tensor hash: {result}, time: {end - start}s") + print(f"array hash: {result}, time: {end - start}s") start = time.time() result2 = gen_hashes(token_ids, 16) end = time.time() diff --git a/flexkv/common/request.py b/flexkv/common/request.py index 1c871ea821..ef1c6ca009 100644 --- a/flexkv/common/request.py +++ b/flexkv/common/request.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from enum import Enum +from typing import Callable, List, Optional import torch +import numpy as np class KVRequestType(Enum): @@ -18,3 +20,21 @@ class KVRequest: slot_mapping: torch.Tensor layer_granularity: int = -1 dp_id: int = 0 + + +class KVResponseStatus(Enum): + SUCCESS = "success" + NOTFOUND = "not_found" + UNREADY = "unready" + TIMEOUT = "timeout" + CANCELLED = "cancelled" + FAILED = "failed" + +@dataclass +class KVResponse: + status: KVResponseStatus + task_id: int + return_mask: Optional[np.ndarray] + + def get_mask(self) -> torch.Tensor: + return torch.from_numpy(self.return_mask) if self.return_mask is not None else None diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 342f0678c0..ee64ce99f8 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -3,7 +3,7 @@ from enum import Enum from typing import ClassVar, List, Set, Dict -import torch +import numpy as np class DeviceType(Enum): @@ -31,16 +31,6 @@ class PartitionBlockType(Enum): ROUND_ROBIN = 0 SEQUENTIAL = 1 -@dataclass -class TransferDescriptor: - device_type: DeviceType = DeviceType.CPU - device_id: int = 0 - physical_block_ids: torch.Tensor = torch.tensor([], dtype=torch.int64) - - def __post_init__(self) -> None: - assert self.physical_block_ids.ndim == 1 - assert self.physical_block_ids.dtype == torch.int64 - class TransferOpStatus(Enum): PENDING = 0 RUNNING = 1 @@ -54,10 +44,10 @@ class TransferOp: op_id: int = field(init=False) graph_id: int transfer_type: TransferType - layer_id: int - layer_granularity: int - src_descriptor: TransferDescriptor = field(default_factory=TransferDescriptor) - dst_descriptor: TransferDescriptor = field(default_factory=TransferDescriptor) + src_block_ids: np.ndarray + dst_block_ids: np.ndarray + layer_id: int = 0 + layer_granularity: int = -1 # this will change dynamically as transfer ops executed predecessors: Set[int] = field(default_factory=set) # this will keep the full info @@ -67,8 +57,8 @@ class TransferOp: def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ - len(self.src_descriptor.physical_block_ids) != len(self.dst_descriptor.physical_block_ids): - raise ValueError("src_descriptor and dst_descriptor must have the same number of physical blocks") + self.src_block_ids.size != self.dst_block_ids.size: + raise ValueError("src_block_ids and dst_block_ids must have the same number of physical blocks") with TransferOp._lock: self.op_id = TransferOp._next_op_id TransferOp._next_op_id += 1 @@ -79,13 +69,14 @@ class TransferOpGraph: _lock = threading.Lock() def __init__(self) -> None: - self.graph_id = self._get_next_graph_id() + self.graph_id = self._get_graph_id() self._op_map: Dict[int, TransferOp] = {} self._ready_ops: Set[int] = set() self._trigger_ops: Set[int] = set() + self._gpu_transfer_op_id: int = -1 @classmethod - def _get_next_graph_id(cls) -> int: + def _get_graph_id(cls) -> int: with cls._lock: graph_id = cls._next_graph_id cls._next_graph_id += 1 @@ -115,6 +106,14 @@ def trigger_op(self, op_id: int) -> None: def add_transfer_op(self, op: TransferOp) -> None: op.graph_id = self.graph_id self._op_map[op.op_id] = op + if op.transfer_type == TransferType.H2D or \ + op.transfer_type == TransferType.D2H or \ + op.transfer_type == TransferType.D2DISK or \ + op.transfer_type == TransferType.DISK2D: + if self._gpu_transfer_op_id == -1: + self._gpu_transfer_op_id = op.op_id + else: + raise ValueError("Only one GPU transfer op is allowed") self._ready_ops.add(op.op_id) def add_dependency(self, successor_op_id: int, predecessor_op_id: int) -> None: @@ -130,8 +129,6 @@ def mark_completed(self, op_id: int) -> None: assert self._op_map[op_id].status == TransferOpStatus.RUNNING self._op_map[op_id].status = TransferOpStatus.COMPLETED my_successors = self._op_map[op_id].successors - if len(my_successors) == 0: - return for successor_id in my_successors: self._op_map[successor_id].predecessors.remove(op_id) @@ -164,6 +161,16 @@ def all_transfer_ops_completed(self) -> bool: return all(op.status == TransferOpStatus.COMPLETED for op in self._op_map.values()) + def set_gpu_blocks(self, gpu_blocks: np.ndarray) -> None: + transfer_type = self._op_map[self._gpu_transfer_op_id].transfer_type + op = self._op_map[self._gpu_transfer_op_id] + if transfer_type.name.endswith("2D"): + op.dst_block_ids = gpu_blocks + else: + op.src_block_ids = gpu_blocks + assert op.src_block_ids.size == op.dst_block_ids.size, \ + f"src_block_ids.size={op.src_block_ids.size}, dst_block_ids.size={op.dst_block_ids.size}" + @property def num_ops(self) -> int: return len(self._op_map) @@ -172,69 +179,6 @@ def bind_to_dp_group(self, dp_id: int) -> None: for op in self._op_map.values(): op.dp_id = dp_id - def print_op_map(self) -> None: - """Print transfer op graph in a visual format showing dependencies. - - Example output: - Transfer Graph 5: - ├── Op 1 (H2D) [Completed] - │ └── No successors - ├── Op 2 (D2H) [Pending] - │ └── Followed by: 1 - └── Op 3 (DISK2H) [Pending] - └── Followed by: 1, 2 - """ - print(f"Transfer Graph {self.graph_id}:") - - # get all op ids and sort them - op_ids = sorted(self._op_map.keys()) - - for i, op_id in enumerate(op_ids): - op = self._op_map[op_id] - is_last = (i == len(op_ids) - 1) - - # draw the tree structure branch - prefix = "└── " if is_last else "├── " - - # get the op status - status = "[Completed]" if op.status == TransferOpStatus.COMPLETED else "[Pending]" - - # print the op info - print(f"{prefix}Op {op_id} ({op.transfer_type.name}) {status}") - - if op.transfer_type == TransferType.VIRTUAL: - continue - # print the dependency info - dep_prefix = " " if is_last else "│ " - if not op.successors: - print(f"{dep_prefix}└── No successors") - else: - deps_str = ", ".join(str(dep) for dep in sorted(op.successors)) - print(f"{dep_prefix}└── Followed by: {deps_str}") - - # print the transfer details - src_info = f"From: {op.src_descriptor.device_type.name}:{op.src_descriptor.device_id}" - dst_info = f"To: {op.dst_descriptor.device_type.name}:{op.dst_descriptor.device_id}" - print(f"{dep_prefix} └── {src_info} -> {dst_info}") - - print(f"{dep_prefix} └── layers: {op.layer_id} - {op.layer_id + op.layer_granularity}") - - # if there are physical block ids, also print them - if len(op.src_descriptor.physical_block_ids) > 0: - blocks = op.src_descriptor.physical_block_ids.tolist() - if len(blocks) > 3: - blocks_str = f"{blocks[:3]}... ({len(blocks)} blocks)" - else: - blocks_str = str(blocks) - print(f"{dep_prefix} └── Src Blocks: {blocks_str}") - if len(op.dst_descriptor.physical_block_ids) > 0: - blocks = op.dst_descriptor.physical_block_ids.tolist() - if len(blocks) > 3: - blocks_str = f"{blocks[:3]}... ({len(blocks)} blocks)" - else: - blocks_str = str(blocks) - print(f"{dep_prefix} └── Dst Blocks: {blocks_str}") - def get_nvtx_default_color() -> int: return 0xD3D3D3 diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index fd75750530..8724ce6829 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -12,566 +12,192 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing as mp -import threading + +from typing import Optional, Tuple, List, Dict, Union, Iterable import time -from dataclasses import dataclass -from queue import Queue -from typing import Dict, Any, Optional -from typing import List, Callable, Union -import nvtx +import numpy as np import torch -from expiring_dict import ExpiringDict -from flexkv.cache.cache_engine import GlobalCacheEngine, TransferOpGraph -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.server.client import KVDPClient +from flexkv.server.server import KVServer, DPClient +from flexkv.kvtask import KVTaskEngine, KVResponse +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.debug import flexkv_logger -from flexkv.common.memory_handle import TensorSharedHandle -from flexkv.common.request import KVRequestType, KVRequest -from flexkv.common.transfer import DeviceType, get_nvtx_range_color, get_nvtx_default_color -from flexkv.common.storage import KVCacheLayout -from flexkv.common.exceptions import LogicError -from flexkv.common.tracer import FlexKVTracer -from flexkv.storage.storage_engine import StorageEngine -from flexkv.transfer.transfer_engine import TransferEngine - - -@dataclass -class RequestTracker: - task_id: int - task_type: KVRequestType - return_mask: torch.Tensor - callback: Optional[Callable] - task_end_ops_ids: List[int] - task_end_ops_status: List[bool] - task_finished: bool = False class KVManager: def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - gpu_layout: Optional[KVCacheLayout] = None, - gpu_blocks: Optional[Dict[int, List[TensorSharedHandle]]] = None): - - flexkv_logger.info(f"Initializing kvmanager...\nmodel_config: {model_config}\ncache_config: {cache_config}") - - mp.set_start_method('spawn', force=True) - self.init_nvtx_range = nvtx.push_range("Initialize kvmanager", color=get_nvtx_default_color()) - - if not cache_config.enable_cpu: - raise ValueError("enable_cpu must be True") - if cache_config.enable_remote and not cache_config.enable_ssd: - raise ValueError("enable_ssd must be True if enable_remote is True") - if not cache_config.enable_cpu and not cache_config.use_gds: - raise ValueError("use_gds must be True if enable_cpu is False") - self.cache_config = cache_config + gpu_register_port: Optional[str] = None, + server_recv_port: Optional[str] = None): + flexkv_logger.info(f"{model_config = }") + flexkv_logger.info(f"{cache_config = }") self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + self.server_recv_port = server_recv_port + self.server_client_mode = model_config.dp_size > 1 # True #just for test + flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") + if self.server_client_mode: + self.server_handle = KVServer.create_server(model_config, cache_config, gpu_register_port, server_recv_port) + self.dp_client = KVDPClient(self.server_recv_port, self.model_config) + else: + self.server_handle = None + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port) - self._verify_Model_Cache_config(model_config, cache_config) - self.cache_engine = GlobalCacheEngine(cache_config, model_config) - self.storage_engine = StorageEngine(self.model_config, self.cache_config) - - # Initialize tracer - self.tracer = FlexKVTracer(cache_config) - - # Record configuration in tracer - if gpu_layout is not None: - self.tracer.trace_config(model_config, cache_config, gpu_layout) - - - self.transfer_engine: Optional[TransferEngine] = None - self.gpu_layout: Optional[KVCacheLayout] = gpu_layout - - self.running = False - self.requests_tracker: ExpiringDict[int, RequestTracker] = ExpiringDict(1800) # 30 minutes - self.graph_to_request: Dict[int, int] = {} - self.taskid_to_nvtx_range: Dict[int, Any] = {} - self.graphid_to_nvtx_range: Dict[int, Any] = {} - - self._task_id_counter = 0 - self.task_queue: Queue[KVRequest] = Queue() - - if gpu_blocks is None: - gpu_blocks = {} - - self.num_gpus = len(gpu_blocks) - self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = gpu_blocks - - self.lock = threading.Lock() - - if self.num_gpus == self.model_config.tp_size * self.model_config.dp_size: - self._init_transfer_engine() - - # Note that for now only after all the gpu blocks are added, we can initialize the transfer engine - def _init_transfer_engine(self) -> None: - assert self.gpu_layout is not None - assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size - for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): - self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, - self.gpu_layout, - device_id, - dtype=self.model_config.dtype) - self.gpu_handles = [ - self.storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(self.model_config.tp_size * self.model_config.dp_size) - ] - cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) if self.cache_config.enable_cpu else None - ssd_handle = self.storage_engine.get_storage_handle(DeviceType.SSD) if self.cache_config.enable_ssd else None - remote_handle = ( - self.storage_engine.get_storage_handle(DeviceType.REMOTE) - if self.cache_config.enable_remote - else None - ) - self.transfer_engine = TransferEngine(self.gpu_handles, - self.model_config, - self.cache_config, - cpu_handle, - ssd_handle, - remote_handle) - - nvtx.pop_range(self.init_nvtx_range) - - - def is_ready(self) -> bool: - return self.transfer_engine is not None - - def is_running(self) -> bool: - return self.running + #def _launch_server(self) -> None: + # self.server = KVServer(self.model_config, self.cache_config, self.server_recv_port) + # self.server.run() + # time.sleep(10) + # self.dp_client = DPClient(self.server_recv_port, self.model_config) def start(self) -> None: - if self.running: - flexkv_logger.warning("kvmanager is already running") - return - if not self.is_ready(): - raise ValueError("transfer engine is not ready, please add all gpu blocks first") - if self.transfer_engine is not None: - self.transfer_engine.start() - self.running = True - else: - raise ValueError("transfer engine is not initialized, please call start() after all gpu blocks are added") - - self._worker_thread = threading.Thread(target=self._worker_loop) - self._worker_thread.start() - flexkv_logger.info("KVManager fully started and running") + if not self.server_client_mode: + self.kv_task_engine.start() + # for server client mode, we need to do nothing, because the start is actually called + # when the server is created - # the gpu_blocks of multiple gpus can be added post initialization. - # the transfer engine will be initialized after we have all the intended gpu handles. - def register_single_gpu_blocks( - self, - gpu_handles: List[TensorSharedHandle], - gpu_layout: KVCacheLayout, - dp_client_id: int = 0, - tp_rank: int = 0, - ) -> None: - if self.transfer_engine is not None: - raise ValueError("we have already get all gpu blocks") - if self.gpu_layout is None: - self.gpu_layout = gpu_layout - self.tracer.trace_config(self.model_config, self.cache_config, self.gpu_layout) + def is_ready(self) -> bool: + if self.server_client_mode: + return self.server_handle is not None and self.server_handle.ready_event.is_set() else: - assert self.gpu_layout == gpu_layout - self.all_gpu_blocks[tp_rank + dp_client_id * self.model_config.tp_size] = gpu_handles - self.num_gpus += 1 - if self.num_gpus == self.model_config.tp_size * self.model_config.dp_size: - self._init_transfer_engine() - - def _worker_loop(self) -> None: - assert self.transfer_engine is not None - while self.running: - # deal with completed requests from the cache engine - if not self.task_queue.empty(): - request = self.task_queue.get() - if request.request_type == KVRequestType.SHUTDOWN: - self.shutdown() - break - elif request.request_type == KVRequestType.GET: - nvtx.push_range(f"cache_engine.get request_id: {request.request_id}", - color=get_nvtx_default_color()) - graph, return_mask, callback, task_end_ops_ids = self.cache_engine.get(request.request_id, - request.token_ids, - request.token_mask, - request.slot_mapping, - self.model_config.num_layers, - request.layer_granularity, - request.dp_id) - elif request.request_type == KVRequestType.PUT: - nvtx.push_range(f"cache_engine.put request_id: {request.request_id}", - color=get_nvtx_default_color()) - graph, return_mask, callback, task_end_ops_ids = self.cache_engine.put(request.request_id, - request.token_ids, - request.token_mask, - request.slot_mapping, - self.model_config.num_layers, - request.dp_id) - else: - raise ValueError(f"Unknown request type: {request.request_type}") - nvtx.pop_range() - if graph.num_ops == 0: #early return - flexkv_logger.info(f"no transfer: " - f"request_id = {request.request_id}, request_type = {request.request_type}") - layer_op_num = self.model_config.num_layers // request.layer_granularity \ - if request.request_type == KVRequestType.GET else 1 - self.requests_tracker[request.request_id] = RequestTracker(task_id=request.request_id, - task_type=request.request_type, - return_mask=return_mask, - callback=None, - task_end_ops_ids=[-1]*layer_op_num, - task_end_ops_status=[True]*layer_op_num, - task_finished=True) - else: - self.graph_to_request[graph.graph_id] = request.request_id - self.graphid_to_nvtx_range[graph.graph_id] = nvtx.start_range( - f"request id: {request.request_id}, " - f"graph id: {graph.graph_id}", - color=get_nvtx_range_color(graph.graph_id)) - self.requests_tracker[request.request_id] = RequestTracker(task_id=request.request_id, - task_type=request.request_type, - return_mask=return_mask, - callback=callback, - task_end_ops_ids=task_end_ops_ids, - task_end_ops_status=len(task_end_ops_ids)*[False], - task_finished=False) - self.transfer_engine.submit_transfer_graph(graph) - results = self.transfer_engine.get_completed_graphs_and_ops(timeout=0.001) - for completed_graph_id, completed_op_id in results: - request_id = self.graph_to_request[completed_graph_id] - request_tracker = self.requests_tracker[request_id] - if completed_op_id == -1: - if request_tracker.callback: - request_tracker.callback() - nvtx.end_range(self.graphid_to_nvtx_range[completed_graph_id]) - self.graphid_to_nvtx_range.pop(completed_graph_id) - self.graph_to_request.pop(completed_graph_id) - nvtx.end_range(self.taskid_to_nvtx_range[request_tracker.task_id]) - self.taskid_to_nvtx_range.pop(request_tracker.task_id) - request_tracker.task_finished = True - elif completed_op_id in request_tracker.task_end_ops_ids: - request_tracker.task_end_ops_status[request_tracker.task_end_ops_ids.index(completed_op_id)] = True - self.requests_tracker[request_id] = request_tracker - time.sleep(0.0001) - - def _get_task_id(self) -> int: - with self.lock: - old_value = self._task_id_counter - self._task_id_counter += 1 - return old_value - - def __del__(self) -> None: - if hasattr(self, 'tracer'): - self.tracer.flush() - if self.running: - self.shutdown() + return self.kv_task_engine.is_ready() def shutdown(self) -> None: - self.running = False - # Flush tracer before shutdown - if hasattr(self, 'tracer'): - self.tracer.flush() - flexkv_logger.info("kvmanager shutdown") - self.task_queue.put(KVRequest( - request_type=KVRequestType.SHUTDOWN, - request_id=-1, - token_ids=torch.empty(0), - token_mask=torch.empty(0), - slot_mapping=torch.empty(0), - )) - self._worker_thread.join() - if self.transfer_engine is not None: - self.transfer_engine.shutdown() + if self.server_client_mode: + if self.server_handle is not None: + self.server_handle.shutdown() + else: + flexkv_logger.error("Shutdown server failed, server is not created") + else: + self.kv_task_engine.shutdown() def get_async(self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, + token_ids: Union[torch.Tensor, np.ndarray], + slot_mapping: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, dp_id: int = 0, - task_id: int = -1) -> int: - if not self.running: - raise ValueError("kvmanager is not running, please call start() first") - if token_mask is None: - token_mask = torch.ones_like(token_ids) - if layer_granularity == -1: - layer_granularity = self.model_config.num_layers - if task_id == -1: - task_id = self._get_task_id() - # Trace the request - self.tracer.trace_request( - request_type="GET", - request_id=task_id, - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - layer_granularity=layer_granularity, - dp_id=dp_id - ) - nvtx.mark(f"GET request_id: {task_id}") - self.taskid_to_nvtx_range[task_id] = nvtx.start_range(f"GET request_id: {task_id}", - color=get_nvtx_default_color()) - self.task_queue.put(KVRequest( - request_type=KVRequestType.GET, - request_id=task_id, - token_ids=token_ids, - token_mask=token_mask, - slot_mapping=slot_mapping, - layer_granularity=layer_granularity, - dp_id=dp_id, - )) - self.requests_tracker[task_id] = RequestTracker(task_id=task_id, - task_type=KVRequestType.GET, - return_mask=torch.empty(0), - callback=None, - task_end_ops_ids=[], - task_end_ops_status=[], - task_finished=False) + ) -> int: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(slot_mapping, torch.Tensor): + slot_mapping = slot_mapping.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id = self.dp_client.get_async(token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id) + else: + task_id, _ = self.kv_task_engine.get_async(token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id) return task_id + def get_match(self, + token_ids: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + layer_granularity: int = -1, + dp_id: int = 0, + ) -> Tuple[int, np.ndarray]: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id, mask = self.dp_client.get_match(token_ids, + token_mask, + layer_granularity, + dp_id) + else: + task_id, mask = self.kv_task_engine.get_match(token_ids, + token_mask, + layer_granularity, + dp_id) + mask = torch.from_numpy(mask) if mask is not None else None + return task_id, mask + def put_async(self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, + token_ids: Union[torch.Tensor, np.ndarray], + slot_mapping: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, dp_id: int = 0, - task_id: int = -1) -> int: - if not self.running: - raise ValueError("kvmanager is not running, please call start() first") - if token_mask is None: - token_mask = torch.ones_like(token_ids) - if task_id == -1: - task_id = self._get_task_id() - # Trace the request - self.tracer.trace_request( - request_type="PUT", - request_id=task_id, - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - dp_id=dp_id - ) - nvtx.mark(f"PUT request_id: {task_id}") - self.taskid_to_nvtx_range[task_id] = nvtx.start_range(f"PUT request_id: {task_id}", - color=get_nvtx_default_color()) - self.task_queue.put(KVRequest( - request_type=KVRequestType.PUT, - request_id=task_id, - token_ids=token_ids, - token_mask=token_mask, - slot_mapping=slot_mapping, - dp_id=dp_id, - )) - self.requests_tracker[task_id] = RequestTracker(task_id=task_id, - task_type=KVRequestType.PUT, - return_mask=torch.empty(0), - callback=None, - task_end_ops_ids=[], - task_end_ops_status=[], - task_finished=False) + ) -> int: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(slot_mapping, torch.Tensor): + slot_mapping = slot_mapping.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask, dp_id) + else: + task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_id) return task_id - # wait for the key op to be finished - def wait(self, task_ids: Union[int, List[int]], timeout: float = 20.0) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait", - task_ids=task_ids, - ) - nvtx.mark(f"wait task_ids: {task_ids}") + def put_match(self, + token_ids: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + dp_id: int = 0, + ) -> Tuple[int, np.ndarray]: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id, mask = self.dp_client.put_match(token_ids, token_mask, dp_id) + else: + task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id) + mask = torch.from_numpy(mask) if mask is not None else None + return task_id, mask + + def launch(self, + task_ids: Union[int, List[int]], + slot_mappings: Union[np.ndarray, List[np.ndarray]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] - num_completed_tasks = 0 - num_tasks = len(task_ids) - return_masks = {} - start_time = time.time() - while num_completed_tasks < num_tasks: - finished_task_ids = [] + if not isinstance(slot_mappings, List): + slot_mappings = [slot_mappings] + if isinstance(slot_mappings[0], torch.Tensor): + slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] + if self.server_client_mode: for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0, dtype=torch.bool) #if not found in tracker, the return mask is an empty tensor - num_completed_tasks += 1 - finished_task_ids.append(task_id) - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.return_mask) == 0: #NOT READY - continue - if all(task_tracker.task_end_ops_status): - num_completed_tasks += 1 - return_masks[task_id] = task_tracker.return_mask - finished_task_ids.append(task_id) - task_ids = [task_id for task_id in task_ids if task_id not in finished_task_ids] - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_ids: {task_ids} timeout, has to return now") - for task_id in task_ids: - return_masks[task_id] = torch.empty(0, dtype=torch.bool) # return mask of timeout task is also an empty tensor - nvtx.mark(f"wait task_ids: {task_ids} timeout") - return return_masks - time.sleep(0.0001) - nvtx.mark(f"wait task_ids: {task_ids} done") - return return_masks + self.dp_client.launch_task(task_id, slot_mappings) + else: + self.kv_task_engine.launch_transfer(task_ids, slot_mappings) - # wait for the whole task to be finished, including the key op and all other ops - # this function is mainly designed for testing to avoid the frequency of writing is too high to use up memory blocks - def wait_for_graph_finished(self, - task_ids: Union[int, List[int]], - timeout: float = 20.0) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait_for_graph_finished", - task_ids=task_ids, - ) - nvtx.mark(f"wait task_ids: {task_ids}") + def cancel_task(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] - num_completed_tasks = 0 - return_masks = {} - start_time = time.time() - while num_completed_tasks < len(task_ids): - finished_task_ids = [] + if self.server_client_mode: + for task_id in task_ids: + self.dp_client.cancel_task(task_id) + else: for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - num_completed_tasks += 1 - finished_task_ids.append(task_id) - continue - task_tracker = self.requests_tracker[task_id] - if task_tracker.task_finished: - num_completed_tasks += 1 - return_masks[task_id] = task_tracker.return_mask - finished_task_ids.append(task_id) - task_ids = [task_id for task_id in task_ids if task_id not in finished_task_ids] - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_ids: {task_ids} timeout, has to return now") - for task_id in task_ids: - return_masks[task_id] = torch.empty(0) # return mask of timeout task is also an empty tensor - nvtx.mark(f"wait task_ids: {task_ids} timeout") - return return_masks - time.sleep(0.0001) - nvtx.mark(f"wait task_ids: {task_ids} done") - return return_masks + self.kv_task_engine.cancel_task(task_id) - # the try_wait api is used for server-client mode: - # server process running the kvmanager should NOT be blocked by any single client - def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="try_wait", - task_ids=task_ids, - ) - return_masks: Dict[int, torch.Tensor] = {} + def wait(self, + task_ids: Union[int, List[int]], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - mask = None - elif all(task_tracker.task_end_ops_status): - mask = task_tracker.return_mask - return_masks[task_id] = mask - else: - mask = None - - return return_masks - - def wait_at_layer_group(self, task_id: int, layer_group_id: int, timeout: float = 20.0) -> torch.Tensor: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait_at_layer_group", - task_ids=task_id, - layer_group_id=layer_group_id - ) - nvtx.mark(f"wait task_id: {task_id}, layer_group_id: {layer_group_id}") - start_time = time.time() - while True: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return torch.empty(0) #if not found in tracker, the return mask is an empty tensor - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - continue - if task_tracker.task_end_ops_status[layer_group_id]: - return task_tracker.return_mask - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_id: {task_id}, layer_group_id: {layer_group_id} " - f"timeout, has to return now") - return torch.empty(0) # return mask of timeout task is an empty tensor - time.sleep(0.0001) - - # nvtx.mark(f"wait_at_layer_group task_id: {task_id}, layer_group_id: {layer_group_id} done") - # return return_mask + if self.server_client_mode: + return self.dp_client.wait(task_ids, timeout, completely) + else: + return self.kv_task_engine.wait(task_ids, timeout, completely) - def try_wait_at_layer_group(self, - task_ids: Union[int, List[int]], - layer_group_id: int) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="try_wait_at_layer_group", - task_ids=task_ids, - layer_group_id=layer_group_id, - ) - return_masks: Dict[int, torch.Tensor] = {} + def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - mask = torch.empty(0) - elif task_tracker.task_end_ops_status[layer_group_id]: - mask = task_tracker.return_mask - else: - mask = torch.empty(0) - return_masks[task_id] = mask - return return_masks - - def _verify_Model_Cache_config(self, - model_config: ModelConfig, - cache_config: CacheConfig): - if cache_config.enable_remote: - if cache_config.remote_cache_path is None: - - if cache_config.remote_file_prefix is None: - raise ValueError("remote_file_prefix must be provided when remote_cache_path is None") - - if cache_config.remote_file_num is None or cache_config.remote_file_num <= 0: - raise ValueError("remote_file_num must be a positive integer") - - cache_config.remote_cache_path = [ - f"{cache_config.remote_file_prefix}_{i}" - for i in range(cache_config.remote_file_num) - ] - - if cache_config.remote_cache_size_mode == "block_num": - if cache_config.num_remote_blocks is None: - raise ValueError("num_remote_blocks must not None if use block_num model") - elif cache_config.remote_cache_size_mode == "file_size": - if cache_config.remote_file_size is None: - raise ValueError("remote_file_size must not None if use file_size model") - if model_config.use_mla: - kv_size = ( - model_config.num_layers - * cache_config.tokens_per_block - * model_config.num_kv_heads - * model_config.head_size - * model_config.dtype.itemsize - ) - else: - kv_size = ( - model_config.num_layers - * 2 - * cache_config.tokens_per_block - * model_config.num_kv_heads - * model_config.head_size - * model_config.dtype.itemsize - ) - cache_config.num_remote_blocks = cache_config.remote_file_size // kv_size * cache_config.remote_file_num - - else: - raise ValueError("remote_cache_size_mode must block_num or file_size model") + if self.server_client_mode: + return self.dp_client.try_wait(task_ids) + else: + return self.kv_task_engine.try_wait(task_ids) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py new file mode 100644 index 0000000000..625d98c78e --- /dev/null +++ b/flexkv/kvtask.py @@ -0,0 +1,504 @@ +import time +from typing import Dict, Optional, List, Union, Tuple +import threading +from enum import Enum +from dataclasses import dataclass +from typing import Callable + + +from expiring_dict import ExpiringDict +import nvtx +import torch +import numpy as np + +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.debug import flexkv_logger +from flexkv.common.transfer import TransferOpGraph +from flexkv.common.tracer import FlexKVTracer +from flexkv.cache.cache_engine import GlobalCacheEngine +from flexkv.transfer_manager import TransferManagerHandle +from flexkv.common.request import KVResponseStatus, KVResponse + +class TaskStatus(Enum): + # slot mapping is not ready + UNREADY = "unready" + # waiting for the task to be launched + READY = "ready" + # in transfer + RUNNING = "running" + # transfer completed + COMPLETED = "completed" + # transfer cancelled + CANCELLED = "cancelled" + # transfer failed + FAILED = "failed" + +class TaskType(Enum): + GET = "get" + PUT = "put" + +@dataclass +class KVTask: + # task descriptor + task_id: int + task_type: TaskType + task_end_op_id: int + task_end_op_finished: bool + status: TaskStatus + + # params + token_ids: np.ndarray + slot_mapping: np.ndarray + token_mask: Optional[np.ndarray] + dp_id: int + + # cache engine return + graph: TransferOpGraph + return_mask: np.ndarray + callback: Optional[Callable] + + def is_completed(self) -> bool: + return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED] + +def task_status_to_response_status(task_status: TaskStatus) -> KVResponseStatus: + return { + TaskStatus.COMPLETED: KVResponseStatus.SUCCESS, + TaskStatus.CANCELLED: KVResponseStatus.CANCELLED, + TaskStatus.FAILED: KVResponseStatus.FAILED, + }[task_status] + +class KVTaskManager: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: Optional[str] = None, + use_separate_process: bool = True, + ): + if not cache_config.enable_cpu: + raise ValueError("enable_cpu must be True") + if cache_config.enable_remote and not cache_config.enable_ssd: + raise ValueError("enable_ssd must be True if enable_remote is True") + if not cache_config.enable_cpu and not cache_config.use_gds: + raise ValueError("use_gds must be True if enable_cpu is False") + self.cache_config = cache_config + self.model_config = model_config + self._check_config(model_config, cache_config) + + self.cache_engine = GlobalCacheEngine(cache_config, model_config) + + self.transfer_handle = TransferManagerHandle( + self.model_config, + self.cache_config, + use_separate_process=use_separate_process, + gpu_register_port=gpu_register_port + ) + + self.tasks: ExpiringDict[int, KVTask] = ExpiringDict(max_age_seconds=1800, max_len=100000) # 30 minutes + self.graph_to_task: Dict[int, int] = {} + + self.task_id_counter = 0 + + self.task_id_lock = threading.Lock() + + def start(self) -> None: + self.transfer_handle.start() + + def is_ready(self) -> bool: + return self.transfer_handle.is_ready() + + def __del__(self) -> None: + self.shutdown() + + def shutdown(self) -> None: + self.transfer_handle.shutdown() + + def create_get_task(self, + task_id: int, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + is_fake_slot_mapping: bool = False, + ) -> None: + if task_id in self.tasks: + raise ValueError(f"Task ID {task_id} already exists") + graph, return_mask, callback, task_end_op_id = self.cache_engine.get(task_id, + token_ids, + token_mask, + slot_mapping, + self.model_config.num_layers, + layer_granularity, + dp_id) + self.tasks[task_id] = KVTask( + task_id=task_id, + task_type=TaskType.GET, + task_end_op_id=task_end_op_id, + task_end_op_finished=False, + status=TaskStatus.UNREADY if is_fake_slot_mapping else TaskStatus.READY, + token_ids=token_ids, + slot_mapping=slot_mapping, + token_mask=token_mask, + dp_id=dp_id, + graph=graph, + return_mask=return_mask, + callback=callback) + + self.graph_to_task[graph.graph_id] = task_id + + def create_put_task(self, + task_id: int, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + is_fake_slot_mapping: bool = False, + ) -> None: + if task_id in self.tasks: + raise ValueError(f"Task ID {task_id} already exists") + graph, return_mask, callback, task_end_op_id = self.cache_engine.put(task_id, + token_ids, + token_mask, + slot_mapping, + self.model_config.num_layers, + dp_id) + self.tasks[task_id] = KVTask( + task_id=task_id, + task_type=TaskType.PUT, + task_end_op_id=task_end_op_id, + task_end_op_finished=False, + status=TaskStatus.UNREADY if is_fake_slot_mapping else TaskStatus.READY, + token_ids=token_ids, + slot_mapping=slot_mapping, + token_mask=token_mask, + dp_id=dp_id, + graph=graph, + return_mask=return_mask, + callback=callback) + self.graph_to_task[graph.graph_id] = task_id + + def launch_task(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + return + if task.status != TaskStatus.READY: + raise ValueError(f"Task {task_id} status is {task.status}, cannot launch") + transfer_graph = task.graph + task.status = TaskStatus.RUNNING + if transfer_graph.num_ops > 0: + self.transfer_handle.submit(transfer_graph) + + def update_tasks(self, timeout: float = 0.001) -> None: + completed_ops = self._get_completed_ops(timeout) + for completed_graph_id, completed_op_id in completed_ops: + if completed_graph_id not in self.graph_to_task: + continue + task_id = self.graph_to_task[completed_graph_id] + task = self.tasks[task_id] + if completed_op_id == -1: # the graph is totally finished + self._mark_completed(task_id) + elif completed_op_id == task.task_end_op_id: + self.tasks[task_id].task_end_op_finished = True + + def cancel_task(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + flexkv_logger.warning(f"Task {task_id} is already completed, cannot cancel") + return + if task.status == TaskStatus.RUNNING: + flexkv_logger.warning(f"Task {task_id} is running, cannot cancel") + return + if task.status == TaskStatus.CANCELLED: + flexkv_logger.warning(f"Task {task_id} is already cancelled, cannot cancel") + return + task.status = TaskStatus.CANCELLED + self.graph_to_task.pop(task.graph.graph_id, None) + + def check_completed(self, task_id: int, completely: bool = False) -> bool: + self._process_empty_graph(task_id) + task = self.tasks[task_id] + if completely: + return task.is_completed() + return task.is_completed() or task.task_end_op_finished + + def set_slot_mappings(self, + task_ids: List[int], + slot_mappings: List[np.ndarray]) -> None: + for task_id, slot_mapping in zip(task_ids, slot_mappings): + self._set_slot_mapping_impl(task_id, slot_mapping) + + def _set_slot_mapping_impl(self, task_id: int, slot_mapping: np.ndarray) -> None: + task = self.tasks[task_id] + if task.status != TaskStatus.UNREADY: + return + graph_ids = self.cache_engine.slot_mapping_to_block_ids(slot_mapping[task.return_mask.astype(np.bool_)], + self.cache_config.tokens_per_block) + task.graph.set_gpu_blocks(graph_ids) + task.status = TaskStatus.READY + + def _gen_task_id(self) -> int: + with self.task_id_lock: + old_value = self.task_id_counter + self.task_id_counter += 1 + return old_value + + def _mark_completed(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + return + if task.callback: + task.callback() + task.status = TaskStatus.COMPLETED + task.task_end_op_finished = True + self.graph_to_task.pop(task.graph.graph_id) + + def _process_empty_graph(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.graph.num_ops == 0: + self._mark_completed(task_id) + + def _get_completed_ops(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_handle.wait(timeout) + + def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> None: + if cache_config.enable_remote: + if cache_config.remote_cache_path is None: + + if cache_config.remote_file_prefix is None: + raise ValueError("remote_file_prefix must be provided when remote_cache_path is None") + + if cache_config.remote_file_num is None or cache_config.remote_file_num <= 0: + raise ValueError("remote_file_num must be a positive integer") + + cache_config.remote_cache_path = [ + f"{cache_config.remote_file_prefix}_{i}" + for i in range(cache_config.remote_file_num) + ] + + if cache_config.remote_cache_size_mode == "block_num": + if cache_config.num_remote_blocks is None: + raise ValueError("num_remote_blocks must not None if use block_num model") + elif cache_config.remote_cache_size_mode == "file_size": + if cache_config.remote_file_size is None: + raise ValueError("remote_file_size must not None if use file_size model") + if model_config.use_mla: + kv_size = ( + model_config.num_layers + * cache_config.tokens_per_block + * model_config.num_kv_heads + * model_config.head_size + * model_config.dtype.itemsize + ) + else: + kv_size = ( + model_config.num_layers + * 2 + * cache_config.tokens_per_block + * model_config.num_kv_heads + * model_config.head_size + * model_config.dtype.itemsize + ) + cache_config.num_remote_blocks = cache_config.remote_file_size // kv_size * cache_config.remote_file_num + + else: + raise ValueError("remote_cache_size_mode must block_num or file_size model") + + +class KVTaskEngine(KVTaskManager): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: Optional[str] = None, + use_separate_process: bool = True, + ): + super().__init__(model_config, cache_config, gpu_register_port, use_separate_process) + self.tracer = FlexKVTracer(cache_config) + + def get_async(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + task_id, return_mask = self._get_match_impl(token_ids, + slot_mapping, + is_fake_slot_mapping=False, + token_mask=token_mask, + layer_granularity=layer_granularity, + dp_id=dp_id, + task_id=task_id) + self.launch_task(task_id) + return task_id, return_mask + + def put_async(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + task_id, return_mask = self._put_match_impl(token_ids, + slot_mapping, + is_fake_slot_mapping=False, + token_mask=token_mask, + dp_id=dp_id, + task_id=task_id) + self.launch_task(task_id) + return task_id, return_mask + + def _wait_impl(self, + task_ids: List[int], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: + return_responses = {} + start_time = time.time() + is_timeout = timeout == 0.0 + + self.update_tasks(timeout=0) + + for task_id in task_ids: + while True: + if task_id not in self.tasks: + flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") + return_responses[task_id] = KVResponse( + status=KVResponseStatus.NOTFOUND, + task_id=task_id, + return_mask=None + ) + break + elif self.tasks[task_id].status == TaskStatus.UNREADY: + flexkv_logger.warning(f"task_id {task_id} is unready") + return_responses[task_id] = KVResponse( + status=KVResponseStatus.UNREADY, + task_id=task_id, + return_mask=None + ) + break + elif self.check_completed(task_id, completely=completely): + self.tasks[task_id].status = TaskStatus.COMPLETED # TODO is this correct? + return_responses[task_id] = KVResponse( + status=task_status_to_response_status(self.tasks[task_id].status), + task_id=task_id, + return_mask=self.tasks[task_id].return_mask + ) + break + elif is_timeout: + return_responses[task_id] = KVResponse( + status=KVResponseStatus.TIMEOUT, + task_id=task_id, + return_mask=None + ) + break + else: + if time.time() - start_time > timeout: + is_timeout = True + break + self.update_tasks(timeout=0.001) + return return_responses + + def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: + if isinstance(task_ids, int): + task_ids = [task_ids] + return_responses = self._wait_impl(task_ids, + timeout=0.0, + completely=False) + return return_responses + + def wait(self, + task_ids: Union[int, List[int]], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: + nvtx.mark(f"wait task_ids: {task_ids}") + if isinstance(task_ids, int): + task_ids = [task_ids] + return_responses = self._wait_impl(task_ids, timeout, completely=completely) + nvtx.mark(f"wait task_ids: {task_ids} done") + return return_responses + + def get_match(self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + fake_slot_mapping = np.zeros_like(token_ids) + return self._get_match_impl(token_ids, + fake_slot_mapping, + is_fake_slot_mapping=True, + token_mask=token_mask, + layer_granularity=layer_granularity, + dp_id=dp_id, + task_id=task_id) + + def _get_match_impl(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + is_fake_slot_mapping: bool = False, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids) + if layer_granularity == -1: + layer_granularity = self.model_config.num_layers + if task_id == -1: + task_id = self._gen_task_id() + nvtx.mark(f"GET task_id: {task_id}") + self.create_get_task(task_id, + token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id, + is_fake_slot_mapping=is_fake_slot_mapping) + self._process_empty_graph(task_id) + return task_id, self.tasks[task_id].return_mask + + def put_match(self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + fake_slot_mapping = np.zeros_like(token_ids) + return self._put_match_impl(token_ids, + fake_slot_mapping, + is_fake_slot_mapping=True, + token_mask=token_mask, + dp_id=dp_id, + task_id=task_id) + + def _put_match_impl(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + is_fake_slot_mapping: bool = False, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids) + if task_id == -1: + task_id = self._gen_task_id() + nvtx.mark(f"PUT task_id: {task_id}") + self.create_put_task(task_id, + token_ids, + slot_mapping, + token_mask, + dp_id, + is_fake_slot_mapping=is_fake_slot_mapping) + self._process_empty_graph(task_id) + return task_id, self.tasks[task_id].return_mask + + def launch_transfer(self, + task_ids: List[int], + slot_mappings: List[np.ndarray]) -> None: + assert isinstance(slot_mappings[0], np.ndarray) + self.set_slot_mappings(task_ids, slot_mappings) + for task_id in task_ids: + self.launch_task(task_id) + + def cancel(self, task_ids: Union[int, List[int]]) -> None: + if isinstance(task_ids, int): + task_ids = [task_ids] + for task_id in task_ids: + self.cancel_task(task_id) diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 257c3f2fa3..a6db160934 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -2,16 +2,18 @@ from multiprocessing import Lock, Queue from multiprocessing.connection import Connection from queue import Queue as ThreadQueue -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Callable import tempfile import torch import zmq +import numpy as np -from flexkv.common.config import ModelConfig +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout +from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.server.utils import get_zmq_socket from flexkv.server.request import ( RegisterDPClientRequest, @@ -19,6 +21,10 @@ IsReadyRequest, PutRequest, GetRequest, + PutMatchRequest, + GetMatchRequest, + LaunchTaskRequest, + CancelTaskRequest, WaitRequest, TryWaitRequest, CheckRunningRequest, @@ -26,7 +32,6 @@ Response ) - class KVDPClient: def __init__( self, @@ -38,6 +43,7 @@ def __init__( self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) + # is this ok when there are multiple dp clients? client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" self.recv_from_server = get_zmq_socket( context, zmq.SocketType.PULL, client_recv_port, True @@ -67,7 +73,7 @@ def register_to_server( self.send_to_server.send_pyobj(register_req) # blocking response: Response = self.recv_from_server.recv_pyobj() - if response.success: + if response.error_msg is None: flexkv_logger.info(f"DP client registered successfully! DP client id: {response.dp_client_id}") return response.dp_client_id else: @@ -80,36 +86,47 @@ def is_ready( req = IsReadyRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.success: - return response.is_ready - else: - flexkv_logger.error(f"is_ready failed: {response.error_msg}") - raise - + return response.is_ready + def put_async( self, token_ids: torch.Tensor, slot_mapping: torch.Tensor, token_mask: Optional[torch.Tensor], - ) -> Optional[int]: - # start_time = time.time() + ) -> int: req = PutRequest(self.dp_client_id, token_ids.numpy(), slot_mapping.numpy(), token_mask.numpy() if token_mask is not None else None, self._get_task_id()) self.send_to_server.send_pyobj(req) - # end_time = time.time() - # print(f"[dpclient] put_async task: {req.task_id} created. time: {(end_time - start_time)*1000:.2f}ms") return req.task_id + def put_match( + self, + token_ids: torch.Tensor, + slot_mapping: torch.Tensor, + token_mask: Optional[torch.Tensor], + ) -> Optional[Tuple[int, np.ndarray]]: + req = PutMatchRequest(self.dp_client_id, + token_ids.numpy(), + slot_mapping.numpy(), + token_mask.numpy() if token_mask is not None else None, + self._get_task_id()) + self.send_to_server.send_pyobj(req) + response: Response = self.recv_from_server.recv_pyobj() + if response.error_msg is None: + return response.task_id, response.mask + else: + flexkv_logger.error(f"put_match failed, error_msg: {response.error_msg}") + return None + def get_async( self, token_ids: torch.Tensor, slot_mapping: torch.Tensor, token_mask: Optional[torch.Tensor], - ) -> Optional[int]: - # start_time = time.time() + ) -> int: req = GetRequest(self.dp_client_id, token_ids.numpy(), slot_mapping.numpy(), @@ -117,24 +134,56 @@ def get_async( self._get_task_id()) self.send_to_server.send_pyobj(req) - # end_time = time.time() - # print(f"[dpclient] get_async task: {req.task_id} created. time: {(end_time - start_time)*1000:.2f}ms") return req.task_id + def get_match( + self, + token_ids: torch.Tensor, + slot_mapping: torch.Tensor, + token_mask: Optional[torch.Tensor], + ) -> Optional[Tuple[int, np.ndarray]]: + req = GetMatchRequest(self.dp_client_id, + token_ids.numpy(), + slot_mapping.numpy(), + token_mask.numpy() if token_mask is not None else None, + self._get_task_id()) + self.send_to_server.send_pyobj(req) + response: Response = self.recv_from_server.recv_pyobj() + if response.error_msg is None: + return req.task_id, response.mask + else: + flexkv_logger.error(f"get_match failed, error_msg: {response.error_msg}") + return None + + def launch_task( + self, + task_ids: List[int], + ) -> None: + req = LaunchTaskRequest(self.dp_client_id, task_ids) + self.send_to_server.send_pyobj(req) + + def cancel_task( + self, + task_ids: List[int], + ) -> None: + req = CancelTaskRequest(self.dp_client_id, task_ids) + self.send_to_server.send_pyobj(req) + def wait( self, wait_task_ids: List[int], wait_timeout: float = 20.0, - ) -> Optional[Dict[int, torch.Tensor]]: - req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout) + completely: bool = False, + ) -> Optional[Dict[int, KVResponse]]: + req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout, completely) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.masks is not None: - response.masks = {k: torch.from_numpy(v) for k, v in response.masks.items()} - if response.success: - # flexkv_logger.info(f"wait tasks: {wait_task_ids} finished.") - return response.masks + if response.status is not None: + for k, v in response.status.items(): + if v.status != KVResponseStatus.SUCCESS: + flexkv_logger.error(f"wait task {k} failed: {v.status}") + return response.status else: flexkv_logger.error(f"wait tasks: {wait_task_ids} in DP {self.dp_client_id} failed.") return None @@ -142,26 +191,26 @@ def wait( def try_wait( self, try_wait_task_ids: List[int], - ) -> Optional[Dict[int, torch.Tensor]]: + ) -> Optional[Dict[int, KVResponse]]: req = TryWaitRequest(self.dp_client_id, None, try_wait_task_ids) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() if response.masks is not None: - response.masks = {k: torch.from_numpy(v) for k, v in response.masks.items()} - if response.success: - # flexkv_logger.info(f"try_wait tasks: {try_wait_task_ids} finished.") + for k, v in response.masks.items(): + if v.status != KVResponseStatus.SUCCESS: + flexkv_logger.error(f"try_wait task {k} failed: {v.status}") return response.masks else: flexkv_logger.error(f"try_wait tasks: {try_wait_task_ids} in DP {self.dp_client_id} failed.") return None - + """ def check_running(self) -> bool: req = CheckRunningRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() return response.running - + """ def shutdown(self) -> None: req = ShutdownRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) @@ -224,7 +273,7 @@ def register_to_server( self.send_to_server.send_pyobj(register_req) # blocking response: Response = self.recv_from_server.recv_pyobj() - if response.success: + if response.error_msg is None: flexkv_logger.info(f"TP client of DP client {self.dp_client_id} registered successfully!") else: flexkv_logger.error( @@ -232,8 +281,6 @@ def register_to_server( ) raise - - if __name__ == "__main__": num_layers = 32 num_kv_heads = 8 diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 532feae047..8e740adca0 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -7,6 +7,7 @@ from flexkv.common.config import ModelConfig from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout +from flexkv.common.request import KVResponseStatus @dataclass @@ -45,12 +46,39 @@ class GetRequest: token_mask: Optional[np.ndarray] task_id: int = -1 +@dataclass +class PutMatchRequest: + dp_client_id: int + token_ids: np.ndarray + slot_mapping: np.ndarray + token_mask: Optional[np.ndarray] + task_id: int = -1 + +@dataclass +class GetMatchRequest: + dp_client_id: int + token_ids: np.ndarray + slot_mapping: np.ndarray + token_mask: Optional[np.ndarray] + task_id: int = -1 + +@dataclass +class LaunchTaskRequest: + dp_client_id: int + task_ids: List[int] + +@dataclass +class CancelTaskRequest: + dp_client_id: int + task_ids: List[int] + @dataclass class WaitRequest: dp_client_id: int tp_rank: Optional[int] wait_task_ids: List[int] wait_timeout: float = 20.0 + completely: bool = False # Used for async put/get @dataclass @@ -62,13 +90,17 @@ class TryWaitRequest: @dataclass class Response: - dp_client_id: int + dp_client_id: int = -1 task_id: Optional[int] = None - masks: Optional[Dict[int, np.ndarray]] = None - success: bool = True - running: bool = False - error_msg: str = "" + mask: Optional[Dict[int, np.ndarray]] = None + status: Optional[Dict[int, KVResponseStatus]] = None is_ready: bool = False + error_msg: Optional[str] = None + + @property + def success(self) -> bool: + return self.status is not None and \ + all(self.status[task_id] == KVResponseStatus.SUCCESS for task_id in self.status.keys()) @dataclass diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 8ea5b91faa..ae3f5d0af0 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -7,12 +7,15 @@ import time import threading from threading import Lock +import multiprocessing as mp +import socket +import os from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVTaskEngine from flexkv.server.utils import get_zmq_socket from flexkv.server.request import ( RegisterDPClientRequest, @@ -20,6 +23,10 @@ IsReadyRequest, PutRequest, GetRequest, + PutMatchRequest, + GetMatchRequest, + LaunchTaskRequest, + CancelTaskRequest, WaitRequest, TryWaitRequest, Response, @@ -29,6 +36,49 @@ import contextlib +def _is_port_in_use(port_or_endpoint: str) -> bool: + """ + check if the port or IPC endpoint is in use by another process + + Args: + port_or_endpoint: port number or IPC endpoint string (e.g. "ipc:///tmp/xxx" or "5555") + + Returns: + bool: True if the port/endpoint is in use, False if it is free + """ + try: + if port_or_endpoint.startswith("ipc://"): + # IPC endpoint: check if the file exists + ipc_path = port_or_endpoint[6:] # remove "ipc://" prefix + return os.path.exists(ipc_path) + elif port_or_endpoint.startswith("tcp://"): + # TCP endpoint: parse host and port + tcp_part = port_or_endpoint[6:] # remove "tcp://" prefix + if ':' in tcp_part: + host, port_str = tcp_part.rsplit(':', 1) + port = int(port_str) + else: + host = "localhost" + port = int(tcp_part) + + # try to connect to the port to check if it is in use + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex((host, port)) + sock.close() + return result == 0 + else: + # assume it is a pure port number + port = int(port_or_endpoint) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex(("localhost", port)) + sock.close() + return result == 0 + except (ValueError, OSError): + # if parsing fails or connection fails, assume the port is free + return False +""" class TPClient: def __init__( self, @@ -39,7 +89,7 @@ def __init__( self.tp_rank = tp_rank self.device_id = device_id self.send_to_client = send_to_client - +""" class DPClient: def __init__( @@ -50,12 +100,11 @@ def __init__( ): self.client_id = client_id self.tp_size = tp_size - self.tp_client_dict: Dict[int, TPClient] = {} self.send_to_client = send_to_client self.is_ready: bool = False - +""" def register_tp_client( self, context: zmq.Context, @@ -82,7 +131,7 @@ def register_tp_client( self.is_ready = True flexkv_logger.info(f"All the TP clients in DP client: {self.client_id} has registered. " f"DP client: {self.client_id} is ready!") - +""" class ClientManager: def __init__( @@ -115,7 +164,7 @@ def register_dp_client( flexkv_logger.info(f"DP client {client_id} registered successfully") return client_id - + """ def register_tp_client( self, context: zmq.Context, @@ -129,7 +178,7 @@ def register_tp_client( raise self.client_dict[dp_client_id].register_tp_client( context, client_recv_port, tp_rank, device_id) - + """ def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -150,34 +199,81 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: return self.client_dict[dp_client_id].is_ready return False +class KVServerHandle: + def __init__(self, process: mp.Process, ready_event: mp.Event): + self.process = process + self.ready_event = ready_event + + def shutdown(self) -> None: + self.process.join(timeout=5) + if self.process.is_alive(): + flexkv_logger.info("force terminate the server process") + self.process.terminate() + self.process.join() + + def __del__(self) -> None: + if self.process.is_alive(): + self.shutdown() class KVServer: def __init__( self, model_config: ModelConfig, cache_config: CacheConfig, - server_recv_port: Optional[str] = None, + gpu_register_port: str, + server_recv_port: str ): # Init inter-process communication self.context = zmq.Context(2) - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, server_recv_port, True) self.client_manager = ClientManager(max_num_dp_client=model_config.dp_size) - self.kvmanager = KVManager(model_config, cache_config) - - if self.kvmanager.is_ready(): - flexkv_logger.info("KVManager is ready, starting with worker initialization...") - self.kvmanager.start() + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, False) + self.kv_task_engine.start() + self._is_ready = True self.req_counter = 0 flexkv_logger.info(f"Server Initialized! [Recv Port]: {server_recv_port}") - # self._running = True + self._running = False + def is_ready(self) -> bool: + return self._is_ready + + @staticmethod + def _server_process(model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str, + server_recv_port: str, + ready_event: mp.Event) -> None: + + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) + ready_event.set() + server.run() + + @classmethod + def create_server(cls, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str, + server_recv_port: Optional[str] = None) -> 'KVServerHandle': + if server_recv_port is None: + server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + #if _is_port_in_use(server_recv_port): + # flexkv_logger.info(f"port {server_recv_port} is in use, skip starting new kvserver") + # return None + #else: + # flexkv_logger.info(f"port {server_recv_port} is free, starting new kvserver") + mp.set_start_method("spawn") + ready_event = mp.Event() + process = mp.Process(target=cls._server_process, + args=(model_config, cache_config, gpu_register_port, server_recv_port, ready_event)) + process.start() + flexkv_logger.info(f"KVServer process started, PID: {process.pid}") + + return KVServerHandle(process, ready_event) def run(self) -> None: """Main server loop""" @@ -187,7 +283,7 @@ def run(self) -> None: self._running = True while self._running: try: - flexkv_logger.info("start wait for req") + flexkv_logger.info("start waiting for req") req = self.recv_from_client.recv_pyobj() flexkv_logger.info(f"recv req: {type(req)}") @@ -203,41 +299,16 @@ def run(self) -> None: result_zmq = self.client_manager.get_zmq(client_id) result_zmq.send_pyobj(response) - - elif isinstance(req, RegisterTPClientRequest): - self.client_manager.register_tp_client( - self.context, - req.dp_client_id, - req.client_recv_port, - req.tp_rank, - req.device_id, - ) - - # register GPU Memory - self.kvmanager.register_single_gpu_blocks(req.handles, - req.gpu_layout, - req.dp_client_id, - req.tp_rank) - - response = Response(req.dp_client_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id, req.tp_rank) - result_zmq.send_pyobj(response) - - if self.kvmanager.is_ready(): - flexkv_logger.info("All TP clients registered, starting KVManager...") - self.kvmanager.start() - elif isinstance(req, IsReadyRequest): - is_ready = self.kvmanager.is_ready() + is_ready = self.kv_task_engine.is_ready() response = Response(req.dp_client_id, is_ready=is_ready) result_zmq = self.client_manager.get_zmq( req.dp_client_id) result_zmq.send_pyobj(response) elif isinstance(req, GetRequest): - assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kvmanager.get_async( + #assert self.client_manager.is_dp_client_ready(req.dp_client_id) + req_id = self.kv_task_engine.get_async( token_ids=torch.from_numpy(req.token_ids), slot_mapping=torch.from_numpy(req.slot_mapping), token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, @@ -245,50 +316,56 @@ def run(self) -> None: dp_id=req.dp_client_id, task_id=req.task_id, ) - if req.task_id == -1: - response = Response(req.dp_client_id, req_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) elif isinstance(req, PutRequest): - assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kvmanager.put_async( + #assert self.client_manager.is_dp_client_ready(req.dp_client_id) + req_id = self.kv_task_engine.put_async( token_ids=torch.from_numpy(req.token_ids), slot_mapping=torch.from_numpy(req.slot_mapping), token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, dp_id=req.dp_client_id, task_id=req.task_id, ) - if req.task_id == -1: - response = Response(req.dp_client_id, req_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) + + elif isinstance(req, GetMatchRequest): + #assert self.client_manager.is_dp_client_ready(req.dp_client_id) + req_id, mask = self.kv_task_engine.get_match( + token_ids=torch.from_numpy(req.token_ids), + slot_mapping=torch.from_numpy(req.slot_mapping), + token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq( + req.dp_client_id) + result_zmq.send_pyobj(response) + + elif isinstance(req, PutMatchRequest): + #assert self.client_manager.is_dp_client_ready(req.dp_client_id) + req_id, mask = self.kv_task_engine.put_match( + token_ids=torch.from_numpy(req.token_ids), + slot_mapping=torch.from_numpy(req.slot_mapping), + token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq( + req.dp_client_id) + result_zmq.send_pyobj(response) elif isinstance(req, WaitRequest): - # TODO: support TP client wait - masks = self.kvmanager.wait( + kv_responses = self.kv_task_engine.wait( req.wait_task_ids, timeout=req.wait_timeout, ) - if masks is not None: - # Convert to numpy arrays for serialization - masks = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in masks.items()} - response = Response(req.dp_client_id, masks=masks) + response = Response(req.dp_client_id, status=kv_responses) result_zmq = self.client_manager.get_zmq( req.dp_client_id) result_zmq.send_pyobj(response) elif isinstance(req, TryWaitRequest): - # TODO: support TP client try_wait - masks = self.kvmanager.try_wait( + kv_responses = self.kv_task_engine.try_wait( req.try_wait_task_ids, ) - if masks is not None: - # Convert to numpy arrays for serialization - masks = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in masks.items()} - response = Response(req.dp_client_id, masks=masks) + response = Response(req.dp_client_id, status=kv_responses) result_zmq = self.client_manager.get_zmq( req.dp_client_id) result_zmq.send_pyobj(response) @@ -302,12 +379,7 @@ def run(self) -> None: result_zmq = self.client_manager.get_zmq(req.dp_client_id) result_zmq.send_pyobj(response) break - - elif isinstance(req, CheckRunningRequest): - response = Response(req.dp_client_id, success=True, running=self.kvmanager.is_running()) - result_zmq = self.client_manager.get_zmq(req.dp_client_id) - result_zmq.send_pyobj(response) - + else: raise TypeError(f"Unregonized RequestType: {type(req)}") @@ -332,337 +404,6 @@ def _verify_model_config( def __del__(self) -> None: self.kvmanager.shutdown() - -class SchedulerServer: - """ - Scheduler server that merges the functionality of KVServer and KVDPClient. - Note that this class is ONLY FOR CASES WHEN DP_SIZE = 1. - - This class can: - 1. Directly call KVManager methods to avoid inter-process communication latency - 2. Accept registration requests from TPClient - 3. Provide the same interface as KVDPClient (put_async, get_async, wait, try_wait) - """ - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - server_recv_port: Optional[str] = None, - ): - self.model_config = model_config - self.cache_config = cache_config - - # Initialize KVManager (similar to KVServer) - self.kvmanager = KVManager(model_config, cache_config) - - # Start KVManager if it's ready (e.g., when no TP clients are needed) - if self.kvmanager.is_ready(): - try: - self.kvmanager.start() - flexkv_logger.info("KVManager started during initialization") - except Exception as e: - flexkv_logger.warning(f"KVManager start failed during initialization: {e}") - - # For TPClient compatibility, we need a server to receive TPClient registration requests - self.context = zmq.Context(2) - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - - self.server_recv_port = server_recv_port - self.recv_from_client = get_zmq_socket( - self.context, zmq.SocketType.PULL, server_recv_port, True) - - # Manage TP clients - self.tp_size = model_config.tp_size - self.tp_client_dict: Dict[int, TPClient] = {} - self.is_ready: bool = False - - # DP client related - self.dp_client_id = 0 # Fixed to 0 because we merged scheduler and server - self._task_id_range = (self.dp_client_id * 10000000, (self.dp_client_id + 1) * 10000000) - self._task_id_counter = self._task_id_range[0] - self._task_id_lock = Lock() - - # Server thread control - self._running = False - self._server_thread = None - - flexkv_logger.info(f"SchedulerServer Initialized! [Recv Port]: {server_recv_port}") - - def _get_task_id(self) -> int: - """Generate unique task ID""" - with self._task_id_lock: - old_value = self._task_id_counter - self._task_id_counter += 1 - if self._task_id_counter >= self._task_id_range[1]: - self._task_id_counter = self._task_id_range[0] - return old_value - - def start_server_thread(self) -> None: - """Start background server thread to handle TPClient requests""" - if self._server_thread is not None and self._server_thread.is_alive(): - flexkv_logger.warning("Server thread is already running") - return - - self._running = True - self._server_thread = threading.Thread(target=self._server_loop, daemon=True) - self._server_thread.start() - flexkv_logger.info("SchedulerServer background thread started") - - def _server_loop(self) -> None: - """Background server loop to handle requests from TPClient""" - while self._running: - try: - # Set non-blocking receive to allow checking _running status - try: - req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) - except zmq.Again: - time.sleep(0.001) # Brief sleep to avoid busy waiting - continue - - flexkv_logger.info(f"SchedulerServer received request: {type(req)}") - - if isinstance(req, RegisterTPClientRequest): - self._handle_tp_registration(req) - elif isinstance(req, ShutdownRequest): - flexkv_logger.info("Received shutdown request from TP client") - response = Response(req.dp_client_id, success=True) - # Since we don't know which TP client sent the shutdown request, - # we send response to all registered TP clients - self._running = False - for tp_client in self.tp_client_dict.values(): - tp_client.send_to_client.send_pyobj(response) - break - else: - flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}") - - except zmq.ZMQError as e: - if e.errno == zmq.ETERM: - break # Context terminated - flexkv_logger.error(f"ZMQ Error in SchedulerServer: {e}", exc_info=True) - except Exception as e: - flexkv_logger.error(f"Error in SchedulerServer: {e}", exc_info=True) - time.sleep(0.0001) - - flexkv_logger.info("SchedulerServer background thread stopped") - - def _handle_tp_registration(self, req: RegisterTPClientRequest) -> None: - """Handle TP Client registration request""" - tp_rank = req.tp_rank - - if tp_rank in self.tp_client_dict: - flexkv_logger.error(f"TP rank: {tp_rank} has already registered.") - response = Response(req.dp_client_id, success=False, - error_msg=f"TP rank {tp_rank} already registered") - elif tp_rank >= self.tp_size: - flexkv_logger.error(f"TP rank: {tp_rank} is larger than TP size: {self.tp_size}.") - response = Response(req.dp_client_id, success=False, - error_msg=f"TP rank {tp_rank} exceeds TP size {self.tp_size}") - else: - try: - # Create connection to TP client - send_to_client = get_zmq_socket( - self.context, zmq.SocketType.PUSH, req.client_recv_port, False - ) - - self.tp_client_dict[tp_rank] = TPClient(send_to_client, tp_rank, req.device_id) - - # Register GPU Memory to KVManager - self.kvmanager.register_single_gpu_blocks( - req.handles, - req.gpu_layout, - self.dp_client_id, # Use fixed dp_client_id = 0 - req.tp_rank - ) - - flexkv_logger.info(f"TP rank: {tp_rank} registered successfully.") - - # Check if all TP clients have registered - if len(self.tp_client_dict) == self.tp_size: - self.is_ready = True - # Always start kvmanager when all TP clients are registered - try: - flexkv_logger.info("All TP clients registered, starting KVManager...") - self.kvmanager.start() - flexkv_logger.info("KVManager started successfully") - except Exception as e: - flexkv_logger.warning(f"KVManager start failed or already started: {e}") - flexkv_logger.info("All TP clients registered. SchedulerServer is ready!") - - response = Response(req.dp_client_id, success=True) - - except Exception as e: - flexkv_logger.error(f"Failed to register TP client {tp_rank}: {e}") - response = Response(req.dp_client_id, success=False, error_msg=str(e)) - - # Send response to TP client - if tp_rank in self.tp_client_dict: - self.tp_client_dict[tp_rank].send_to_client.send_pyobj(response) - - def put_async( - self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, - ) -> Optional[int]: - """ - Asynchronous PUT operation, directly calling KVManager (no network communication required) - - Args: - token_ids: Token IDs tensor - slot_mapping: Slot mapping tensor - token_mask: Optional token mask tensor - - Returns: - Task ID if successful, None otherwise - """ - start_time = time.time() - - if not self.is_ready: - flexkv_logger.error("SchedulerServer is not ready (not all TP clients registered)") - return None - - try: - task_id = self._get_task_id() - req_id = self.kvmanager.put_async( - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - dp_id=self.dp_client_id, - task_id=task_id, - ) - - end_time = time.time() - flexkv_logger.info(f"[SchedulerServer] put_async task: {task_id} created. " - f"time: {(end_time - start_time)*1000:.2f}ms") - return task_id - - except Exception as e: - flexkv_logger.error(f"put_async failed: {e}") - return None - - def get_async( - self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, - ) -> Optional[int]: - """ - Asynchronous GET operation, directly calling KVManager (no network communication required) - - Args: - token_ids: Token IDs tensor - slot_mapping: Slot mapping tensor - token_mask: Optional token mask tensor - - Returns: - Task ID if successful, None otherwise - """ - start_time = time.time() - - if not self.is_ready: - flexkv_logger.error("SchedulerServer is not ready (not all TP clients registered)") - return None - - try: - task_id = self._get_task_id() - req_id = self.kvmanager.get_async( - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - layer_granularity=-1, - dp_id=self.dp_client_id, - task_id=task_id, - ) - - end_time = time.time() - flexkv_logger.info(f"[SchedulerServer] get_async task: {task_id} created. " - f"time: {(end_time - start_time)*1000:.2f}ms") - return task_id - - except Exception as e: - flexkv_logger.error(f"get_async failed: {e}") - return None - - def wait( - self, - wait_task_ids: List[int], - wait_timeout: float = 20.0, - ) -> Optional[Dict[int, torch.Tensor]]: - """ - Wait for specified tasks to complete, directly calling KVManager (no network communication required) - - Args: - wait_task_ids: List of task IDs to wait for - - Returns: - Dictionary mapping task IDs to result masks, None if failed - """ - try: - masks = self.kvmanager.wait(wait_task_ids, timeout=wait_timeout) - flexkv_logger.info(f"[SchedulerServer] wait tasks: {wait_task_ids} finished.") - return masks - - except Exception as e: - flexkv_logger.error(f"wait failed: {e}") - return None - - def try_wait( - self, - try_wait_task_ids: List[int], - ) -> Optional[Dict[int, torch.Tensor]]: - """ - Non-blocking wait for specified tasks, directly calling KVManager (no network communication required) - - Args: - try_wait_task_ids: List of task IDs to try waiting for - - Returns: - Dictionary mapping task IDs to result masks, None if not ready or failed - """ - try: - masks = self.kvmanager.try_wait(try_wait_task_ids) - if masks is not None: - flexkv_logger.info(f"[SchedulerServer] try_wait tasks: {try_wait_task_ids} finished.") - return masks - - except Exception as e: - flexkv_logger.error(f"try_wait failed: {e}") - return None - - def check_running(self) -> bool: - return self.kvmanager.is_running() - - def shutdown(self) -> None: - """Shutdown SchedulerServer""" - flexkv_logger.info("Shutting down SchedulerServer...") - - # Stop server thread - self._running = False - if self._server_thread is not None and self._server_thread.is_alive(): - self._server_thread.join(timeout=5.0) - - # Shutdown KVManager - if hasattr(self, 'kvmanager'): - self.kvmanager.shutdown() - - # Close ZMQ context - #if hasattr(self, 'context'): - # self.context.term() - - flexkv_logger.info("SchedulerServer shutdown complete") - - def get_server_port(self) -> str: - """Get server receive port for TPClient to use""" - return self.server_recv_port - - def __del__(self) -> None: - """Destructor""" - with contextlib.suppress(Exception): - self.shutdown() - - if __name__ == "__main__": import torch num_layers = 32 diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 30516e6c3f..2406098dfc 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -88,7 +88,6 @@ def _init_workers(self) -> None: if self.tp_size == 1: self.gpucpu_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( - worker_id=i, finished_ops_queue=self.finished_ops_queue, gpu_blocks=self.gpu_handles[i].get_tensor_handle_list(), cpu_blocks=self._cpu_handle.get_tensor(), @@ -106,7 +105,6 @@ def _init_workers(self) -> None: else: self.gpucpu_workers = [ tpGPUCPUTransferWorker.create_worker( - worker_id=i, finished_ops_queue=self.finished_ops_queue, gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], @@ -128,7 +126,6 @@ def _init_workers(self) -> None: if self._ssd_handle is not None and self._cpu_handle is not None: self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - worker_id=10, finished_ops_queue=self.finished_ops_queue, cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), @@ -139,7 +136,6 @@ def _init_workers(self) -> None: cache_config=self._cache_config, ) self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - worker_id=11, finished_ops_queue=self.finished_ops_queue, cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), @@ -153,7 +149,6 @@ def _init_workers(self) -> None: self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( - worker_id=20, finished_ops_queue=self.finished_ops_queue, cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), @@ -163,7 +158,6 @@ def _init_workers(self) -> None: remote_config_custom=self._remote_handle.remote_config_custom, ) self.remotecpu_write_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( - worker_id=21, finished_ops_queue=self.finished_ops_queue, cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), @@ -177,18 +171,19 @@ def _init_workers(self) -> None: if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") # Wait for all workers to ready - for worker in self._worker_map.values(): + for transfer_type, worker in self._worker_map.items(): if isinstance(worker, List): for w in worker: - w.ready_event.wait(timeout=60) + flexkv_logger.info(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") + w.ready_event.wait() + flexkv_logger.info(f"{transfer_type.name} worker {w.worker_id} is ready") else: - flexkv_logger.info(f"waiting for worker {worker} to ready") - worker.ready_event.wait(timeout=60) - flexkv_logger.info(f"worker {worker} is ready") + flexkv_logger.info(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") + worker.ready_event.wait() + flexkv_logger.info(f"{transfer_type.name} worker {worker.worker_id} is ready") # Start scheduler thread self._running = True self._scheduler_thread = threading.Thread(target=self._scheduler_loop) - flexkv_logger.info("TransferEngine initialized and running") self._scheduler_thread.start() def start(self) -> None: @@ -286,6 +281,8 @@ def get_completed_graphs_and_ops(self, timeout: Optional[float] = None) -> List[ def shutdown(self) -> None: """Shutdown the transfer engine""" try: + if not self._running: + return self._running = False self._scheduler_thread.join(timeout=5) diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 4d72a0fb93..84a15a3cdf 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -61,13 +61,16 @@ def __init__(self, transfer_op: TransferOp): self.transfer_op_id = transfer_op.op_id self.transfer_graph_id = transfer_op.graph_id self.transfer_type = transfer_op.transfer_type - self.src_block_ids = transfer_op.src_descriptor.physical_block_ids.numpy() - self.dst_block_ids = transfer_op.dst_descriptor.physical_block_ids.numpy() + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids self.layer_id = transfer_op.layer_id self.layer_granularity = transfer_op.layer_granularity # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): + _worker_id_counter = 0 + _worker_id_lock = threading.Lock() + def __init__(self, worker_id: int, transfer_conn: Connection, # receive end of pipe @@ -76,6 +79,13 @@ def __init__(self, self.transfer_conn = transfer_conn # receive end of pipe self.finished_ops_queue: MPQueue[int] = finished_ops_queue + @classmethod + def _get_worker_id(cls) -> int: + with cls._worker_id_lock: + worker_id = cls._worker_id_counter + cls._worker_id_counter += 1 + return worker_id + def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: if isinstance(layer_blocks, torch.Tensor): layer_blocks = [layer_blocks] @@ -90,10 +100,11 @@ def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) return layer_ptrs @classmethod - def create_worker(cls, worker_id: int, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) -> 'WorkerHandle': + def create_worker(cls, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) -> 'WorkerHandle': """Generic worker creation template method""" parent_conn, child_conn = MPPipe() # create pipe ready_event = mp.Event() + worker_id = cls._get_worker_id() process = mp.Process( target=cls._worker_process, @@ -180,7 +191,7 @@ class WorkerHandle: """handle for worker process""" def __init__(self, worker_id: int, transfer_conn: Connection, process: mp.Process, ready_event: Any): self.worker_id = worker_id - self.transfer_conn = transfer_conn # send end of pipe + self.transfer_conn = transfer_conn self.process = process self.ready_event = ready_event diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py new file mode 100644 index 0000000000..d6e5be00f6 --- /dev/null +++ b/flexkv/transfer_manager.py @@ -0,0 +1,344 @@ +import multiprocessing as mp +import time +import queue +from queue import Queue +from typing import Dict, Optional, List, Tuple +from abc import ABC, abstractmethod +from multiprocessing import Process, Pipe, Event +import zmq +import tempfile +import threading +import numpy as np + +from flexkv.common.transfer import TransferOpGraph +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.debug import flexkv_logger +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.common.transfer import DeviceType +from flexkv.common.storage import KVCacheLayout +from flexkv.storage.storage_engine import StorageEngine +from flexkv.transfer.transfer_engine import TransferEngine +from flexkv.server.utils import get_zmq_socket +from flexkv.server.request import RegisterTPClientRequest, Response + + +class TransferManager: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + + self.gpu_layout: Optional[KVCacheLayout] = None + self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks + + self.context = zmq.Context(2) + self.recv_from_client = get_zmq_socket( + self.context, zmq.SocketType.PULL, gpu_register_port, True) + self.client_dict: Dict[int, zmq.Socket] = {} + + self.transfer_engine: Optional[TransferEngine] = None + self.storage_engine = StorageEngine(self.model_config, self.cache_config) + + def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: + device_id = req.device_id + + if device_id in self.all_gpu_blocks: + flexkv_logger.error(f"GPU {device_id} has already registered.") + response = Response(req.dp_client_id, success=False, + error_msg=f"GPU {device_id} already registered") + elif device_id >= self.model_config.tp_size * self.model_config.dp_size: + flexkv_logger.error(f"GPU {device_id} is larger than TP size: " + f"{self.model_config.tp_size * self.model_config.dp_size}.") + response = Response(req.dp_client_id, success=False, + error_msg=f"GPU {device_id} exceeds TP size " + f"{self.model_config.tp_size * self.model_config.dp_size}") + else: + try: + response = Response(req.dp_client_id) + send_to_client = get_zmq_socket( + self.context, zmq.SocketType.PUSH, req.client_recv_port, False) + send_to_client.send_pyobj(response) + self.client_dict[device_id] = send_to_client + + self.all_gpu_blocks[device_id] = req.handles + if self.gpu_layout is None: + self.gpu_layout = req.gpu_layout + elif self.gpu_layout != req.gpu_layout: + flexkv_logger.error(f"GPU {device_id} has different GPU layout: " + f"{self.gpu_layout} != {req.gpu_layout}") + raise ValueError(f"GPU {device_id} has different GPU layout: " + f"{self.gpu_layout} != {req.gpu_layout}") + flexkv_logger.info(f"GPU {device_id} registered successfully") + except Exception as e: + flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") + response = Response(req.dp_client_id, success=False, + error_msg=f"Failed to register GPU {device_id}: {e}") + + if device_id in self.client_dict: + self.client_dict[device_id].send_pyobj(response) + + def _register_gpu_blocks_via_socket(self) -> None: + try: + flexkv_logger.info(f"GPU tensor registration server started on port {self.gpu_register_port}") + + expected_gpus = self.model_config.tp_size * self.model_config.dp_size + + while len(self.all_gpu_blocks) < expected_gpus: + try: + req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) + except zmq.Again: + time.sleep(0.001) + continue + + if isinstance(req, RegisterTPClientRequest): + flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}") + self._handle_gpu_blocks_registration(req) + else: + flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}") + + flexkv_logger.info(f"All {expected_gpus} GPUs registered successfully") + + except Exception as e: + flexkv_logger.error(f"Error in GPU registration server: {e}") + raise + finally: + pass + # TODO: fix the socket close issue + # self.recv_from_client.close() + # self.context.term() + + def initialize_transfer_engine(self) -> None: + self._register_gpu_blocks_via_socket() + + assert self.gpu_layout is not None + assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size + for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): + self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, + self.gpu_layout, + device_id, + dtype=self.model_config.dtype) + self.gpu_handles = [ + self.storage_engine.get_storage_handle(DeviceType.GPU, i) + for i in range(self.model_config.tp_size * self.model_config.dp_size) + ] + cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) \ + if self.cache_config.enable_cpu else None + ssd_handle = self.storage_engine.get_storage_handle(DeviceType.SSD) \ + if self.cache_config.enable_ssd else None + remote_handle = ( + self.storage_engine.get_storage_handle(DeviceType.REMOTE) \ + if self.cache_config.enable_remote \ + else None + ) + self.transfer_engine = TransferEngine(gpu_handles=self.gpu_handles, + model_config=self.model_config, + cache_config=self.cache_config, + cpu_handle=cpu_handle, + ssd_handle=ssd_handle, + remote_handle=remote_handle) + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.transfer_engine.submit_transfer_graph(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_engine.get_completed_graphs_and_ops(timeout) + + def start(self) -> None: + self.transfer_engine.start() + + def shutdown(self) -> None: + self.transfer_engine.shutdown() + + +class TransferManagerHandleBase(ABC): + @abstractmethod + def start(self) -> None: + pass + + @abstractmethod + def is_ready(self) -> bool: + pass + + @abstractmethod + def submit(self, transfer_graph: TransferOpGraph) -> None: + pass + + @abstractmethod + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def shutdown(self) -> None: + pass + + +class TransferManagerIntraProcessHandle(TransferManagerHandleBase): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + self.transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) + self._is_ready = False + + def start(self) -> None: + self.transfer_manager.initialize_transfer_engine() + self.transfer_manager.start() + self._is_ready = True + + def is_ready(self) -> bool: + return self._is_ready + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.transfer_manager.submit(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_manager.wait(timeout) + + def shutdown(self) -> None: + self.transfer_manager.shutdown() + + +class TransferManagerInterProcessHandle(TransferManagerHandleBase): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + mp.set_start_method('spawn', force=True) + + self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + + self.command_parent_conn, self.command_child_conn = Pipe() + self.result_parent_conn, self.result_child_conn = Pipe() + + self.process: Optional[Process] = None + self.ready_event = Event() + + self._completed_results: List[Tuple[int, int]] = [] + + def _start_process(self) -> None: + if self.process is not None and self.process.is_alive(): + return + + self.process = Process( + target=self._process_worker, + args=(self.model_config, + self.cache_config, + self.command_child_conn, + self.result_child_conn, + self.gpu_register_port, + self.ready_event), + daemon=False + ) + self.process.start() + + def _process_worker(self, + model_config: ModelConfig, + cache_config: CacheConfig, + command_conn, + result_conn, + gpu_register_port: str, + ready_event) -> None: + try: + transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) + transfer_manager.initialize_transfer_engine() + transfer_manager.start() + ready_event.set() + while True: + try: + if command_conn.poll(timeout=0.0001): + request = command_conn.recv() + request_type = request.get('type') + if request_type == 'submit': + transfer_manager.submit(request['transfer_graph']) + else: + flexkv_logger.error(f"Unrecognized request type: {request_type}") + try: + finished_ops = transfer_manager.wait(0.0001) + if finished_ops: + result_conn.send(finished_ops) + except queue.Empty: + pass + except Exception as e: + flexkv_logger.error(f"Error in transfer manager process: {e}") + + except Exception as e: + flexkv_logger.error(f"Failed to initialize transfer manager process: {e}") + finally: + command_conn.close() + result_conn.close() + + def start(self) -> None: + self._start_process() + + def is_ready(self) -> bool: + return self.ready_event.is_set() + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.command_parent_conn.send({ + 'type': 'submit', + 'transfer_graph': transfer_graph + }) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + finished_ops: List[Tuple[int, int]] = [] + try: + if self.result_parent_conn.poll(timeout=timeout): + finished_ops += self.result_parent_conn.recv() + while self.result_parent_conn.poll(): + finished_ops += self.result_parent_conn.recv() + except EOFError: + pass + + return finished_ops + + def shutdown(self) -> None: + if self.process is not None: + self.process.terminate() + self.process.join(timeout=5.0) + if self.process.is_alive(): + self.process.kill() + self.process.join() + + self.command_parent_conn.close() + self.result_parent_conn.close() + + def __del__(self): + self.shutdown() + + +class TransferManagerHandle: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + use_separate_process: bool = True, + gpu_register_port: Optional[str] = None): + if gpu_register_port is None: + gpu_register_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + if use_separate_process: + self._handle: TransferManagerHandleBase = TransferManagerInterProcessHandle( + model_config, cache_config, gpu_register_port + ) + else: + self._handle: TransferManagerHandleBase = TransferManagerIntraProcessHandle( + model_config, cache_config, gpu_register_port + ) + + def start(self) -> None: + self._handle.start() + + def is_ready(self) -> bool: + return self._handle.is_ready() + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self._handle.submit(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self._handle.wait(timeout) + + def shutdown(self) -> None: + self._handle.shutdown() diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index c9de75b720..3ddc0ce810 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -24,7 +24,7 @@ from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.memory_handle import TensorSharedHandle -from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVTaskEngine class FlexKVReplayEngine: @@ -204,7 +204,7 @@ def create_kvmanager(self,): ) # Create KVManager - self.kvmanager = KVManager( + self.kvmanager = KVTaskEngine( model_config=self.model_config, cache_config=self.cache_config, gpu_layout=self.gpu_layout, @@ -274,10 +274,6 @@ def replay_wait_event(self, event: Dict[str, Any]): result = self.kvmanager.wait_for_graph_finished(task_ids) elif wait_type == "try_wait": result = self.kvmanager.try_wait(task_ids) - elif wait_type == "wait_at_layer_group": - result = self.kvmanager.wait_at_layer_group(task_ids[0], layer_group_id) - elif wait_type == "try_wait_at_layer_group": - result = self.kvmanager.try_wait_at_layer_group(task_ids, layer_group_id) else: raise ValueError(f"Unknown wait type: {wait_type}") successed_elements = [] diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index c01bee7aaa..f105f95857 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -4,23 +4,72 @@ import pytest import torch +import multiprocessing as mp +from multiprocessing import Process, Pipe from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.request import KVResponseStatus +from flexkv.kvtask import KVTaskEngine from flexkv.kvmanager import KVManager +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.server.client import KVTPClient +from flexkv.common.debug import flexkv_logger # Import utilities from test_utils from test_utils import ( DEFAULT_MODEL_CONFIG, DEFAULT_CACHE_CONFIG, DEFAULT_TEST_CONFIG, generate_request_pair, verify_data, block_ids_2_slot_mapping, generate_gpu_blocks_with_ground_truth, skip_if_insufficient_gpus, - create_kvmanager_with_mode, create_gpu_kv_layout + create_gpu_kv_layout, GPUKVCacheVerifier ) +def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config, num_gpu_blocks, child_conn): + """Run tp_client process""" + try: + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Send GPU blocks back to main process via pipe if connection provided + if child_conn is not None: + print(f"[TP Client {tp_rank}] Converting {len(gpu_blocks_for_tp)} GPU blocks to TensorSharedHandle") + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks_for_tp] + child_conn.send(shared_gpu_blocks) + print(f"[TP Client {tp_rank}] Sent GPU blocks to main process via pipe") + child_conn.close() + + # Keep the process running + while True: + time.sleep(1) + except Exception as e: + print(f"[TP Client {tp_rank}] Error occurred: {e}") + if child_conn is not None: + child_conn.send(None) + child_conn.close() + +def shutdown_tp_client(tp_client_processes): + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) @pytest.mark.parametrize("model_config", [ {'tp_size': 1, 'dp_size': 1}, - {'tp_size': 2, 'dp_size': 2}, + {'tp_size': 2, 'dp_size': 2}, {'dtype': torch.float32}, {'use_mla': True}, {'tp_size': 4, 'dp_size': 1, 'use_mla': True}, @@ -34,20 +83,15 @@ 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, ], indirect=True) @pytest.mark.parametrize("test_config", [ - {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4, 'use_server_client': False}, - {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4, 'use_server_client': True}, + {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, ], indirect=True) @pytest.mark.parametrize("flex_kv_layout_type", [ KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE, ]) def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type): - num_layers = model_config.num_layers - num_kv_heads = model_config.num_kv_heads - head_size = model_config.head_size tp_size = model_config.tp_size dp_size = model_config.dp_size - use_mla = model_config.use_mla tokens_per_block = cache_config.tokens_per_block num_cpu_blocks = cache_config.num_cpu_blocks @@ -64,7 +108,6 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) num_gpu_blocks = test_config["num_gpu_blocks"] block_per_request = test_config['requests_per_block'] initial_write_ratio = test_config['initial_write_ratio'] - use_server_client = test_config.get('use_server_client', False) num_requests = num_gpu_blocks // block_per_request @@ -74,50 +117,97 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) if enable_remote: pytest.skip("skip because enable_remote is not supported") - if use_server_client and dp_size > 1: - pytest.skip("skip because server-client mode is not supported for dp_size > 1 IN THIS TEST SCRIPT now") + if dp_size > 1: + #note that for now only dp_size=1 is supported + pytest.skip("skip because server-client mode is not ready for dp_size > 1") + + import uuid + gpu_register_port = f"ipc:///tmp/flexkv_gpu_{uuid.uuid4().hex[:8]}" + server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" + kvmanager = KVManager(model_config, cache_config, gpu_register_port, server_recv_port) + kvmanager.start() + + # Create pipes for each tp_client to send GPU blocks back + pipe_connections = [] + tp_client_processes = [] + + for tp_rank in range(tp_size): + parent_conn, child_conn = Pipe() + pipe_connections.append(parent_conn) - if use_server_client: - # In server-client mode, GPU blocks are created in tp_client processes - # We only need the layout for initialization + tp_client_process = Process( + target=run_tp_client, + args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks, child_conn), + daemon=True + ) + tp_client_processes.append(tp_client_process) + tp_client_process.start() + + # Collect GPU blocks from all tp_client processes + print(f"[Main Process] Waiting to receive GPU blocks from {tp_size} TP client processes...") + all_gpu_blocks = [] + + for tp_rank, parent_conn in enumerate(pipe_connections): + try: + shared_gpu_blocks = parent_conn.recv() + if shared_gpu_blocks is not None: + all_gpu_blocks.append(shared_gpu_blocks) + print(f"[Main Process] Received GPU blocks from TP client {tp_rank}") + else: + print(f"[Main Process] TP client {tp_rank} failed to create GPU blocks") + parent_conn.close() + except Exception as e: + print(f"[Main Process] Error receiving from TP client {tp_rank}: {e}") + + # Create GPUKVCacheVerifier with collected GPU blocks + if all_gpu_blocks and len(all_gpu_blocks) == tp_size: + print(f"[Main Process] Creating GPUKVCacheVerifier with GPU blocks from {len(all_gpu_blocks)} TP clients") + + # Get gpu_kv_layout from cache_config for GPUKVCacheVerifier gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) - gpu_blocks = None # Not used in server-client mode - dp_wise_gpu_blocks_gt = None # Not used in server-client mode + + gpu_kv_verifier = GPUKVCacheVerifier( + shared_gpu_blocks=all_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype + ) + print("[Main Process] GPUKVCacheVerifier created successfully") else: - # In direct mode, create GPU blocks in current process - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config) + print(f"[Main Process] Failed to collect GPU blocks from all TP clients. " + f"Got {len(all_gpu_blocks)} out of {tp_size}") + gpu_kv_verifier = None - kvmanager = create_kvmanager_with_mode(model_config, cache_config, gpu_kv_layout, gpu_blocks, use_server_client) + while not kvmanager.is_ready(): + time.sleep(1) + flexkv_logger.info("waiting for flexkv to be ready") - # put this after KVManager() num_remote_blocks = cache_config.num_remote_blocks - assert kvmanager.is_ready() - kvmanager.start() request_pairs = [generate_request_pair(i, block_per_request, num_gpu_blocks, tokens_per_block, dp_size) for i in range(num_requests)] initial_write_num = int(num_requests * initial_write_ratio) print("writing initial data...") + put_ids = [] for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, dp_id=dp_id, ) - kvmanager.wait_for_graph_finished(write_request) - if not use_server_client: - # In direct mode, update GPU blocks for verification - for gpu in range(dp_id * tp_size, (dp_id + 1) * tp_size): - for i in range(num_layers): - gpu_blocks[gpu][i][:, block_ids, :, :, :] = 0 + kvmanager.wait([write_request], completely=True) #corner case: input token length for put is less than tokens_per_block write_request = kvmanager.put_async( token_ids=torch.randint(0, 100, size=(8,), dtype=torch.int64), slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), + token_mask=None, dp_id=0, ) - kvmanager.wait_for_graph_finished(write_request) + kvmanager.wait([write_request], completely=True) #corner case: input token length is long enough, but the mask is less than tokens_per_block #my_mask = torch.zeros(16, dtype=torch.bool) #my_mask[0:8] = True @@ -134,44 +224,78 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) total_cache_miss = 0 running_get_requests = [] running_put_requests = [] + req_id2block_ids = {} + req_id2token_ids = {} + flexkv_id2req_id = {} start_time = time.time() print(f"the initial {initial_write_num} write done,performing mixed read/write...") for i in range(initial_write_num, num_requests): print(f"performing mixed read/write {i} / {num_requests} ...") read_idx = i - initial_write_num token_ids, block_ids, dp_id = request_pairs[read_idx] - request_id = kvmanager.get_async( + slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) + print(f"token_ids: {token_ids}, block_ids: {block_ids}, dp_id: {dp_id}, slot_mapping: {slot_mapping}") + request_id, _ = kvmanager.get_match( token_ids=token_ids, - slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), layer_granularity=-1, + token_mask=None, dp_id=dp_id, ) + kvmanager.launch(request_id, slot_mapping) + flexkv_id2req_id[request_id] = read_idx running_get_requests.append(request_id) + req_id2block_ids[request_id] = block_ids + req_id2token_ids[request_id] = token_ids token_ids, block_ids, dp_id = request_pairs[i] + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) request_id = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, dp_id=dp_id, ) + flexkv_id2req_id[request_id] = i + print(f"write flexkv request_id {request_id} to req_id {i}") running_put_requests.append(request_id) min_block_num = min(num_cpu_blocks, num_gpu_blocks) if (len(running_get_requests) + len(running_put_requests) >= min_block_num // block_per_request - 2 or i % initial_write_num == initial_write_num - 1 or i == num_requests - 1): if len(running_put_requests) > 0: - kvmanager.wait_for_graph_finished(running_put_requests) + kvmanager.wait(running_put_requests, completely=True) if len(running_get_requests) > 0: - return_masks = kvmanager.wait(running_get_requests) - for return_mask in return_masks.values(): - total_cache_hit += return_mask.sum() - total_cache_miss += len(return_mask) - return_mask.sum() + return_results = kvmanager.wait(running_get_requests, completely=True) + if gpu_kv_verifier is not None: + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + valid_fetched_tokens = kvresponse.return_mask.sum().item() // \ + tokens_per_block * tokens_per_block + token_ids = req_id2token_ids[req_id] + block_ids = req_id2block_ids[req_id] + assert gpu_kv_verifier.verify_kv_blocks( + token_ids[:valid_fetched_tokens], + block_ids[:valid_fetched_tokens//tokens_per_block]) + for kvresponse in return_results.values(): + assert kvresponse.status == KVResponseStatus.SUCCESS + total_cache_hit += kvresponse.return_mask.sum().item() + total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() running_get_requests = [] running_put_requests = [] if len(running_get_requests) > 0: - kvmanager.wait(running_get_requests) + return_results = kvmanager.wait(running_get_requests, completely=True) + if gpu_kv_verifier is not None: + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + token_ids = req_id2token_ids[req_id] + block_ids = req_id2block_ids[req_id] + assert gpu_kv_verifier.verify_kv_blocks( + token_ids[:valid_fetched_tokens], + block_ids[:valid_fetched_tokens//tokens_per_block]) running_get_requests = [] if len(running_put_requests) > 0: - kvmanager.wait_for_graph_finished(running_put_requests) + kvmanager.wait(running_put_requests, completely=True) running_put_requests = [] print("mixed read/write done") end_time = time.time() @@ -182,12 +306,12 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) enable_ssd and num_ssd_blocks >= num_gpu_blocks or \ enable_remote and num_remote_blocks >= num_gpu_blocks: assert total_cache_miss == 0 + shutdown_tp_client(tp_client_processes) kvmanager.shutdown() - if total_cache_miss == 0 and not use_server_client: - # Only verify data in direct mode - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, num_kv_heads, tp_size, dp_size, num_layers, use_mla) + # Only verify data in direct mode + # verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, num_kv_heads, tp_size, dp_size, num_layers, use_mla) + if total_cache_miss == 0: + return elif total_cache_miss > 0: - print(f"verify skipped, because of total_cache_miss={total_cache_miss}>0") - elif use_server_client: - print("verify skipped in server-client mode (verification happens in tp_client processes)") + print(f"verify skipped, because of total_cache_miss={total_cache_miss} > 0") diff --git a/tests/test_utils.py b/tests/test_utils.py index a7e2f166c5..2530f0291a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,16 @@ import time import os import shutil -from typing import List, Dict, Tuple -from multiprocessing import Process - +from typing import List, Dict, Tuple, Optional, Union +from multiprocessing import Process, Pipe, Queue +import pickle +import multiprocessing as mp import pytest import torch from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.memory_handle import TensorSharedHandle # Default configurations @@ -292,7 +294,7 @@ def __init__(self, model_config, cache_config, gpu_kv_layout, gpu_blocks): tp_client_process = Process( target=KVManagerServerClient._run_tp_client, args=(self.dp_client.dp_client_id, tp_rank, device_id, self.server_recv_port, - model_config.num_layers, str(model_config.dtype), + model_config.num_layers, str(model_config.dtype), list(gpu_kv_layout.kv_shape[1:]), model_config.use_mla), daemon=True ) @@ -449,11 +451,310 @@ def shutdown(self): print("KVManagerServerClient shutdown complete") -def create_kvmanager_with_mode(model_config, cache_config, gpu_kv_layout, gpu_blocks, use_server_client=False): - """Create KVManager with optional server-client mode""" - if use_server_client: - print("Using server-client mode") - return KVManagerServerClient(model_config, cache_config, gpu_kv_layout, gpu_blocks) - else: - from flexkv.kvmanager import KVManager - return KVManager(model_config, cache_config, gpu_kv_layout, gpu_blocks) +class GPUKVCacheVerifier: + def __init__(self, + shared_gpu_blocks: Union[List[torch.Tensor], List[TensorSharedHandle], List[List[TensorSharedHandle]]], + gpu_kv_layout: KVCacheLayout, + tp_size: int, + tokens_per_block: int, + dtype: torch.dtype)->None: + self.gpu_kv_layout = gpu_kv_layout + self.num_layers = gpu_kv_layout.num_layer + # we have to map the exported gpu blocks into the virtual space of current process + if isinstance(shared_gpu_blocks[0], torch.Tensor): + self.gpu_blocks = shared_gpu_blocks + elif isinstance(shared_gpu_blocks[0], TensorSharedHandle): + self.gpu_blocks = [wrapper.get_tensor() for wrapper in shared_gpu_blocks] + else: + imported_gpu_blocks = [] + for handles_in_one_gpu in shared_gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_gpu_blocks.append(blocks_in_one_gpu) + self.gpu_blocks = imported_gpu_blocks + self.gpu_block_num = gpu_kv_layout.num_block + self.tp_size = tp_size + self.is_mla = gpu_kv_layout.is_mla + self.tokens_per_block = tokens_per_block + self.dtype = dtype + + + def hash_all_values(self, layer_id, kv_id, token_ids, head_id): + base_hash = hash((layer_id, kv_id, head_id)) + + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + + token_hash = 0 + prime = 31 + for i, token_id in enumerate(token_ids): + token_hash += (token_id * (prime ** i)) % (2**31 - 1) + + combined_hash = (base_hash + token_hash) % (2**31 - 1) + + normalized_value = (combined_hash % 1000000) / 1000000.0 + + return torch.tensor(normalized_value, dtype=self.dtype).item() + + def fill_gpu_blocks(self, token_ids, block_ids): + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + # Ensure token_ids is in tensor format + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for layer_id in range(self.num_layers): + kv_num = 2 if not self.is_mla else 1 + for kv_id in range(kv_num): + for tp_id in range(self.tp_size): + if isinstance(self.gpu_blocks[0], list): + # multiple gpu:gpu_blocks[tp_id][layer_id] + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + else: + # single gpu:gpu_blocks[layer_id] + gpu_tensor = self.gpu_blocks[layer_id] + + for head_id in range(self.gpu_kv_layout.num_head): + actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id + + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + hash_value = self.hash_all_values(layer_id, + kv_id, + token_ids[start_token_idx:end_token_idx], + actual_head_id) + # GPU tensor dim:[kv_dim, num_block, tokens_per_block, num_head, head_size] + gpu_tensor[kv_id, block_id, :, head_id, :] = hash_value + + def verify_kv_blocks(self, token_ids, block_ids)->bool: + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + verification_passed = True + errors = [] + + for layer_id in range(self.num_layers): + kv_num = 2 if not self.is_mla else 1 + for kv_id in range(kv_num): + for tp_id in range(self.tp_size): + if isinstance(self.gpu_blocks[0], list): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + else: + gpu_tensor = self.gpu_blocks[layer_id] + + for head_id in range(self.gpu_kv_layout.num_head): + actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + expected_hash_value = self.hash_all_values(layer_id, kv_id, + token_ids[start_token_idx:end_token_idx], + actual_head_id) + + actual_values = gpu_tensor[kv_id, block_id, :, head_id, :] + + if not torch.allclose(actual_values, + torch.full_like(actual_values, expected_hash_value), + rtol=1e-5, atol=1e-6): + verification_passed = False + errors.append( + f"Mismatch at layer={layer_id}, kv={kv_id}, tp={tp_id}, " + f"head={head_id}, block={block_id}: " + f"expected={expected_hash_value}, got={actual_values[0, 0].item()}" + ) + + if not verification_passed: + print(f"Verification failed with {len(errors)} errors:") + for error in errors[:10]: + print(f" {error}") + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more errors") + else: + print("KV blocks verification passed!") + + return verification_passed + + +def gpu_blocks_worker_process(conn, model_config, cache_config, gpu_kv_layout): + try: + print(f"[Worker Process {os.getpid()}] Starting to create GPU blocks...") + + # Create GPU blocks in subprocess + gpu_blocks = [] + for layer_id in range(model_config.num_layers): + # LAYERWISE format: [kv_dim, num_block, tokens_per_block, num_head, head_size] + kv_dim = 2 if not model_config.use_mla else 1 + gpu_tensor = torch.zeros( + kv_dim, + gpu_kv_layout.num_block, + gpu_kv_layout.tokens_per_block, + gpu_kv_layout.num_head, + gpu_kv_layout.head_size, + dtype=model_config.dtype, + device='cuda:0' if torch.cuda.is_available() else 'cpu' + ) + gpu_blocks.append(gpu_tensor) + + print(f"[Worker Process {os.getpid()}] Successfully created {len(gpu_blocks)} GPU blocks") + + # Convert to TensorSharedHandle + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks] + print(f"[Worker Process {os.getpid()}] Successfully converted to {len(shared_gpu_blocks)} TensorSharedHandles") + + # Send to main process via pipe + conn.send(shared_gpu_blocks) + print(f"[Worker Process {os.getpid()}] Sent TensorSharedHandle list to main process via pipe") + + #while True: + # time.sleep(1) + conn.close() + + except Exception as e: + print(f"[Worker Process {os.getpid()}] Error occurred: {e}") + conn.send(None) + conn.close() + + +# Usage examples +def example_usage_gpu_kv_cache_verifier(): + """Demonstrates three ways to initialize GPUKVCacheVerifier""" + import torch + from flexkv.common.config import ModelConfig, CacheConfig + from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType + from flexkv.common.memory_handle import TensorSharedHandle + + # Create example configurations + model_config = ModelConfig( + num_layers=2, + num_kv_heads=8, + head_size=64, + use_mla=False, + dtype=torch.float16, + tp_size=1, + dp_size=1 + ) + + cache_config = CacheConfig( + tokens_per_block=16 + ) + + # Create GPU KV layout + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERWISE, + num_layer=model_config.num_layers, + num_block=64, # Assume 64 blocks + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla + ) + + # Create mock GPU blocks + gpu_blocks = [] + for layer_id in range(model_config.num_layers): + # LAYERWISE format: [kv_dim, num_block, tokens_per_block, num_head, head_size] + kv_dim = 2 if not model_config.use_mla else 1 + gpu_tensor = torch.zeros( + kv_dim, + gpu_kv_layout.num_block, + gpu_kv_layout.tokens_per_block, + gpu_kv_layout.num_head, + gpu_kv_layout.head_size, + dtype=model_config.dtype, + device='cuda:0' if torch.cuda.is_available() else 'cpu' + ) + gpu_blocks.append(gpu_tensor) + + print("=== Method 1: Direct Tensor List ===") + verifier1 = GPUKVCacheVerifier( + shared_gpu_blocks=gpu_blocks, # Pass tensor list directly + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + ) + + print("=== Method 2: Using TensorSharedHandle (Multi-process version) ===") + mp.set_start_method('spawn') + # Create pipe for inter-process communication + parent_conn, child_conn = Pipe() + print(f"[Main Process {os.getpid()}] Successfully created pipe connection") + + # Start worker process to create GPU blocks and TensorSharedHandle + worker_process = Process( + target=gpu_blocks_worker_process, + args=(child_conn, model_config, cache_config, gpu_kv_layout) + ) + + print(f"[Main Process {os.getpid()}] Starting worker process...") + worker_process.start() + + # Wait to receive TensorSharedHandle created by worker process + print(f"[Main Process {os.getpid()}] Waiting to receive results from worker process...") + shared_gpu_blocks = parent_conn.recv() + + # Wait for worker process to complete + + + if shared_gpu_blocks is None: + raise RuntimeError("Worker process failed to create GPU blocks") + + print(f"[Main Process {os.getpid()}] Successfully received {len(shared_gpu_blocks)} TensorSharedHandles") + verifier2 = GPUKVCacheVerifier( + shared_gpu_blocks=shared_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + ) + + # Prepare test data - Note: now hash is calculated per block + token_ids = torch.randint(0, 1000, (32,), dtype=torch.int64) # 32 tokens (2 blocks) + block_ids = torch.tensor([0, 1], dtype=torch.int64) # Use blocks 0 and 1 + + print(f"Token IDs shape: {token_ids.shape}") + print(f"Block IDs: {block_ids}") + print(f"Tokens per block: {cache_config.tokens_per_block}") + + # Test method 1 + print("\n=== Testing Method 1 (Direct Tensor) ===") + print("Starting to fill GPU blocks...") + verifier1.fill_gpu_blocks(token_ids, block_ids) + print("Filling completed!") + + print("Starting data verification...") + is_valid1 = verifier1.verify_kv_blocks(token_ids, block_ids) + print(f"Verification result: {'PASSED' if is_valid1 else 'FAILED'}") + + # Test method 2 + print("\n=== Testing Method 2 (SharedHandle) ===") + print("Starting to fill GPU blocks...") + verifier2.fill_gpu_blocks(token_ids, block_ids) + print("Filling completed!") + + print("Starting data verification...") + is_valid2 = verifier2.verify_kv_blocks(token_ids, block_ids) + print(f"Verification result: {'PASSED' if is_valid2 else 'FAILED'}") + + # Demonstrate hash calculation changes: now each block has independent hash values + print("\n=== Hash Calculation Demo ===") + for block_idx, block_id in enumerate(block_ids): + start_idx = block_idx * cache_config.tokens_per_block + end_idx = start_idx + cache_config.tokens_per_block + block_tokens = token_ids[start_idx:end_idx] + hash_value = verifier1.hash_all_values(0, 0, block_tokens, 0) + print(f"Block {block_id} tokens: {block_tokens.tolist()[:5]}... -> hash: {hash_value:.6f}") + worker_process.join() + parent_conn.close() + return verifier1, token_ids, block_ids + +if __name__ == "__main__": + example_usage_gpu_kv_cache_verifier() From 875a99a47d4b4597bbb658e42a9c61225f6ced98 Mon Sep 17 00:00:00 2001 From: lilgao Date: Fri, 22 Aug 2025 10:45:25 +0800 Subject: [PATCH 05/42] feat: add support release wheel (#77) * feat: add support release wheel Signed-off-by: lilgao * fix copilot review for ci Signed-off-by: lilgao --------- Signed-off-by: lilgao Co-authored-by: lilgao --- .github/workflows/publish.yml | 73 ++++++++++++++++++++ .github/workflows/scripts/cuda-install.sh | 24 +++++++ .github/workflows/scripts/env.sh | 21 ++++++ .github/workflows/scripts/pytorch-install.sh | 16 +++++ build.sh | 15 ++-- setup.py | 5 +- 6 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/publish.yml create mode 100755 .github/workflows/scripts/cuda-install.sh create mode 100755 .github/workflows/scripts/env.sh create mode 100755 .github/workflows/scripts/pytorch-install.sh diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000..eb5bc9b75f --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,73 @@ +# This workflow will upload a Python Package to Release asset +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/publish.yml +name: flexkv ci + +on: + pull_request: + branches: [ "main", "feat/lilgao/ci"] + push: + branches: [ "main", "feat/lilgao/ci"] + +# Needed to create wheel and upload assets +permissions: + contents: write + +jobs: + build: + name: Build Wheel + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ['ubuntu-22.04'] + python-version: ['3.10'] + pytorch-version: ['2.6.0'] + cuda-version: ['12.4'] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Linux Env + if: ${{ runner.os == 'Linux' }} + run: | + bash -x .github/workflows/scripts/env.sh + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install CUDA ${{ matrix.cuda-version }} + run: | + bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} + + - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} + run: | + bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} + + - name: Build wheel + shell: bash + env: + TORCH_CUDA_ARCH_LIST: "8.9 9.0+PTX" + MAX_JOBS: 4 + run: | + ./build.sh --release + + - name: Get Date and Time + run: | + echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + echo "time=$(date +'%H-%M-%S')" >> $GITHUB_ENV + + - name: Upload to cos + uses: shallwefootball/s3-upload-action@master + with: + aws_key_id: ${{ secrets.COS_SECRET_ID }} + aws_secret_access_key: ${{ secrets.COS_SECRET_KEY }} + aws_bucket: ${{ secrets.COS_BUCKET }} + endpoint: ${{ secrets.COS_ENDPOINT }} + source_dir: dist + destination_dir: flexkv/${{ env.date }}/${{ env.time }} diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh new file mode 100755 index 0000000000..3e4d7c8b7d --- /dev/null +++ b/.github/workflows/scripts/cuda-install.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/cuda-install.sh + +# Replace '.' with '-' ex: 11.8 -> 11-8 +cuda_version=$(echo "$1" | tr "." "-") +# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 +OS=$(echo "$2" | tr -d ".\-") + +# Installs CUDA +wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb" +sudo dpkg -i cuda-keyring_1.1-1_all.deb +rm cuda-keyring_1.1-1_all.deb +sudo apt -qq update +sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}" +sudo apt clean + +# Test nvcc +PATH=/usr/local/cuda-$1/bin:${PATH} +nvcc --version + +# Log gcc, g++, c++ versions +gcc --version +g++ --version +c++ --version diff --git a/.github/workflows/scripts/env.sh b/.github/workflows/scripts/env.sh new file mode 100755 index 0000000000..299f281236 --- /dev/null +++ b/.github/workflows/scripts/env.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/env.sh + +# This file installs common linux environment tools + +export LANG=C.UTF-8 + +sudo apt-get update && \ +sudo apt-get install -y --no-install-recommends \ + software-properties-common + +sudo apt-get install -y --no-install-recommends \ + build-essential \ + liburing-dev \ + git \ + cmake + +# Remove github bloat files to free up disk space +sudo rm -rf "/usr/local/share/boost" +sudo rm -rf "$AGENT_TOOLSDIRECTORY" +sudo rm -rf "/usr/share/dotnet" diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh new file mode 100755 index 0000000000..559043d412 --- /dev/null +++ b/.github/workflows/scripts/pytorch-install.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/pytorch-install.sh + +python_executable=python$1 +pytorch_version=$2 +cuda_version=$3 + +# Install torch +$python_executable -m pip install numpy ninja cython wheel typing typing-extensions dataclasses setuptools && conda clean -ya +$python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}" + +# Print version information +$python_executable --version +$python_executable -c "import torch; print('PyTorch:', torch.__version__)" +$python_executable -c "import torch; print('CUDA:', torch.version.cuda)" +$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" diff --git a/build.sh b/build.sh index 8b34f27df4..247dd62405 100755 --- a/build.sh +++ b/build.sh @@ -40,13 +40,6 @@ echo "=== Setting BUILD_LIB_PATH to $BUILD_LIB_PATH ===" cd .. -echo "=== Installing package with pip ===" -if [ "$BUILD_TYPE" = "debug" ]; then - FLEXKV_DEBUG=1 pip install --no-build-isolation -e . -else - FLEXKV_DEBUG=0 pip install --no-build-isolation -e . -fi - # Set LD_LIBRARY_PATH for immediate use export LD_LIBRARY_PATH=$BUILD_LIB_PATH:$LD_LIBRARY_PATH echo "Added $BUILD_LIB_PATH to LD_LIBRARY_PATH for current session" @@ -69,3 +62,11 @@ fi echo "=== Build and installation completed successfully in ${BUILD_TYPE} mode ===" echo "You can now run tests directly without setting LD_LIBRARY_PATH manually" + +if [ "$BUILD_TYPE" = "debug" ]; then + FLEXKV_DEBUG=1 pip install --no-build-isolation -e . +elif [ "$BUILD_TYPE" = "release" ]; then + FLEXKV_DEBUG=0 python setup.py bdist_wheel -v +else + FLEXKV_DEBUG=0 pip install --no-build-isolation -e . +fi diff --git a/setup.py b/setup.py index 67997949b2..5926bac9aa 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from torch.utils import cpp_extension -build_dir = os.path.abspath("build") +build_dir = "build" os.makedirs(build_dir, exist_ok=True) # Check if we're in debug mode using environment variable @@ -130,13 +130,14 @@ def copy_shared_libraries(self): version="0.1.0", packages=find_packages(exclude=("benchmarks", "csrc", "examples", "tests")), package_data={ - "flexkv": ["lib/*.so", "lib/*.so.*"], + "flexkv": ["*.so", "lib/*.so", "lib/*.so.*"], }, include_package_data=True, install_requires=install_requires, ext_modules=ext_modules, # Now contains both C++ and Cython modules as needed cmdclass={ "build_ext": CustomBuildExt.with_options( + include_dirs=os.path.join(build_dir, "include"), # Include directory for xxhash no_python_abi_suffix=True, build_temp=os.path.join(build_dir, "temp"), # Temporary build files ) From 22fc69e7502a22999b4b4a91e7775d652a449bca Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 22 Aug 2025 15:38:14 +0800 Subject: [PATCH 06/42] update unit tests for new version (#79) * update test cache engine * update test cache engine accel * remove some tests --- flexkv/cache/mempool.py | 2 +- flexkv/cache/transfer_pattern.py | 253 ------------------- tests/conftest.py | 6 - tests/test_cache_engine.py | 50 ++-- tests/test_cache_engine_accel.py | 48 ++-- tests/test_transfer_engine.py | 411 ------------------------------- 6 files changed, 50 insertions(+), 720 deletions(-) delete mode 100644 tests/test_transfer_engine.py diff --git a/flexkv/cache/mempool.py b/flexkv/cache/mempool.py index 1a99a5cff6..d00077181f 100644 --- a/flexkv/cache/mempool.py +++ b/flexkv/cache/mempool.py @@ -20,7 +20,7 @@ def __init__( self._free_ids_offset = 0 def reset(self) -> None: - self._free_mask.fill_(True) + self._free_mask.fill(True) self._num_free = self.num_total_blocks self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 diff --git a/flexkv/cache/transfer_pattern.py b/flexkv/cache/transfer_pattern.py index 0a434bde04..fcf207408b 100644 --- a/flexkv/cache/transfer_pattern.py +++ b/flexkv/cache/transfer_pattern.py @@ -29,259 +29,6 @@ def add_virtal_op_for_mutiple_finished_ops( graph.add_dependency(op.op_id, op_id) return graph, op.op_id -def create_read_graph_cpu_storage( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - graph: Optional[TransferOpGraph] = None, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a read transfer graph with (REMOTE_STORAGE / SSD)->CPU->GPU operations - ssd_blocks: the blocks of ssd that are used as a lower-level storage backend, - including ssd or remote storage. This can be empty, which means cpu-only kvcache. - Returns: - graph: TransferOpGraph - ops_to_be_tracked: List[int]: a list of transfer ops that can indicate - the completion of some key operations - """ - assert len(gpu_blocks) == len(cpu_blocks) - if graph is None: - graph = TransferOpGraph() - assert len(gpu_blocks) > 0 - if len(ssd_blocks) == 0: - op = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_block_ids = cpu_blocks, - dst_block_ids = gpu_blocks, - layer_id = 0, - layer_granularity = layer_num, - ) - graph.add_transfer_op(op) - return graph, [op.op_id] - elif len(ssd_blocks) < len(cpu_blocks): - task_end_ops_ids = [] - if len(ssd_blocks) > 0: - op1 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.DISK2H, - src_block_ids = ssd_blocks, - dst_block_ids = cpu_blocks[-len(ssd_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op1) - op2 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_block_ids = cpu_blocks[-len(ssd_blocks):], - dst_block_ids = gpu_blocks[-len(ssd_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op2) - graph.add_dependency(op2.op_id, op1.op_id) - task_end_ops_ids.append(op2.op_id) - op3 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_block_ids = cpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], - dst_block_ids = gpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op3) - task_end_ops_ids.append(op3.op_id) - return graph, task_end_ops_ids - else: - op1 = TransferOp( - graph_id = graph.graph_id, - transfer_type=TransferType.DISK2H, - src_block_ids=ssd_blocks, - dst_block_ids=cpu_blocks, - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op1) - op2 = TransferOp( - graph_id = graph.graph_id, - transfer_type=TransferType.H2D, - src_block_ids=cpu_blocks, - dst_block_ids=gpu_blocks, - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op2) - graph.add_dependency(op2.op_id, op1.op_id) - return graph, [op2.op_id] - -def create_read_graph_cpu_ssd_remote( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - remote_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - write_back_to_ssd: bool = True, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a read transfer graph with (REMOTE_STORAGE + SSD)->CPU->GPU operations - Returns: - graph: TransferOpGraph - finished_ops_ids: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - graph = TransferOpGraph() - finished_ops_ids: List[int] = [] - if len(remote_blocks) == 0: - graph, finished_ops_ids = create_read_graph_cpu_storage(gpu_blocks=gpu_blocks, - cpu_blocks=cpu_blocks, - ssd_blocks=ssd_blocks, - gpu_device_id=gpu_device_id, - layer_num=layer_num, - graph=graph) - if len(finished_ops_ids) > 0: - graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops(graph, finished_ops_ids) - assert len(finished_ops_ids) > 0 - return graph, finished_ops_ids - else: - if len(remote_blocks) < len(gpu_blocks): - graph, finished_ops_ids = create_read_graph_cpu_storage(gpu_blocks=gpu_blocks[:-len(remote_blocks)], - cpu_blocks=cpu_blocks[:-len(remote_blocks)], - ssd_blocks=ssd_blocks[:-len(remote_blocks)], - gpu_device_id=gpu_device_id, - layer_num=layer_num, - graph=graph) - op_r2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.REMOTE2H, - src_block_ids = remote_blocks, - dst_block_ids = cpu_blocks[-len(remote_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_r2h) - op_h2d = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_block_ids = cpu_blocks[-len(remote_blocks):], - dst_block_ids = gpu_blocks[-len(remote_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2d) - graph.add_dependency(op_h2d.op_id, op_r2h.op_id) - if write_back_to_ssd: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_block_ids = cpu_blocks[-len(remote_blocks):], - dst_block_ids = ssd_blocks[-len(remote_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_r2h.op_id) - finished_ops_ids.append(op_h2d.op_id) - if len(finished_ops_ids) > 0: - graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops(graph, finished_ops_ids) - return graph, finished_ops_ids - -def create_write_graph_cpu_storage( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - graph: Optional[TransferOpGraph] = None, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a write transfer graph with CPU->REMOTE_STORAGE / SSD operations - ssd_blocks: the blocks of ssd that are used as a lower-level storage backend, - including ssd or remote storage. This can be empty, which means cpu-only kvcache. - Write op granularity is larger: gpu->cpu is put into the same op. - Returns: - graph: TransferOpGraph - layer_wise_ops: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - if graph is None: - graph = TransferOpGraph() - op_d2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.D2H, - src_block_ids = gpu_blocks, - dst_block_ids = cpu_blocks[-len(gpu_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_d2h) - if len(ssd_blocks) == 0: - return graph, [op_d2h.op_id] - else: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_block_ids = cpu_blocks[-len(ssd_blocks):], - dst_block_ids = ssd_blocks, - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_d2h.op_id) - return graph, [op_d2h.op_id] - -def create_write_graph_cpu_ssd_remote( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - remote_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a write transfer graph with CPU->REMOTE_STORAGE + SSD operations - Returns: - graph: TransferOpGraph - layer_wise_ops: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - graph = TransferOpGraph() - op_d2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.D2H, - src_block_ids = gpu_blocks, - dst_block_ids = cpu_blocks[-len(gpu_blocks):], - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_d2h) - if len(ssd_blocks) != 0: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_block_ids = cpu_blocks[-len(gpu_blocks):], - dst_block_ids = ssd_blocks, - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_d2h.op_id) - if len(remote_blocks) != 0: - op_h2remote = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2REMOTE, - src_block_ids = cpu_blocks[-len(remote_blocks):], - dst_block_ids = remote_blocks, - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2remote) - graph.add_dependency(op_h2remote.op_id, op_d2h.op_id) - return graph, [op_d2h.op_id] - def convert_read_graph_to_layer_wise_graph( transfer_graph: TransferOpGraph, finished_ops_ids: List[int], diff --git a/tests/conftest.py b/tests/conftest.py index e04cf78e50..2aba477a6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,3 @@ """ # Import fixtures from test_utils so pytest can discover them from test_utils import model_config, cache_config, test_config - -import multiprocessing as mp - -# Set the start method for multiprocessing to 'spawn' -# This ensures consistent behavior across different platforms -mp.set_start_method("spawn", force=True) diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 133dfb470c..d76a6fe05c 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -1,7 +1,7 @@ import random import pytest -import torch +import numpy as np from flexkv.cache.mempool import Mempool from flexkv.cache.cache_engine import CacheEngine @@ -42,14 +42,14 @@ def test_mempool(): mempool = Mempool(num_total_blocks=64) assert mempool.num_free_blocks == 64 block_ids = mempool.allocate_blocks(16) - assert isinstance(block_ids, torch.Tensor) - assert block_ids.dtype == torch.int64 + assert isinstance(block_ids, np.ndarray) + assert block_ids.dtype == np.int64 assert block_ids.shape == (16,) assert mempool.num_free_blocks == 48 mempool.recycle_blocks(block_ids) assert mempool.num_free_blocks == 64 - block_ids = torch.cat([mempool.allocate_blocks(16), + block_ids = np.concatenate([mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16)]) @@ -63,21 +63,21 @@ def test_mempool(): empty_blocks = mempool.allocate_blocks(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): mempool.allocate_blocks(-1) - mempool.recycle_blocks(torch.tensor([], dtype=torch.int64)) + mempool.recycle_blocks(np.array([], dtype=np.int64)) assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int32)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int64)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([[1, 2, 3]], dtype=torch.int64)) + mempool.recycle_blocks(np.array([[1, 2, 3]], dtype=np.int64)) def test_reset(cache_engine: CacheEngine): cache_engine.reset() @@ -101,22 +101,22 @@ def test_reset(cache_engine: CacheEngine): [1, 10, 16, 32, 10000], ) def test_match_and_insert(cache_engine: CacheEngine, num_insert: int, seq_len: int): - base_token_ids = torch.randint(0, 10000, (seq_len, ), dtype=torch.int64) + base_token_ids = np.random.randint(0, 10000, (seq_len, ), dtype=np.int64) base_num_blocks = seq_len // cache_engine.tokens_per_block cache_engine.insert(SequenceMeta(token_ids=base_token_ids, tokens_per_block=cache_engine.tokens_per_block), - torch.arange(base_num_blocks, dtype=torch.int64), + np.arange(base_num_blocks, dtype=np.int64), is_ready=True) cur_cached_blocks = base_num_blocks for i in range(num_insert): prefix_ratio = random.random() prefix_len = int(len(base_token_ids)*prefix_ratio) num_prefix_blocks = prefix_len // cache_engine.tokens_per_block - token_ids = torch.cat([base_token_ids[:prefix_len], - torch.randint(10000 + i * seq_len, + token_ids = np.concatenate([base_token_ids[:prefix_len], + np.random.randint(10000 + i * seq_len, 10000 + (i+1) * seq_len, (seq_len-prefix_len, ), - dtype=torch.int64)]) + dtype=np.int64)]) insert_sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=cache_engine.tokens_per_block) match_result = cache_engine.match(insert_sequence_meta) @@ -125,11 +125,11 @@ def test_match_and_insert(cache_engine: CacheEngine, num_insert: int, seq_len: i assert match_result.last_ready_node is not None assert match_result.last_node is not None assert match_result.physical_blocks.shape == (num_prefix_blocks, ) - assert match_result.physical_blocks.dtype == torch.int64 + assert match_result.physical_blocks.dtype == np.int64 num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks cache_engine.insert(insert_sequence_meta, - torch.arange(num_insert_blocks, dtype=torch.int64), + np.arange(num_insert_blocks, dtype=np.int64), is_ready=True, match_result=match_result) cur_cached_blocks += num_insert_blocks @@ -150,7 +150,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): num_total_blocks = cache_engine.num_total_blocks tokens_per_block = cache_engine.tokens_per_block seq_blocks = 10 - token_ids = torch.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=torch.int64) + token_ids = np.random.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=np.int64) sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) physical_blocks = cache_engine.take(seq_blocks) @@ -159,7 +159,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): empty_blocks = cache_engine.take(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 with pytest.raises(ValueError): cache_engine.take(-1) @@ -168,7 +168,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) - assert physical_blocks2.dtype == torch.int64 + assert physical_blocks2.dtype == np.int64 cache_engine.recycle(physical_blocks2) @@ -193,22 +193,22 @@ def test_cleanup(cache_engine: CacheEngine): if cache_engine.tokens_per_block != 1: pytest.skip("tokens_per_block != 1") tokens_per_block = cache_engine.tokens_per_block - token_ids_list = [torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64), - torch.tensor([0, 1, 2, 3, 17, 15, 19, 20], dtype=torch.int64), - torch.tensor([0, 23, 22, 21], dtype=torch.int64)] + token_ids_list = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64), + np.array([0, 1, 2, 3, 17, 15, 19, 20], dtype=np.int64), + np.array([0, 23, 22, 21], dtype=np.int64)] sequence_meta_list = [SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) for token_ids in token_ids_list] num_insert_blocks0 = sequence_meta_list[0].num_blocks radixnode0 = cache_engine.insert(sequence_meta_list[0], - torch.arange(num_insert_blocks0, dtype=torch.int64), + np.arange(num_insert_blocks0, dtype=np.int64), is_ready=False) cache_engine.lock_node(radixnode0) radixnode0_size = radixnode0.size() match_result = cache_engine.match(sequence_meta_list[1]) num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks radixnode1 = cache_engine.insert(sequence_meta_list[1], - torch.arange(num_insert_blocks1, dtype=torch.int64), + np.arange(num_insert_blocks1, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode1) @@ -216,7 +216,7 @@ def test_cleanup(cache_engine: CacheEngine): match_result = cache_engine.match(sequence_meta_list[2]) num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks radixnode2 = cache_engine.insert(sequence_meta_list[2], - torch.arange(num_insert_blocks2, dtype=torch.int64), + np.arange(num_insert_blocks2, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode2) diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py index 500b1ac7e9..464aa31d34 100644 --- a/tests/test_cache_engine_accel.py +++ b/tests/test_cache_engine_accel.py @@ -1,7 +1,7 @@ import random import pytest -import torch +import numpy as np from flexkv.cache.mempool import Mempool from flexkv.cache.cache_engine import CacheEngineAccel @@ -42,14 +42,14 @@ def test_mempool(): mempool = Mempool(num_total_blocks=64) assert mempool.num_free_blocks == 64 block_ids = mempool.allocate_blocks(16) - assert isinstance(block_ids, torch.Tensor) - assert block_ids.dtype == torch.int64 + assert isinstance(block_ids, np.ndarray) + assert block_ids.dtype == np.int64 assert block_ids.shape == (16,) assert mempool.num_free_blocks == 48 mempool.recycle_blocks(block_ids) assert mempool.num_free_blocks == 64 - block_ids = torch.cat([mempool.allocate_blocks(16), + block_ids = np.concatenate([mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16)]) @@ -63,21 +63,21 @@ def test_mempool(): empty_blocks = mempool.allocate_blocks(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): mempool.allocate_blocks(-1) - mempool.recycle_blocks(torch.tensor([], dtype=torch.int64)) + mempool.recycle_blocks(np.array([], dtype=np.int64)) assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int32)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int64)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([[1, 2, 3]], dtype=torch.int64)) + mempool.recycle_blocks(np.array([[1, 2, 3]], dtype=np.int64)) def test_reset(cache_engine: CacheEngineAccel): cache_engine.reset() @@ -101,22 +101,22 @@ def test_reset(cache_engine: CacheEngineAccel): [1, 10, 16, 32, 10000], ) def test_match_and_insert(cache_engine: CacheEngineAccel, num_insert: int, seq_len: int): - base_token_ids = torch.randint(0, 10000, (seq_len, ), dtype=torch.int64) + base_token_ids = np.random.randint(0, 10000, (seq_len, ), dtype=np.int64) base_num_blocks = seq_len // cache_engine.tokens_per_block cache_engine.insert(SequenceMeta(token_ids=base_token_ids, tokens_per_block=cache_engine.tokens_per_block), - torch.arange(base_num_blocks, dtype=torch.int64), + np.arange(base_num_blocks, dtype=np.int64), is_ready=True) cur_cached_blocks = base_num_blocks for i in range(num_insert): prefix_ratio = random.random() prefix_len = int(len(base_token_ids)*prefix_ratio) num_prefix_blocks = prefix_len // cache_engine.tokens_per_block - token_ids = torch.cat([base_token_ids[:prefix_len], - torch.randint(10000 + i * seq_len, + token_ids = np.concatenate([base_token_ids[:prefix_len], + np.random.randint(10000 + i * seq_len, 10000 + (i+1) * seq_len, (seq_len-prefix_len, ), - dtype=torch.int64)]) + dtype=np.int64)]) insert_sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=cache_engine.tokens_per_block) match_result = cache_engine.match(insert_sequence_meta) @@ -125,7 +125,7 @@ def test_match_and_insert(cache_engine: CacheEngineAccel, num_insert: int, seq_l num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks cache_engine.insert(insert_sequence_meta, - torch.arange(num_insert_blocks, dtype=torch.int64), + np.arange(num_insert_blocks, dtype=np.int64), is_ready=True, match_result=match_result) cur_cached_blocks += num_insert_blocks @@ -146,7 +146,7 @@ def test_take_and_recycle(cache_engine: CacheEngineAccel): num_total_blocks = cache_engine.num_total_blocks tokens_per_block = cache_engine.tokens_per_block seq_blocks = 10 - token_ids = torch.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=torch.int64) + token_ids = np.random.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=np.int64) sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) physical_blocks = cache_engine.take(seq_blocks) @@ -155,7 +155,7 @@ def test_take_and_recycle(cache_engine: CacheEngineAccel): empty_blocks = cache_engine.take(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 with pytest.raises(ValueError): cache_engine.take(-1) @@ -164,7 +164,7 @@ def test_take_and_recycle(cache_engine: CacheEngineAccel): physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) - assert physical_blocks2.dtype == torch.int64 + assert physical_blocks2.dtype == np.int64 cache_engine.recycle(physical_blocks2) @@ -188,22 +188,22 @@ def test_cleanup(cache_engine: CacheEngineAccel): if cache_engine.tokens_per_block != 1: pytest.skip("tokens_per_block != 1") tokens_per_block = cache_engine.tokens_per_block - token_ids_list = [torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64), - torch.tensor([0, 1, 2, 3, 17, 15, 19, 20], dtype=torch.int64), - torch.tensor([0, 23, 22, 21], dtype=torch.int64)] + token_ids_list = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64), + np.array([0, 1, 2, 3, 17, 15, 19, 20], dtype=np.int64), + np.array([0, 23, 22, 21], dtype=np.int64)] sequence_meta_list = [SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) for token_ids in token_ids_list] num_insert_blocks0 = sequence_meta_list[0].num_blocks radixnode0 = cache_engine.insert(sequence_meta_list[0], - torch.arange(num_insert_blocks0, dtype=torch.int64), + np.arange(num_insert_blocks0, dtype=np.int64), is_ready=False) cache_engine.lock_node(radixnode0) radixnode0_size = radixnode0.size() match_result = cache_engine.match(sequence_meta_list[1]) num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks radixnode1 = cache_engine.insert(sequence_meta_list[1], - torch.arange(num_insert_blocks1, dtype=torch.int64), + np.arange(num_insert_blocks1, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode1) @@ -211,7 +211,7 @@ def test_cleanup(cache_engine: CacheEngineAccel): match_result = cache_engine.match(sequence_meta_list[2]) num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks radixnode2 = cache_engine.insert(sequence_meta_list[2], - torch.arange(num_insert_blocks2, dtype=torch.int64), + np.arange(num_insert_blocks2, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode2) diff --git a/tests/test_transfer_engine.py b/tests/test_transfer_engine.py deleted file mode 100644 index 73ee107864..0000000000 --- a/tests/test_transfer_engine.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -Transfer Engine Unit Tests - -This module contains comprehensive unit tests for the TransferEngine component, -which handles data transfers between different storage tiers (GPU, CPU, SSD). - -Test Functions Overview: -1. test_gpu_cpu_round_trip: Tests round-trip data transfers between GPU and CPU - - Parameterized by: tp_size, dp_size, num_gpu_blocks, transfer_block_num - - Validates data consistency after GPU->CPU->GPU transfers - -2. test_ssd_round_trip: Tests round-trip data transfers involving SSD storage - - Parameterized by: num_gpu_blocks, transfer_block_num, enable_ssd_cache - - Validates data consistency after GPU->CPU->SSD->CPU->GPU transfers - -3. test_concurrent_mixed_transfers: Tests multiple concurrent read/write transfers - - Parameterized by: num_concurrent_transfers, blocks_per_transfer, include_ssd - - Validates correctness of mixed read/write transfer graphs running concurrently - -usage example: - python -m pytest tests/test_transfer_engine.py::test_gpu_cpu_round_trip -v --tb=short -Each test validates both transfer completion and data correctness to ensure -the TransferEngine maintains data integrity across all transfer operations. -""" - -import os -import time -import tempfile -from typing import List, Dict, Tuple -import multiprocessing as mp -from contextlib import suppress - -import pytest -import torch - -from flexkv.cache.transfer_pattern import ( - create_read_graph_cpu_storage, - create_write_graph_cpu_storage, -) -from flexkv.common.config import ModelConfig, CacheConfig -from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.common.transfer import DeviceType -from flexkv.storage.storage_engine import StorageEngine -from flexkv.transfer.transfer_engine import TransferEngine - -# Import utilities from test_utils -from test_utils import ( - wait_for_transfer_completion, - skip_if_no_cuda, - skip_if_insufficient_gpus, - generate_gpu_blocks_with_ground_truth, - verify_data -) - -@pytest.mark.parametrize("tp_size,dp_size", [(1, 1), (2, 1), (2, 2)]) -@pytest.mark.parametrize("num_gpu_blocks", [128]) -@pytest.mark.parametrize("transfer_block_num", [16]) -@pytest.mark.parametrize("use_mla", [False, True]) -@pytest.mark.parametrize("underlying_layout_type", [KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE]) -def test_gpu_cpu_round_trip(model_config, - cache_config, - test_config, - tp_size, - dp_size, - num_gpu_blocks, - transfer_block_num, - use_mla, - underlying_layout_type): - """ - Test round-trip data transfers between GPU and CPU - - This test validates: - 1. GPU -> CPU transfer correctness - 2. CPU -> GPU transfer correctness - 3. Round-trip data consistency (GPU -> CPU -> GPU) - - Parameterized by: - - tp_size, dp_size: Tensor and data parallelism configurations - - num_gpu_blocks: Number of GPU blocks to test with - - transfer_block_num: Number of blocks to transfer in each operation - """ - total_gpus = tp_size * dp_size - skip_if_insufficient_gpus(total_gpus) - - if transfer_block_num > num_gpu_blocks: - pytest.skip(f"transfer_block_num ({transfer_block_num}) > num_gpu_blocks ({num_gpu_blocks})") - - # Update model config - model_config.use_mla = use_mla - model_config.tp_size = tp_size - model_config.dp_size = dp_size - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - cache_config.cpu_kv_layout_type = underlying_layout_type - # Setup configurations - cache_config.enable_ssd = False - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) for i in range(total_gpus)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle - ) - transfer_engine.start() - - # Test each DP group separately - for dp_id in range(dp_size): - gpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - cpu_block_ids = torch.arange(dp_id * transfer_block_num, (dp_id + 1) * transfer_block_num, dtype=torch.int64) - - # Step 1: GPU -> CPU transfer - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=torch.tensor([], dtype=torch.int64), - gpu_device_id=dp_id * tp_size, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(dp_id) - transfer_engine.submit_transfer_graph(write_graph) - - # Wait for write completion - assert wait_for_transfer_completion(transfer_engine, [write_graph.graph_id]), \ - f"GPU->CPU transfer failed for DP group {dp_id}" - - # Clear GPU blocks for read test - for tp_id in range(tp_size): - global_gpu_id = dp_id * tp_size + tp_id - for layer_id in range(model_config.num_layers): - gpu_blocks[global_gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Step 2: CPU -> GPU transfer - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=torch.tensor([], dtype=torch.int64), - gpu_device_id=dp_id * tp_size, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(dp_id) - transfer_engine.submit_transfer_graph(read_graph) - # Wait for read completion - assert wait_for_transfer_completion(transfer_engine, [read_graph.graph_id]), \ - f"CPU->GPU transfer failed for DP group {dp_id}" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - # Cleanup - transfer_engine.shutdown() - - -@pytest.mark.parametrize("num_gpu_blocks", [64, 128]) -@pytest.mark.parametrize("transfer_block_num", [8, 16]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("iouring_entries", [0, 512]) -def test_ssd_round_trip(model_config, - cache_config, - test_config, - num_gpu_blocks, - transfer_block_num, - use_mla, - iouring_entries): - """ - Test round-trip data transfers involving SSD storage - - This test validates: - 1. GPU -> CPU -> SSD transfer chain - 2. SSD -> CPU -> GPU transfer chain - 3. Full round-trip data consistency - - Parameterized by: - - num_gpu_blocks: Number of GPU blocks to test with - - transfer_block_num: Number of blocks to transfer - """ - skip_if_no_cuda() - - if transfer_block_num > num_gpu_blocks: - pytest.skip(f"transfer_block_num ({transfer_block_num}) > num_gpu_blocks ({num_gpu_blocks})") - - # Setup configurations - cache_config.enable_ssd = True - cache_config.ssd_cache_iouring_entries = iouring_entries - model_config.use_mla = use_mla - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - if (model_config.tp_size * model_config.dp_size) > 1: - pytest.skip("SSD transfer test is not supported for multi-GPU") - - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(model_config.tp_size * model_config.dp_size)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - ssd_handle = storage_engine.get_storage_handle(DeviceType.SSD) - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle - ) - transfer_engine.start() - # Prepare transfer block IDs - gpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - cpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - ssd_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - - # Step 1: GPU -> CPU -> SSD write - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(0) - transfer_engine.submit_transfer_graph(write_graph) - - # Wait for write completion - assert wait_for_transfer_completion(transfer_engine, [write_graph.graph_id]), \ - "GPU->CPU->SSD write transfer failed" - - # Clear GPU blocks for read test - for gpu_id in range(model_config.tp_size * model_config.dp_size): - for layer_id in range(model_config.num_layers): - gpu_blocks[gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Step 2: SSD -> CPU -> GPU read - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(0) - transfer_engine.submit_transfer_graph(read_graph) - - # Wait for read completion - assert wait_for_transfer_completion(transfer_engine, [read_graph.graph_id]), \ - "SSD->CPU->GPU read transfer failed" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - - # Cleanup - transfer_engine.shutdown() - - -@pytest.mark.parametrize("num_concurrent_transfers", [4]) -@pytest.mark.parametrize("blocks_per_transfer", [16]) -@pytest.mark.parametrize("include_ssd", [True, False]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("iouring_entries", [0, 512]) -def test_concurrent_mixed_transfers(model_config, - cache_config, - test_config, - num_concurrent_transfers, - blocks_per_transfer, - include_ssd, - use_mla, - iouring_entries): - """ - Test multiple concurrent read/write transfers - - This test validates: - 1. Multiple write transfers running concurrently - 2. Multiple read transfers running concurrently - 3. Mixed read/write transfers running concurrently - 4. Data correctness across all concurrent operations - - Parameterized by: - - num_concurrent_transfers: Number of concurrent transfer graphs - - blocks_per_transfer: Number of blocks per transfer operation - - include_ssd: Whether to include SSD in transfer operations - """ - model_config.use_mla = use_mla - skip_if_no_cuda() - - if (model_config.tp_size * model_config.dp_size) > 1: - pytest.skip("Concurrent transfer test is not supported for multi-GPU") - - total_blocks_needed = num_concurrent_transfers * blocks_per_transfer * 2 # For both read and write - num_gpu_blocks = max(128, total_blocks_needed) - - cache_config.num_cpu_blocks = num_gpu_blocks - cache_config.num_ssd_blocks = num_gpu_blocks - cache_config.ssd_cache_iouring_entries = iouring_entries - - # Setup configurations - cache_config.enable_ssd = include_ssd - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(model_config.tp_size * model_config.dp_size)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - ssd_handle = storage_engine.get_storage_handle(DeviceType.SSD) if include_ssd else None - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle - ) - - transfer_engine.start() - # Create concurrent write transfers - write_graphs = [] - - for i in range(num_concurrent_transfers): - start_block = i * blocks_per_transfer - end_block = start_block + blocks_per_transfer - - gpu_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) - cpu_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) - ssd_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) \ - if include_ssd else torch.tensor([], dtype=torch.int64) - - - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(0) - write_graphs.append(write_graph) - - # Submit all write transfers - for graph in write_graphs: - transfer_engine.submit_transfer_graph(graph) - - # Wait for all writes to complete - write_graph_ids = [graph.graph_id for graph in write_graphs] - assert wait_for_transfer_completion(transfer_engine, write_graph_ids, max_wait_time=20.0), \ - "Concurrent write transfers failed to complete" - - # Clear GPU blocks for read test - for gpu_id in range(model_config.tp_size * model_config.dp_size): - gpu_block_ids = torch.arange(0, (num_concurrent_transfers + 1) * blocks_per_transfer, dtype=torch.int64) - for layer_id in range(model_config.num_layers): - gpu_blocks[gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Create concurrent read transfers (using different GPU blocks) - read_graphs = [] - - for i in range(num_concurrent_transfers): - gpu_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) - cpu_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) - ssd_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) \ - if include_ssd else torch.tensor([], dtype=torch.int64) - - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(0) - read_graphs.append(read_graph) - - # Submit all read transfers - for graph in read_graphs: - transfer_engine.submit_transfer_graph(graph) - - # Wait for all reads to complete - read_graph_ids = [graph.graph_id for graph in read_graphs] - assert wait_for_transfer_completion(transfer_engine, read_graph_ids, max_wait_time=20.0), \ - "Concurrent read transfers failed to complete" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - - # Cleanup - transfer_engine.shutdown() - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 42247c2a2a6e88f587455be6ef123abfccf80fd1 Mon Sep 17 00:00:00 2001 From: charliecgxu Date: Fri, 22 Aug 2025 22:36:48 +0800 Subject: [PATCH 07/42] enable profile in release build Signed-off-by: charliecgxu --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 5926bac9aa..1c080a00be 100755 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ "boundscheck": False, "wraparound": False, "initializedcheck": False, + "profile": True, }, build_dir=build_dir, # Direct Cython to use the build directory ) From 502e9aa0de66a809f483437bbd36f3f74dbf7b9c Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 22 Aug 2025 01:09:08 -0700 Subject: [PATCH 08/42] rename functions --- flexkv/kvmanager.py | 13 +++++-------- flexkv/kvtask.py | 22 +++++++++++----------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 8724ce6829..52bd97d019 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -168,20 +168,17 @@ def launch(self, if isinstance(slot_mappings[0], torch.Tensor): slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] if self.server_client_mode: - for task_id in task_ids: - self.dp_client.launch_task(task_id, slot_mappings) + self.dp_client.launch_tasks(task_ids, slot_mappings) else: - self.kv_task_engine.launch_transfer(task_ids, slot_mappings) + self.kv_task_engine.launch_tasks(task_ids, slot_mappings) - def cancel_task(self, task_ids: Union[int, List[int]]) -> None: + def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] if self.server_client_mode: - for task_id in task_ids: - self.dp_client.cancel_task(task_id) + self.dp_client.cancel_tasks(task_ids) else: - for task_id in task_ids: - self.kv_task_engine.cancel_task(task_id) + self.kv_task_engine.cancel_tasks(task_ids) def wait(self, task_ids: Union[int, List[int]], diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 625d98c78e..86360e1920 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -177,7 +177,7 @@ def create_put_task(self, callback=callback) self.graph_to_task[graph.graph_id] = task_id - def launch_task(self, task_id: int) -> None: + def _launch_task(self, task_id: int) -> None: task = self.tasks[task_id] if task.is_completed(): return @@ -188,7 +188,7 @@ def launch_task(self, task_id: int) -> None: if transfer_graph.num_ops > 0: self.transfer_handle.submit(transfer_graph) - def update_tasks(self, timeout: float = 0.001) -> None: + def _update_tasks(self, timeout: float = 0.001) -> None: completed_ops = self._get_completed_ops(timeout) for completed_graph_id, completed_op_id in completed_ops: if completed_graph_id not in self.graph_to_task: @@ -200,7 +200,7 @@ def update_tasks(self, timeout: float = 0.001) -> None: elif completed_op_id == task.task_end_op_id: self.tasks[task_id].task_end_op_finished = True - def cancel_task(self, task_id: int) -> None: + def _cancel_task(self, task_id: int) -> None: task = self.tasks[task_id] if task.is_completed(): flexkv_logger.warning(f"Task {task_id} is already completed, cannot cancel") @@ -328,7 +328,7 @@ def get_async(self, layer_granularity=layer_granularity, dp_id=dp_id, task_id=task_id) - self.launch_task(task_id) + self._launch_task(task_id) return task_id, return_mask def put_async(self, @@ -343,7 +343,7 @@ def put_async(self, token_mask=token_mask, dp_id=dp_id, task_id=task_id) - self.launch_task(task_id) + self._launch_task(task_id) return task_id, return_mask def _wait_impl(self, @@ -354,7 +354,7 @@ def _wait_impl(self, start_time = time.time() is_timeout = timeout == 0.0 - self.update_tasks(timeout=0) + self._update_tasks(timeout=0) for task_id in task_ids: while True: @@ -393,7 +393,7 @@ def _wait_impl(self, if time.time() - start_time > timeout: is_timeout = True break - self.update_tasks(timeout=0.001) + self._update_tasks(timeout=0.001) return return_responses def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: @@ -489,16 +489,16 @@ def _put_match_impl(self, self._process_empty_graph(task_id) return task_id, self.tasks[task_id].return_mask - def launch_transfer(self, + def launch_tasks(self, task_ids: List[int], slot_mappings: List[np.ndarray]) -> None: assert isinstance(slot_mappings[0], np.ndarray) self.set_slot_mappings(task_ids, slot_mappings) for task_id in task_ids: - self.launch_task(task_id) + self._launch_task(task_id) - def cancel(self, task_ids: Union[int, List[int]]) -> None: + def cancel_tasks(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] for task_id in task_ids: - self.cancel_task(task_id) + self._cancel_task(task_id) From ee91ca7b12d73c2cf4b8d72ae77a2a41f0ba2065 Mon Sep 17 00:00:00 2001 From: leolingli Date: Thu, 21 Aug 2025 19:45:51 +0800 Subject: [PATCH 09/42] add evict_ratio in cache config, default is 0 evict number is max( int(mempool.num_total_blocks*evict_ratio), former evict number ) --- benchmarks/example_config.json | 3 ++- flexkv/cache/cache_engine.py | 36 ++++++++++++++++++++++------------ flexkv/common/config.py | 3 +++ flexkv/common/tracer.py | 1 + 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 630bdded0b..5ae4237df4 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -40,6 +40,7 @@ "trace_file_path": "./flexkv_trace.log", "trace_max_file_size_mb": 100, "trace_max_files": 5, - "trace_flush_interval_ms": 1000 + "trace_flush_interval_ms": 1000, + "evict_ratio": 0.05 } } diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index b2c76b795d..3ecade2076 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -51,7 +51,8 @@ class CacheEngineAccel: def __init__(self, device_type: DeviceType, num_total_blocks: int, - tokens_per_block: int): + tokens_per_block: int, + evict_ratio: float): if not isinstance(device_type, DeviceType): raise InvalidConfigError(f"Unknown device type: {device_type}") if num_total_blocks <= 0: @@ -68,6 +69,7 @@ def __init__(self, self.tokens_per_block = tokens_per_block self.num_total_blocks = num_total_blocks + self.evict_ratio = evict_ratio def reset(self) -> None: self.index.reset() @@ -119,9 +121,10 @@ def take(self, if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) - target_blocks = torch.zeros(num_required_blocks - self.mempool.num_free_blocks, dtype=torch.int64) - num_evicted = self.index.evict(target_blocks, num_required_blocks - self.mempool.num_free_blocks) - if num_evicted != num_required_blocks - self.mempool.num_free_blocks: + evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) + num_evicted = self.index.evict(target_blocks, evict_block_num) + if num_evicted != evict_block_num: target_blocks.resize_(num_evicted) self.mempool.recycle_blocks(target_blocks) @@ -141,7 +144,8 @@ class CacheEngine: def __init__(self, device_type: DeviceType, num_total_blocks: int, - tokens_per_block: int): + tokens_per_block: int, + evict_ratio: float): if not isinstance(device_type, DeviceType): raise InvalidConfigError(f"Unknown device type: {device_type}") if num_total_blocks <= 0: @@ -158,6 +162,7 @@ def __init__(self, self.tokens_per_block = tokens_per_block self.num_total_blocks = num_total_blocks + self.evict_ratio = evict_ratio def reset(self) -> None: self.index.reset() @@ -194,8 +199,9 @@ def take(self, if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) + evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) self.mempool.recycle_blocks( - self.index.evict(num_required_blocks - self.mempool.num_free_blocks) + self.index.evict(evict_block_num) ) if protected_node is not None: self.index.unlock(protected_node) @@ -225,31 +231,37 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): if cache_config.index_accel: self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) else: self.cpu_cache_engine = CacheEngine(DeviceType.CPU, cache_config.num_cpu_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: if cache_config.index_accel: self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, cache_config.num_ssd_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) else: self.ssd_cache_engine = CacheEngine(DeviceType.SSD, cache_config.num_ssd_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.SSD] = self.ssd_cache_engine if cache_config.enable_remote: if cache_config.index_accel: self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, cache_config.num_remote_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) else: self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, cache_config.num_remote_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \ diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 9148edf1d9..2dc3cc8376 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -71,3 +71,6 @@ class CacheConfig: trace_max_file_size_mb: int = 100 trace_max_files: int = 5 trace_flush_interval_ms: int = 1000 + + #evict ratio + evict_ratio: float = 0.0 diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 45178e85b4..92668ae3f8 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -134,6 +134,7 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_cache_iouring_flags": cache_config.ssd_cache_iouring_flags, "remote_cache_path": cache_config.remote_cache_path, "remote_config_custom": cache_config.remote_config_custom, + "evict_ratio": cache_config.evict_ratio, } # Convert gpu_layout to dict if provided From d5ecff6f653b988a7c5cf70de70428151e7c04b4 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 25 Aug 2025 16:46:15 +0800 Subject: [PATCH 10/42] update benchmark worker (#82) * update benchmark worker * status map * enable nvtx * fix default config * clear cpu to test ssd cache --- benchmarks/benchmark_single_batch.py | 15 ++++++++++-- benchmarks/benchmark_workers.py | 25 ++++++++----------- benchmarks/example_config.json | 8 +++---- flexkv/common/transfer.py | 2 ++ flexkv/kvmanager.py | 8 +++++++ flexkv/kvtask.py | 36 ++++++++++++++++++---------- setup.py | 2 +- tests/test_kvmanager.py | 1 - tests/test_utils.py | 6 ++--- 9 files changed, 64 insertions(+), 39 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index b58c3b8e71..fc77263926 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -9,6 +9,7 @@ from flexkv.server.client import KVTPClient from flexkv.common.storage import KVCacheLayout from flexkv.common.debug import flexkv_logger +from flexkv.common.config import ModelConfig, CacheConfig from utils import load_config from flexkv.kvmanager import KVManager from flexkv.kvtask import KVResponseStatus @@ -22,6 +23,7 @@ class BenchmarkConfig: batch_size: int sequence_length: int cache_ratio: float + clear_cpu_cache: bool def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): """Run tp_client process""" @@ -61,7 +63,11 @@ def shutdown_tp_client(tp_client_processes): tp_process.kill() tp_process.join(timeout=2) -def benchmark_flexkv(model_config, cache_config, benchmark_config, gpu_register_port, server_recv_port): +def benchmark_flexkv(model_config: ModelConfig, + cache_config: CacheConfig, + benchmark_config: BenchmarkConfig, + gpu_register_port: str, + server_recv_port: str): if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " f"the number of available GPUs {torch.cuda.device_count()}") @@ -111,6 +117,9 @@ def benchmark_flexkv(model_config, cache_config, benchmark_config, gpu_register_ put_result = kvmanager.wait(batch_put_ids, completely=True) end_time = time.time() + if benchmark_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + elapsed_time_put = end_time - start_time put_tokens = 0 for _, response in put_result.items(): @@ -157,6 +166,7 @@ def parse_args(): parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--sequence-length", type=int, default=1024) parser.add_argument("--cache-ratio", type=float, default=1) + parser.add_argument("--clear-cpu-cache", action="store_true") return parser.parse_args() if __name__ == "__main__": @@ -165,7 +175,8 @@ def parse_args(): num_layers_to_transfer=args.num_layers, batch_size=args.batch_size, sequence_length=args.sequence_length, - cache_ratio=args.cache_ratio + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache ) model_config, cache_config = load_config(args.config) #cache_config.num_cpu_blocks = 8192 - 2048 diff --git a/benchmarks/benchmark_workers.py b/benchmarks/benchmark_workers.py index c02fb6f925..8c1c2182e3 100644 --- a/benchmarks/benchmark_workers.py +++ b/benchmarks/benchmark_workers.py @@ -17,7 +17,7 @@ from flexkv.common.debug import flexkv_logger -flexkv_logger.set_level("OFF") +# flexkv_logger.set_level("OFF") @dataclass class BenchmarkConfig: @@ -50,11 +50,10 @@ def make_configs(args: dict) -> Tuple[ModelConfig, CacheConfig, BenchmarkConfig] def create_cpu_gpu_worker( model_config: ModelConfig, - cache_config: CacheConfig, - bench_config: BenchmarkConfig) -> Tuple[WorkerHandle, mp.Queue]: + cache_config: CacheConfig) -> Tuple[WorkerHandle, mp.Queue]: mp.set_start_method('spawn', force=True) cpu_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.cpu_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_cpu_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -85,7 +84,6 @@ def create_cpu_gpu_worker( finished_ops_queue = mp.Queue() if model_config.tp_size == 1: worker_handle = GPUCPUTransferWorker.create_worker( - worker_id=0, finished_ops_queue=finished_ops_queue, gpu_blocks=gpu_handles[0].get_tensor_handle_list(), cpu_blocks=cpu_handle.get_tensor(), @@ -100,7 +98,6 @@ def create_cpu_gpu_worker( ) else: worker_handle = tpGPUCPUTransferWorker.create_worker( - worker_id=0, finished_ops_queue=finished_ops_queue, gpu_blocks=[handle.get_tensor_handle_list() for handle in gpu_handles], cpu_blocks=cpu_handle.get_tensor(), @@ -121,11 +118,10 @@ def create_cpu_gpu_worker( def create_cpu_ssd_worker( model_config: ModelConfig, - cache_config: CacheConfig, - bench_config: BenchmarkConfig) -> Tuple[WorkerHandle, mp.Queue]: + cache_config: CacheConfig) -> Tuple[WorkerHandle, mp.Queue]: mp.set_start_method('spawn', force=True) cpu_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.cpu_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_cpu_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -133,7 +129,7 @@ def create_cpu_ssd_worker( head_size=model_config.head_size, ) ssd_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.ssd_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_ssd_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -153,7 +149,6 @@ def create_cpu_ssd_worker( ) finished_ops_queue = mp.Queue() worker_handle = CPUSSDDiskTransferWorker.create_worker( - worker_id=10, finished_ops_queue=finished_ops_queue, cpu_blocks=cpu_handle.get_tensor(), ssd_files=ssd_handle.get_file_list(), @@ -195,18 +190,18 @@ def bench_worker(args): shuffle_ids = bench_config.shuffle_ids if transfer_type == TransferType.H2D or transfer_type == TransferType.D2H: - worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config, bench_config) + worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config) elif transfer_type == TransferType.H2DISK or transfer_type == TransferType.DISK2H: - worker_handle, finished_ops_queue = create_cpu_ssd_worker(model_config, cache_config, bench_config) + worker_handle, finished_ops_queue = create_cpu_ssd_worker(model_config, cache_config) else: raise ValueError(f"Unsupported transfer type: {transfer_type} for benchmark, " f"currently only support {TransferType.H2D.name}, {TransferType.D2H.name}, " f"{TransferType.H2DISK.name}, {TransferType.DISK2H.name}") if shuffle_ids: - block_ids = torch.randperm(num_blocks_to_transfer) + block_ids = torch.randperm(num_blocks_to_transfer).numpy() else: - block_ids = torch.arange(num_blocks_to_transfer) + block_ids = torch.arange(num_blocks_to_transfer).numpy() transfer_op = TransferOp( transfer_type=transfer_type, diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 5ae4237df4..6f35b7a12c 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -16,9 +16,9 @@ "use_gds": false, "use_pinned_memory": true, "gpu_kv_layout_type": "LAYERWISE", - "cpu_kv_layout_type": "LAYERWISE", - "ssd_kv_layout_type": "LAYERWISE", - "remote_kv_layout_type": "LAYERWISE", + "cpu_kv_layout_type": "BLOCKWISE", + "ssd_kv_layout_type": "BLOCKWISE", + "remote_kv_layout_type": "BLOCKWISE", "num_cpu_blocks": 2048, "num_ssd_blocks": 4096, "num_remote_blocks": null, @@ -28,7 +28,7 @@ "transfer_sms_d2h": 8, "max_blocks_per_file": 32000, "ssd_cache_dir": "./ssd_cache1/", - "ssd_cache_iouring_entries": 512, + "ssd_cache_iouring_entries": 32, "ssd_cache_iouring_flags": 0, "remote_cache_size_mode": "file_size", "remote_file_size": null, diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index ee64ce99f8..35061e2eba 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -62,6 +62,8 @@ def __post_init__(self) -> None: with TransferOp._lock: self.op_id = TransferOp._next_op_id TransferOp._next_op_id += 1 + assert self.src_block_ids.dtype == np.int64 + assert self.dst_block_ids.dtype == np.int64 class TransferOpGraph: diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 52bd97d019..5d5d567801 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -198,3 +198,11 @@ def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: return self.dp_client.try_wait(task_ids) else: return self.kv_task_engine.try_wait(task_ids) + + # Only for testing + def _clear_cpu_cache(self) -> None: + if self.server_client_mode: + flexkv_logger.error("clear_cache is not supported in server client mode") + return + else: + self.kv_task_engine._clear_cpu_cache() diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index 86360e1920..c96a55127e 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -13,7 +13,7 @@ from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger -from flexkv.common.transfer import TransferOpGraph +from flexkv.common.transfer import TransferOpGraph, get_nvtx_default_color from flexkv.common.tracer import FlexKVTracer from flexkv.cache.cache_engine import GlobalCacheEngine from flexkv.transfer_manager import TransferManagerHandle @@ -60,12 +60,14 @@ class KVTask: def is_completed(self) -> bool: return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED] -def task_status_to_response_status(task_status: TaskStatus) -> KVResponseStatus: - return { - TaskStatus.COMPLETED: KVResponseStatus.SUCCESS, - TaskStatus.CANCELLED: KVResponseStatus.CANCELLED, - TaskStatus.FAILED: KVResponseStatus.FAILED, - }[task_status] +TASK_STATUS_TO_RESPONSE_STATUS = { + TaskStatus.COMPLETED: KVResponseStatus.SUCCESS, + TaskStatus.CANCELLED: KVResponseStatus.CANCELLED, + TaskStatus.FAILED: KVResponseStatus.FAILED, +} + +def convert_to_response_status(task_status: TaskStatus) -> KVResponseStatus: + return TASK_STATUS_TO_RESPONSE_STATUS[task_status] class KVTaskManager: def __init__(self, @@ -97,9 +99,10 @@ def __init__(self, self.graph_to_task: Dict[int, int] = {} self.task_id_counter = 0 - self.task_id_lock = threading.Lock() + self.running_tasks: int = 0 + def start(self) -> None: self.transfer_handle.start() @@ -185,6 +188,7 @@ def _launch_task(self, task_id: int) -> None: raise ValueError(f"Task {task_id} status is {task.status}, cannot launch") transfer_graph = task.graph task.status = TaskStatus.RUNNING + nvtx.mark(f"launch task: task_id={task_id}, graph_id={transfer_graph.graph_id}") if transfer_graph.num_ops > 0: self.transfer_handle.submit(transfer_graph) @@ -377,7 +381,7 @@ def _wait_impl(self, elif self.check_completed(task_id, completely=completely): self.tasks[task_id].status = TaskStatus.COMPLETED # TODO is this correct? return_responses[task_id] = KVResponse( - status=task_status_to_response_status(self.tasks[task_id].status), + status=convert_to_response_status(self.tasks[task_id].status), task_id=task_id, return_mask=self.tasks[task_id].return_mask ) @@ -399,6 +403,7 @@ def _wait_impl(self, def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] + nvtx.mark(f"try_wait task_ids: {task_ids}") return_responses = self._wait_impl(task_ids, timeout=0.0, completely=False) @@ -408,11 +413,11 @@ def wait(self, task_ids: Union[int, List[int]], timeout: float = 20.0, completely: bool = False) -> Dict[int, KVResponse]: - nvtx.mark(f"wait task_ids: {task_ids}") if isinstance(task_ids, int): task_ids = [task_ids] + nvtx.push_range(f"wait task_ids: {task_ids}", color=get_nvtx_default_color()) return_responses = self._wait_impl(task_ids, timeout, completely=completely) - nvtx.mark(f"wait task_ids: {task_ids} done") + nvtx.pop_range() return return_responses def get_match(self, @@ -444,7 +449,7 @@ def _get_match_impl(self, layer_granularity = self.model_config.num_layers if task_id == -1: task_id = self._gen_task_id() - nvtx.mark(f"GET task_id: {task_id}") + nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) self.create_get_task(task_id, token_ids, slot_mapping, @@ -453,6 +458,7 @@ def _get_match_impl(self, dp_id, is_fake_slot_mapping=is_fake_slot_mapping) self._process_empty_graph(task_id) + nvtx.pop_range() return task_id, self.tasks[task_id].return_mask def put_match(self, @@ -479,7 +485,7 @@ def _put_match_impl(self, token_mask = np.ones_like(token_ids) if task_id == -1: task_id = self._gen_task_id() - nvtx.mark(f"PUT task_id: {task_id}") + nvtx.push_range(f"put match: task_id={task_id}", color=get_nvtx_default_color()) self.create_put_task(task_id, token_ids, slot_mapping, @@ -487,6 +493,7 @@ def _put_match_impl(self, dp_id, is_fake_slot_mapping=is_fake_slot_mapping) self._process_empty_graph(task_id) + nvtx.pop_range() return task_id, self.tasks[task_id].return_mask def launch_tasks(self, @@ -502,3 +509,6 @@ def cancel_tasks(self, task_ids: Union[int, List[int]]) -> None: task_ids = [task_ids] for task_id in task_ids: self._cancel_task(task_id) + + def _clear_cpu_cache(self) -> None: + self.cache_engine.cpu_cache_engine.reset() diff --git a/setup.py b/setup.py index 1c080a00be..a0b549e443 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from torch.utils import cpp_extension -build_dir = "build" +build_dir = os.path.abspath("build") os.makedirs(build_dir, exist_ok=True) # Check if we're in debug mode using environment variable diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index f105f95857..b997e2eb21 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -234,7 +234,6 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) read_idx = i - initial_write_num token_ids, block_ids, dp_id = request_pairs[read_idx] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) - print(f"token_ids: {token_ids}, block_ids: {block_ids}, dp_id: {dp_id}, slot_mapping: {slot_mapping}") request_id, _ = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, diff --git a/tests/test_utils.py b/tests/test_utils.py index 2530f0291a..d3e3cb4fbb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -49,9 +49,9 @@ "pcfs_parent_nodeid": 144115188075855883 # Using transfer engine value for consistency }, 'use_ce_transfer_h2d': False, - 'use_ce_transfer_d2h': True, - 'transfer_sms_h2d': 4, - 'transfer_sms_d2h': 4, + 'use_ce_transfer_d2h': False, + 'transfer_sms_h2d': 8, + 'transfer_sms_d2h': 8, } DEFAULT_TEST_CONFIG = { From 419156e98fbf83132277c6e9718e9d2862b16de6 Mon Sep 17 00:00:00 2001 From: charliecgxu Date: Mon, 25 Aug 2025 21:16:12 +0800 Subject: [PATCH 11/42] fix broken cpp radix tree support for cache engine (#84) * adjust index accel to new cache engine data struct Signed-off-by: charliecgxu * fix broken tests for cache engine Signed-off-by: charliecgxu --------- Signed-off-by: charliecgxu --- flexkv/cache/cache_engine.py | 15 +++++++-------- tests/test_cache_engine.py | 11 ++++++----- tests/test_cache_engine_accel.py | 11 ++++++----- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 3ecade2076..2a581de9b3 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -18,7 +18,7 @@ from functools import partial from queue import Queue from typing import List, Tuple, Optional, Dict, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import torch @@ -41,11 +41,10 @@ class MatchResultAccel: last_ready_node: Optional['CRadixNode'] = None last_node: Optional['CRadixNode'] = None last_node_matched_length: int = 0 - physical_blocks: torch.Tensor = torch.empty(0, dtype=torch.int64) + physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) def __post_init__(self) -> None: assert self.physical_blocks.ndim == 1 - assert self.physical_blocks.dtype == torch.int64 class CacheEngineAccel: def __init__(self, @@ -82,7 +81,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: return MatchResultAccel(match_result.num_ready_matched_blocks, match_result.num_matched_blocks, match_result.last_ready_node, match_result.last_node, match_result.last_node_matched_length, - torch.tensor(match_result.physical_blocks, dtype=torch.int64)) + torch.tensor(match_result.physical_blocks, dtype=torch.int64).numpy()) def insert(self, sequence_meta: SequenceMeta, @@ -92,13 +91,13 @@ def insert(self, match_result: Optional[MatchResultAccel] = None) -> Optional[CRadixNode]: sequence_meta.gen_hashes() if match_result is None: - return self.index.insert(physical_block_ids, + return self.index.insert(torch.from_numpy(physical_block_ids).to(torch.int64), torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), sequence_meta.num_blocks, num_insert_blocks, is_ready) else: - return self.index.insert(physical_block_ids, + return self.index.insert(torch.from_numpy(physical_block_ids).to(torch.int64), torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), sequence_meta.num_blocks, num_insert_blocks, @@ -126,7 +125,7 @@ def take(self, num_evicted = self.index.evict(target_blocks, evict_block_num) if num_evicted != evict_block_num: target_blocks.resize_(num_evicted) - self.mempool.recycle_blocks(target_blocks) + self.mempool.recycle_blocks(target_blocks.numpy()) if protected_node is not None: self.index.unlock(protected_node) @@ -137,7 +136,7 @@ def take(self, num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) return self.mempool.allocate_blocks(num_allocated_blocks) - def recycle(self, physical_blocks: torch.Tensor) -> None: + def recycle(self, physical_blocks: np.ndarray) -> None: self.mempool.recycle_blocks(physical_blocks) class CacheEngine: diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index d76a6fe05c..e224d4305b 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -16,6 +16,7 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngine: 'device_type': DeviceType.CPU, 'num_total_blocks': 64, 'tokens_per_block': 4, + 'evict_ratio': 0.05, } default_config_kwargs.update(param) return CacheEngine(**default_config_kwargs) @@ -23,11 +24,11 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngine: @pytest.mark.parametrize( "config, should_raise", [ - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), - ({'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'evict_ratio': 0.05, 'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), ] ) def test_config_init(config: dict, should_raise: bool): diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py index 464aa31d34..15fef43ec3 100644 --- a/tests/test_cache_engine_accel.py +++ b/tests/test_cache_engine_accel.py @@ -16,6 +16,7 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngineAccel: 'device_type': DeviceType.CPU, 'num_total_blocks': 64, 'tokens_per_block': 4, + 'evict_ratio': 0.05, } default_config_kwargs.update(param) return CacheEngineAccel(**default_config_kwargs) @@ -23,11 +24,11 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngineAccel: @pytest.mark.parametrize( "config, should_raise", [ - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), - ({'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'evict_ratio': 0.05, 'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), ] ) def test_config_init(config: dict, should_raise: bool): From 3581cdadbef7761e760c825fb995e66732663fd3 Mon Sep 17 00:00:00 2001 From: lilgao Date: Mon, 25 Aug 2025 20:42:18 +0800 Subject: [PATCH 12/42] ci: trigger on main and dev Signed-off-by: lilgao --- .github/workflows/publish.yml | 4 ++-- build.sh | 4 ++-- setup.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index eb5bc9b75f..6c2eea8d02 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -5,9 +5,9 @@ name: flexkv ci on: pull_request: - branches: [ "main", "feat/lilgao/ci"] + branches: [ "main", "dev"] push: - branches: [ "main", "feat/lilgao/ci"] + branches: [ "main", "dev"] # Needed to create wheel and upload assets permissions: diff --git a/build.sh b/build.sh index 247dd62405..b976fec7d9 100755 --- a/build.sh +++ b/build.sh @@ -64,9 +64,9 @@ echo "=== Build and installation completed successfully in ${BUILD_TYPE} mode == echo "You can now run tests directly without setting LD_LIBRARY_PATH manually" if [ "$BUILD_TYPE" = "debug" ]; then - FLEXKV_DEBUG=1 pip install --no-build-isolation -e . + FLEXKV_DEBUG=1 pip install -v --no-build-isolation -e . elif [ "$BUILD_TYPE" = "release" ]; then FLEXKV_DEBUG=0 python setup.py bdist_wheel -v else - FLEXKV_DEBUG=0 pip install --no-build-isolation -e . + FLEXKV_DEBUG=0 pip install -v --no-build-isolation -e . fi diff --git a/setup.py b/setup.py index a0b549e443..d11ffc5143 100755 --- a/setup.py +++ b/setup.py @@ -2,13 +2,13 @@ import shutil import sys -from Cython.Build import cythonize + from setuptools import find_packages, setup from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension -build_dir = os.path.abspath("build") +build_dir = "build" os.makedirs(build_dir, exist_ok=True) # Check if we're in debug mode using environment variable @@ -78,6 +78,8 @@ "flexkv/**/benchmark_*.py", "flexkv/benchmark/**/*.py", "flexkv/benchmark/test_kvmanager.py"] + # Import cython when debug is turned off. + from Cython.Build import cythonize cythonized_modules = cythonize( python_files, exclude=excluded_files, From 08571115aa3c28d3b6d7ad7c8d42204f31592b68 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 25 Aug 2025 22:21:49 -0700 Subject: [PATCH 13/42] fix direct io --- benchmarks/benchmark_single_batch.py | 3 +- benchmarks/example_config.json | 3 +- csrc/transfer_ssd.h | 42 +++++++++++----------------- tests/test_utils.py | 2 +- 4 files changed, 22 insertions(+), 28 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index fc77263926..ca9e59c60a 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -93,8 +93,9 @@ def benchmark_flexkv(model_config: ModelConfig, tp_client_processes.append(tp_client_process) while not kvmanager.is_ready(): - time.sleep(1) + time.sleep(3) flexkv_logger.info("waiting for flexkv to be ready") + flexkv_logger.info("flexkv is ready") batch_sequence_tensor = [] batch_slot_mapping = [] diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 6f35b7a12c..4a710f41ca 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -41,6 +41,7 @@ "trace_max_file_size_mb": 100, "trace_max_files": 5, "trace_flush_interval_ms": 1000, - "evict_ratio": 0.05 + "evict_ratio": 0.05, + "index_accel": true } } diff --git a/csrc/transfer_ssd.h b/csrc/transfer_ssd.h index f795f4de54..9d6f9ba57e 100644 --- a/csrc/transfer_ssd.h +++ b/csrc/transfer_ssd.h @@ -173,11 +173,10 @@ class IOUring { class SSDIOCTX { public: - SSDIOCTX(std::map>& ssd_files, - int num_devices, int iouring_entries, int iouring_flags) - : iouring(iouring_entries, iouring_flags), - fds_buffer_io(num_devices), - fds_direct_io(num_devices) { + SSDIOCTX(std::map> &ssd_files, int num_devices, + int iouring_entries, int iouring_flags) + : iouring(iouring_entries, iouring_flags), fds_buffer_io(num_devices), + fds_direct_io(num_devices) { int i, j, fd_buffer_io, fd_direct_io; @@ -190,7 +189,8 @@ class SSDIOCTX { fd_direct_io = open(ssd_files[i][j].c_str(), O_RDWR | O_DIRECT); if (fd_buffer_io < 0 || fd_direct_io < 0) { - std::cerr << "open file failed, path = " << ssd_files[i][j] << std::endl; + std::cerr << "open file failed, path = " << ssd_files[i][j] + << std::endl; throw std::runtime_error("Failed to open file"); } else { posix_fadvise(fd_buffer_io, 0, 0, POSIX_FADV_SEQUENTIAL); @@ -221,20 +221,14 @@ class SSDIOCTX { } } - int get_num_devices() { - return num_devices; - } + int get_num_devices() { return num_devices; } - int get_num_files_per_device() { - return num_files_per_device; - } + int get_num_files_per_device() { return num_files_per_device; } - IOUring &get_iouring() { - return iouring; - } + IOUring &get_iouring() { return iouring; } std::vector> &get_fds(bool is_read, bool is_direct) { - if (is_read && is_direct) { + if (is_direct) { return fds_direct_io; } else { return fds_buffer_io; @@ -250,15 +244,13 @@ class SSDIOCTX { std::vector> fds_direct_io; }; - void transfer_kv_blocks_ssd( - SSDIOCTX &ioctx, - const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, - const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids, - int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, - int64_t ssd_layer_stride_in_bytes, int64_t ssd_kv_stride_in_bytes, - int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read, - int num_blocks_per_file, int round_robin = 1, - int num_threads_per_device = 16, bool is_mla = false); + SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, + int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids, + const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes, + int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, + int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes, + int64_t block_stride_in_bytes, bool is_read, int num_blocks_per_file, + int round_robin = 1, int num_threads_per_device = 16, bool is_mla = false); } // namespace flexkv diff --git a/tests/test_utils.py b/tests/test_utils.py index d3e3cb4fbb..c31fb068eb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -40,7 +40,7 @@ 'enable_trace': False, 'use_pinned_memory': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], - 'ssd_cache_iouring_entries': 0, + 'ssd_cache_iouring_entries': 32, 'remote_cache_path': ["remote_cache1", "remote_cache2"], 'remote_config_custom': { "pcfs_fsid": "f_l91fz6", From 10071258c62eed7c3a7d25aff75aafd0e5e61f8e Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Wed, 27 Aug 2025 17:47:38 +0800 Subject: [PATCH 14/42] quickfix for return type of reduce_tensor --- flexkv/common/memory_handle.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index d1e2838a87..5013308b40 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -3,7 +3,7 @@ import os import pickle import time -from typing import List, Callable, Any, Optional, Tuple, Union +from typing import Callable, Any, Optional, Tuple, Union from dataclasses import dataclass import torch @@ -16,7 +16,7 @@ @dataclass class TensorSharedHandle: rebuild_func: Callable - rebuild_args: List[Any] + rebuild_args: Tuple[Any] device: torch.device def __init__(self, tensor: torch.Tensor): @@ -29,7 +29,7 @@ def get_tensor(self) -> torch.Tensor: return tensor @staticmethod - def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, List[Any], torch.device]: + def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, Tuple[Any], torch.device]: device = tensor.device rebuild_func, rebuild_args = reductions.reduce_tensor(tensor) @@ -37,7 +37,7 @@ def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, List[Any], to return rebuild_func, rebuild_args, device @staticmethod - def _import_tensor_handle(rebuild_func: Callable, rebuild_args: List[Any], device: torch.device) -> torch.Tensor: + def _import_tensor_handle(rebuild_func: Callable, rebuild_args: Tuple[Any], device: torch.device) -> torch.Tensor: try: tensor = rebuild_func(*rebuild_args) From 32e7905ae764219e4164fd6067e515e43b117440 Mon Sep 17 00:00:00 2001 From: zuogan Date: Thu, 28 Aug 2025 19:18:34 +0800 Subject: [PATCH 15/42] fix bug --- .gitignore | 3 +++ flexkv/kvmanager.py | 12 +++++++++--- flexkv/kvtask.py | 26 +++++++++++++++----------- flexkv/server/request.py | 1 - 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index fd0f58caa6..03727c4ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,6 @@ cover/ # mypy .mypy_cache/ + +# VSCode +.vscode/ diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 5d5d567801..2154218cae 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -41,6 +41,7 @@ def __init__(self, self.server_client_mode = model_config.dp_size > 1 # True #just for test flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") if self.server_client_mode: + # TODO: server should only be created once but kvmanager will init in every dp rank. self.server_handle = KVServer.create_server(model_config, cache_config, gpu_register_port, server_recv_port) self.dp_client = KVDPClient(self.server_recv_port, self.model_config) else: @@ -52,6 +53,13 @@ def __init__(self, # self.server.run() # time.sleep(10) # self.dp_client = DPClient(self.server_recv_port, self.model_config) + + @property + def dp_client_id(self) -> int: + if self.server_client_mode: + return self.dp_client.dp_client_id + else: + return 0 def start(self) -> None: if not self.server_client_mode: @@ -121,7 +129,6 @@ def get_match(self, token_mask, layer_granularity, dp_id) - mask = torch.from_numpy(mask) if mask is not None else None return task_id, mask def put_async(self, @@ -155,12 +162,11 @@ def put_match(self, task_id, mask = self.dp_client.put_match(token_ids, token_mask, dp_id) else: task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id) - mask = torch.from_numpy(mask) if mask is not None else None return task_id, mask def launch(self, task_ids: Union[int, List[int]], - slot_mappings: Union[np.ndarray, List[np.ndarray]]) -> None: + slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] if not isinstance(slot_mappings, List): diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index c96a55127e..dddcd6b2e3 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -235,7 +235,7 @@ def _set_slot_mapping_impl(self, task_id: int, slot_mapping: np.ndarray) -> None task = self.tasks[task_id] if task.status != TaskStatus.UNREADY: return - graph_ids = self.cache_engine.slot_mapping_to_block_ids(slot_mapping[task.return_mask.astype(np.bool_)], + graph_ids = self.cache_engine.slot_mapping_to_block_ids(slot_mapping, self.cache_config.tokens_per_block) task.graph.set_gpu_blocks(graph_ids) task.status = TaskStatus.READY @@ -353,7 +353,9 @@ def put_async(self, def _wait_impl(self, task_ids: List[int], timeout: float = 20.0, - completely: bool = False) -> Dict[int, KVResponse]: + completely: bool = False, + only_return_finished: bool = False, + ) -> Dict[int, KVResponse]: return_responses = {} start_time = time.time() is_timeout = timeout == 0.0 @@ -386,18 +388,18 @@ def _wait_impl(self, return_mask=self.tasks[task_id].return_mask ) break - elif is_timeout: + elif only_return_finished: + break + elif time.time() - start_time > timeout: + is_timeout = True + if is_timeout: return_responses[task_id] = KVResponse( status=KVResponseStatus.TIMEOUT, task_id=task_id, return_mask=None ) break - else: - if time.time() - start_time > timeout: - is_timeout = True - break - self._update_tasks(timeout=0.001) + self._update_tasks(timeout=0.001) return return_responses def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: @@ -405,8 +407,8 @@ def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: task_ids = [task_ids] nvtx.mark(f"try_wait task_ids: {task_ids}") return_responses = self._wait_impl(task_ids, - timeout=0.0, - completely=False) + completely=False, + only_return_finished=True) return return_responses def wait(self, @@ -426,7 +428,9 @@ def get_match(self, layer_granularity: int = -1, dp_id: int = 0, task_id: int = -1) -> Tuple[int, np.ndarray]: - fake_slot_mapping = np.zeros_like(token_ids) + if token_mask is None: + token_mask = np.ones_like(token_ids, dtype=bool) + fake_slot_mapping = np.zeros_like(token_ids[token_mask]) return self._get_match_impl(token_ids, fake_slot_mapping, is_fake_slot_mapping=True, diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 8e740adca0..6f048e5515 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import torch from flexkv.common.config import ModelConfig from flexkv.common.memory_handle import TensorSharedHandle From f082e43cff3c235ba91a8e068edce252163144bf Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 28 Aug 2025 20:38:24 -0700 Subject: [PATCH 16/42] fix status bug --- flexkv/kvtask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index dddcd6b2e3..af01e00761 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -381,7 +381,6 @@ def _wait_impl(self, ) break elif self.check_completed(task_id, completely=completely): - self.tasks[task_id].status = TaskStatus.COMPLETED # TODO is this correct? return_responses[task_id] = KVResponse( status=convert_to_response_status(self.tasks[task_id].status), task_id=task_id, From 98902c79c1d283661e9ca8349ded529579f67a53 Mon Sep 17 00:00:00 2001 From: moritzxu Date: Wed, 27 Aug 2025 15:32:21 +0800 Subject: [PATCH 17/42] Using ring buffer in transfer engine to manage the src and dst block ids instead of using pin memory function inner the launch kernel for reducing the bubble --- flexkv/cache/cache_engine.py | 11 ++- flexkv/common/config.py | 1 + flexkv/common/ring_buffer.py | 152 +++++++++++++++++++++++++++++ flexkv/common/transfer.py | 5 +- flexkv/transfer/transfer_engine.py | 9 ++ flexkv/transfer/worker.py | 44 ++++++--- 6 files changed, 205 insertions(+), 17 deletions(-) create mode 100644 flexkv/common/ring_buffer.py diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 2a581de9b3..7d280c7ead 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -21,6 +21,7 @@ from dataclasses import dataclass, field import numpy as np +import nvtx import torch from flexkv.c_ext import CRadixNode, CRadixTreeIndex, CMatchResult @@ -991,6 +992,7 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + @nvtx.annotate("Match Prefix Accel", color="yellow") def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]: cpu_matched_result = MatchResultAccel() ssd_matched_result = MatchResultAccel() @@ -1000,7 +1002,8 @@ def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAcc ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result - + + @nvtx.annotate("Match Prefix", color="yellow") def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() @@ -1010,7 +1013,8 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result - + + @nvtx.annotate("Match All Prefix accel", color="yellow") def match_all_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: cpu_matched_result = MatchResultAccel() @@ -1025,7 +1029,8 @@ def match_all_accel(self, remote_matched_result = self.remote_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result, remote_matched_result - + + @nvtx.annotate("Match All Prefix", color="yellow") def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 2dc3cc8376..fbcf465727 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -14,6 +14,7 @@ class ModelConfig: head_size: int use_mla: bool = False dtype: torch.dtype = torch.bfloat16 + max_req_tokens = 163840 # parallel configs tp_size: int = 1 diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py new file mode 100644 index 0000000000..7ecd3906bd --- /dev/null +++ b/flexkv/common/ring_buffer.py @@ -0,0 +1,152 @@ +import torch +import threading +import time +import random + +from collections import OrderedDict,deque +import numpy as np +from flexkv.common.transfer import TransferOp +from flexkv.common.debug import flexkv_logger + +class PinnedMemoryRing: + def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): + self.max_task_num = max_task_num + self.max_block_num = max_block_num + self.dtype = dtype + self.time_out = 1 ## waiting time for get free slot (1s) + # create the buffer tensor + self.src_buffer_o = torch.empty((self.max_task_num, self.max_block_num), dtype = torch.int64) + self.dst_buffer_o = torch.empty((self.max_task_num, self.max_block_num), dtype = torch.int64) + # move tensor to share memory + self.src_buffer = self.src_buffer_o.share_memory_() + self.dst_buffer = self.dst_buffer_o.share_memory_() + + flexkv_logger.info(f"[PinnedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.data_ptr()}") + flexkv_logger.info(f"[PinnedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.data_ptr()}") + + self.op_slot_map = OrderedDict() ## {op_id : ring buffer slot} + self.slot_in_use = [False]*max_task_num + self.free_slots = deque(range(max_task_num)) + + self.valid_length = [0]*max_task_num + + self.lock = threading.Lock() + self.condition = threading.Condition(self.lock) + + def allocate_and_write(self, op_id: int, op: TransferOp): + """ + Allocating a slot for the op and copy src block ids and dst block ids to the buffer. + Params: + op_id: the id index of the op + op: the actual op object, which contains the block ids + Returns: + slot: the slot which is assigned to the current op + num_blocks: the valid number of blocks in the current op + """ + # firstly, determine whether the length of block ids exceeds the limit + num_blocks = op.src_descriptor.physical_block_ids.size(0) + if num_blocks > self.max_block_num: + raise ValueError(f"block_ids too large: {num_blocks} > {self.max_block_num}") + + assert op.src_descriptor.physical_block_ids.size(0) == op.dst_descriptor.physical_block_ids.size(0), \ + f"the number of src block ids ({op.src_descriptor.physical_block_ids.size(0)}) is not eaqual to" \ + f"the number of dst block ids ({op.dst_descriptor.physical_block_ids.size(0)})" + + # get the slot of empty buffer + with self.condition: + while not self.free_slots: + if not self.condition.wait(timeout=self.time_out): + raise TimeoutError("Timeout waiting for a free slot in the ring buffer") + + slot = self.free_slots.popleft() # O(1) + + # update status managers + self.slot_in_use[slot] = True + self.op_slot_map[op_id] = slot + self.valid_length[slot] = num_blocks + # print("----> ring buffer src blocks: ", op.src_descriptor.physical_block_ids) + # print("----> ring buffer dst blocks: ", op.dst_descriptor.physical_block_ids) + + # do copy + self.src_buffer[slot, :num_blocks] = op.src_descriptor.physical_block_ids + self.dst_buffer[slot, :num_blocks] = op.dst_descriptor.physical_block_ids + + # set the rest value of this buffer to -1 + if num_blocks < self.max_block_num: + self.src_buffer[slot, num_blocks:] = -1 # + self.dst_buffer[slot, num_blocks:] = -1 # + + return slot, num_blocks + + def mark_free(self, op_id: int): + """ + Free the relevant resources of corresponding op, called when op transfer completed. + Input: + op_id: the index of current op + Output: + None + """ + with self.condition: + if op_id not in self.op_slot_map: + raise KeyError(f"Task {op_id} not found in buffer") + + slot = self.op_slot_map[op_id] + if not self.slot_in_use[slot]: + raise RuntimeError(f"Slot {slot} is already free, double free detected!") + + self.slot_in_use[slot] = False + self.valid_length[slot] = 0 + self.free_slots.append(slot) + del self.op_slot_map[op_id] + + self.condition.notify() + + def get_src_block_ids(self, slot: int): + if slot < 0 or slot >= self.max_task_num: + raise IndexError(f"Invalid slot index {slot}") + return self.src_buffer[slot, :self.valid_length[slot]] + + def get_dst_block_ids(self, slot: int): + if slot < 0 or slot >= self.max_task_num: + raise IndexError(f"Invalid slot index {slot}") + return self.dst_buffer[slot, :self.valid_length[slot]] + + def get_src_buffer(self): + return self.src_buffer + + def get_dst_buffer(self): + return self.dst_buffer + + def get_buffer_size(self): + return self.max_task_num, self.max_block_num + + def status(self): + """ + Current status logger + """ + with self.lock: + used = sum(self.slot_in_use) + free = self.max_task_num - used + return {"used_slots": used, + "free_slots": free, + "capacity": self.max_task_num} + + +def producer(manager, task_id, data): + try: + print(f"Producer {task_id} trying to allocate...") + slot = manager.allocate_and_write(task_id, data) + print(f"Producer {task_id} got slot {slot}") + + time.sleep(random.uniform(0.1, 2.0)) + + manager.mark_free(task_id) + print(f"Producer {task_id} released slot {slot}") + except Exception as e: + print(f"Producer {task_id} encountered an error: {e}") + +if __name__ == "__main__": + manager = PinnedMemoryRing(4, 10) + + + \ No newline at end of file diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 35061e2eba..056d09b71d 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -54,7 +54,10 @@ class TransferOp: successors: Set[int] = field(default_factory=set) status: TransferOpStatus = TransferOpStatus.PENDING dp_id: int = 0 - + # used for get block ids inner worker process + slot_id: int = -1 + valid_block_num: int =0 + def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ self.src_block_ids.size != self.dst_block_ids.size: diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 2406098dfc..43ed0eea45 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -37,6 +37,7 @@ tpGPUCPUTransferWorker, ) from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.ring_buffer import PinnedMemoryRing class TransferEngine: @@ -71,6 +72,8 @@ def __init__(self, self._remote_handle = remote_handle self._cache_config = cache_config + self.pin_buffer = PinnedMemoryRing(500, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) + self.op_id_to_nvtx_range: Dict[int, str] = {} self.dp_size = model_config.dp_size @@ -114,6 +117,8 @@ def _init_workers(self) -> None: dtype=self.gpu_handles[i].dtype, tp_group_size=self.tp_size, dp_group_id=i, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), use_ce_transfer_h2d=self.cache_config.use_ce_transfer_h2d, use_ce_transfer_d2h=self.cache_config.use_ce_transfer_d2h, transfer_sms_h2d=self.cache_config.transfer_sms_h2d, @@ -206,6 +211,7 @@ def _scheduler_loop(self) -> None: while True: try: op_id = self.finished_ops_queue.get_nowait() + self.pin_buffer.mark_free(op_id) ## release the slot for ring buffer op = self.op_id_to_op[op_id] self.completed_queue.put((op.graph_id, op.op_id)) finished_ops.append(op) @@ -224,6 +230,9 @@ def _scheduler_loop(self) -> None: self.completed_queue.put((op.graph_id, op.op_id)) else: self.op_id_to_op[op.op_id] = op + slot, valid_block_num = self.pin_buffer.allocate_and_write(op.op_id, op) + op.slot_id = slot + op.valid_block_num = valid_block_num self._assign_op_to_worker(op) # Handle completed graphs for graph_id in completed_graph_ids: diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 84a15a3cdf..60462529ff 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -1,5 +1,5 @@ import copy -import multiprocessing as mp +import torch.multiprocessing as mp import threading import time from abc import ABC, abstractmethod @@ -55,6 +55,8 @@ class WorkerTransferOp: dst_block_ids: np.ndarray layer_id: int layer_granularity: int + slot_id: int + valid_block_num: int # successors: List[int] def __init__(self, transfer_op: TransferOp): @@ -65,6 +67,8 @@ def __init__(self, transfer_op: TransferOp): self.dst_block_ids = transfer_op.dst_block_ids self.layer_id = transfer_op.layer_id self.layer_granularity = transfer_op.layer_granularity + self.slot_id = transfer_op.slot_id + self.valid_block_num = transfer_op.valid_block_num # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): @@ -364,6 +368,8 @@ def __init__(self, dtype: torch.dtype, tp_group_size: int, dp_group_id: int, + src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, transfer_sms_h2d: int = 8, @@ -410,33 +416,42 @@ def __init__(self, self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id) + # get and pin the shared tensor for better transfer + flexkv_logger.info(f"[worker] src data ptr: {src_buffer_tensor.data_ptr()}") + flexkv_logger.info(f"[worker] dst data ptr: {dst_buffer_tensor.data_ptr()}") + self.src_pin_tensor = src_buffer_tensor + self.dst_pin_tensor = dst_buffer_tensor + + cudaHostRegister(self.src_pin_tensor) + cudaHostRegister(self.dst_pin_tensor) + def _transfer_impl(self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, )->None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2D: - gpu_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + gpu_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids use_ce_transfer = self.use_ce_transfer_h2d transfer_sms = self.transfer_sms_h2d elif transfer_type == TransferType.D2H: - gpu_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + gpu_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids use_ce_transfer = self.use_ce_transfer_d2h transfer_sms = self.transfer_sms_d2h else: raise ValueError(f"Invalid transfer type: {transfer_type} for tpGPUCPUTransferWorker") - gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() + # gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() + # cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() assert len(gpu_block_id_list) == len(cpu_block_id_list) @@ -470,10 +485,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers + slot_id = transfer_op.slot_id + valid_block_num = transfer_op.valid_block_num + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + self.src_pin_tensor[slot_id, :valid_block_num], + self.dst_pin_tensor[slot_id, :valid_block_num], transfer_op.transfer_type, layer_id, layer_granularity, From c1f70fdef4f476cb7fe8a577eae282a7cac105a8 Mon Sep 17 00:00:00 2001 From: moritzxu Date: Thu, 28 Aug 2025 21:00:53 +0800 Subject: [PATCH 18/42] refine ring_buffer and apply it to all workers --- flexkv/common/ring_buffer.py | 32 +++--- flexkv/transfer/transfer_engine.py | 23 +++- flexkv/transfer/worker.py | 176 ++++++++++++++++------------- 3 files changed, 135 insertions(+), 96 deletions(-) diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py index 7ecd3906bd..b0f3fb2d04 100644 --- a/flexkv/common/ring_buffer.py +++ b/flexkv/common/ring_buffer.py @@ -13,7 +13,7 @@ def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): self.max_task_num = max_task_num self.max_block_num = max_block_num self.dtype = dtype - self.time_out = 1 ## waiting time for get free slot (1s) + self.time_out = 0.001 ## waiting time for get free slot (1 ms) # create the buffer tensor self.src_buffer_o = torch.empty((self.max_task_num, self.max_block_num), dtype = torch.int64) self.dst_buffer_o = torch.empty((self.max_task_num, self.max_block_num), dtype = torch.int64) @@ -21,8 +21,8 @@ def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): self.src_buffer = self.src_buffer_o.share_memory_() self.dst_buffer = self.dst_buffer_o.share_memory_() - flexkv_logger.info(f"[PinnedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.data_ptr()}") - flexkv_logger.info(f"[PinnedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.data_ptr()}") + flexkv_logger.info(f"[PinnedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.storage().data_ptr()}") + flexkv_logger.info(f"[PinnedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.storage().data_ptr()}") self.op_slot_map = OrderedDict() ## {op_id : ring buffer slot} self.slot_in_use = [False]*max_task_num @@ -44,38 +44,42 @@ def allocate_and_write(self, op_id: int, op: TransferOp): num_blocks: the valid number of blocks in the current op """ # firstly, determine whether the length of block ids exceeds the limit - num_blocks = op.src_descriptor.physical_block_ids.size(0) + num_blocks = len(op.src_block_ids) if num_blocks > self.max_block_num: raise ValueError(f"block_ids too large: {num_blocks} > {self.max_block_num}") - assert op.src_descriptor.physical_block_ids.size(0) == op.dst_descriptor.physical_block_ids.size(0), \ - f"the number of src block ids ({op.src_descriptor.physical_block_ids.size(0)}) is not eaqual to" \ - f"the number of dst block ids ({op.dst_descriptor.physical_block_ids.size(0)})" + assert len(op.src_block_ids) == len(op.dst_block_ids), \ + f"the number of src block ids ({len(op.src_block_ids)}) is not eaqual to" \ + f"the number of dst block ids ({len(op.dst_block_ids)})" # get the slot of empty buffer with self.condition: while not self.free_slots: if not self.condition.wait(timeout=self.time_out): - raise TimeoutError("Timeout waiting for a free slot in the ring buffer") - + flexkv_logger.info("No empty slot in PinnedMemoryRing, transfer the block ids") + op.slot_id = -1 + op.valid_block_num = num_blocks + return -1, num_blocks + slot = self.free_slots.popleft() # O(1) # update status managers self.slot_in_use[slot] = True self.op_slot_map[op_id] = slot self.valid_length[slot] = num_blocks - # print("----> ring buffer src blocks: ", op.src_descriptor.physical_block_ids) - # print("----> ring buffer dst blocks: ", op.dst_descriptor.physical_block_ids) - + # do copy - self.src_buffer[slot, :num_blocks] = op.src_descriptor.physical_block_ids - self.dst_buffer[slot, :num_blocks] = op.dst_descriptor.physical_block_ids + self.src_buffer[slot, :num_blocks] = torch.from_numpy(op.src_block_ids).to(torch.int64) + self.dst_buffer[slot, :num_blocks] = torch.from_numpy(op.dst_block_ids).to(torch.int64) # set the rest value of this buffer to -1 if num_blocks < self.max_block_num: self.src_buffer[slot, num_blocks:] = -1 # self.dst_buffer[slot, num_blocks:] = -1 # + # update slot id and valid_block_num of current op + op.slot_id = slot + op.valid_block_num = num_blocks return slot, num_blocks def mark_free(self, op_id: int): diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 43ed0eea45..8c6d7eb61a 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -92,6 +92,8 @@ def _init_workers(self) -> None: self.gpucpu_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), gpu_blocks=self.gpu_handles[i].get_tensor_handle_list(), cpu_blocks=self._cpu_handle.get_tensor(), gpu_kv_layout=self.gpu_handles[i].kv_layout, @@ -109,6 +111,8 @@ def _init_workers(self) -> None: self.gpucpu_workers = [ tpGPUCPUTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_blocks=self._cpu_handle.get_tensor(), @@ -117,8 +121,6 @@ def _init_workers(self) -> None: dtype=self.gpu_handles[i].dtype, tp_group_size=self.tp_size, dp_group_id=i, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), use_ce_transfer_h2d=self.cache_config.use_ce_transfer_h2d, use_ce_transfer_d2h=self.cache_config.use_ce_transfer_d2h, transfer_sms_h2d=self.cache_config.transfer_sms_h2d, @@ -132,6 +134,8 @@ def _init_workers(self) -> None: if self._ssd_handle is not None and self._cpu_handle is not None: self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -142,6 +146,8 @@ def _init_workers(self) -> None: ) self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -155,6 +161,8 @@ def _init_workers(self) -> None: if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -164,6 +172,8 @@ def _init_workers(self) -> None: ) self.remotecpu_write_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, + src_buffer_tensor = self.pin_buffer.get_src_buffer(), + dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -211,8 +221,10 @@ def _scheduler_loop(self) -> None: while True: try: op_id = self.finished_ops_queue.get_nowait() - self.pin_buffer.mark_free(op_id) ## release the slot for ring buffer op = self.op_id_to_op[op_id] + # release the slot for ring buffer, only when slot_id not equal to -1 + if op.slot_id != -1: + self.pin_buffer.mark_free(op_id) self.completed_queue.put((op.graph_id, op.op_id)) finished_ops.append(op) del self.op_id_to_op[op_id] @@ -230,9 +242,8 @@ def _scheduler_loop(self) -> None: self.completed_queue.put((op.graph_id, op.op_id)) else: self.op_id_to_op[op.op_id] = op - slot, valid_block_num = self.pin_buffer.allocate_and_write(op.op_id, op) - op.slot_id = slot - op.valid_block_num = valid_block_num + # copy block ids into buffer and update slot id info + self.pin_buffer.allocate_and_write(op.op_id, op) self._assign_op_to_worker(op) # Handle completed graphs for graph_id in completed_graph_ids: diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 60462529ff..86bb735464 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -4,7 +4,7 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from multiprocessing import Queue as MPQueue, Pipe as MPPipe +from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe from multiprocessing.connection import Connection from threading import Thread from typing import List, Any, Dict, Union, Optional @@ -51,24 +51,28 @@ class WorkerTransferOp: transfer_op_id: int transfer_graph_id: int transfer_type: TransferType - src_block_ids: np.ndarray - dst_block_ids: np.ndarray layer_id: int layer_granularity: int slot_id: int valid_block_num: int + src_block_ids: np.ndarray + dst_block_ids: np.ndarray # successors: List[int] def __init__(self, transfer_op: TransferOp): self.transfer_op_id = transfer_op.op_id self.transfer_graph_id = transfer_op.graph_id self.transfer_type = transfer_op.transfer_type - self.src_block_ids = transfer_op.src_block_ids - self.dst_block_ids = transfer_op.dst_block_ids self.layer_id = transfer_op.layer_id self.layer_granularity = transfer_op.layer_granularity self.slot_id = transfer_op.slot_id self.valid_block_num = transfer_op.valid_block_num + if self.slot_id == -1: + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids + else: + self.src_block_ids = np.empty(0) + self.dst_block_ids = np.empty(0) # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): @@ -78,11 +82,20 @@ class TransferWorkerBase(ABC): def __init__(self, worker_id: int, transfer_conn: Connection, # receive end of pipe - finished_ops_queue: MPQueue): + finished_ops_queue: MPQueue, + src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor): self.worker_id = worker_id self.transfer_conn = transfer_conn # receive end of pipe self.finished_ops_queue: MPQueue[int] = finished_ops_queue + flexkv_logger.info(f"[TransferWorkerBase] src data ptr: {src_buffer_tensor.storage().data_ptr()}") + flexkv_logger.info(f"[TransferWorkerBase] dst data ptr: {dst_buffer_tensor.storage().data_ptr()}") + self.src_shared_tensor = src_buffer_tensor + self.dst_shared_tensor = dst_buffer_tensor + cudaHostRegister(self.src_shared_tensor) + cudaHostRegister(self.dst_shared_tensor) + @classmethod def _get_worker_id(cls) -> int: with cls._worker_id_lock: @@ -104,7 +117,8 @@ def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) return layer_ptrs @classmethod - def create_worker(cls, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) -> 'WorkerHandle': + def create_worker(cls, finished_ops_queue: MPQueue, src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor, *args: Any, **kwargs: Any) -> 'WorkerHandle': """Generic worker creation template method""" parent_conn, child_conn = MPPipe() # create pipe ready_event = mp.Event() @@ -112,7 +126,8 @@ def create_worker(cls, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) - process = mp.Process( target=cls._worker_process, - args=(worker_id, child_conn, finished_ops_queue, ready_event, *args), + args=(worker_id, child_conn, finished_ops_queue, src_buffer_tensor, + dst_buffer_tensor, ready_event, *args), kwargs=kwargs, daemon=True ) @@ -122,16 +137,18 @@ def create_worker(cls, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) - @classmethod def _worker_process(cls, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + src_buffer_tensor: torch.Tensor, dst_buffer_tensor: torch.Tensor, ready_event: Any, *args: Any, **kwargs: Any) -> None: - worker = cls(worker_id, transfer_conn, finished_ops_queue, *args, **kwargs) + worker = cls(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, + dst_buffer_tensor, *args, **kwargs) ready_event.set() worker.run() @abstractmethod def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, @@ -139,6 +156,20 @@ def _transfer_impl( ) -> None: pass + + def get_transfer_block_ids(self, transfer_op: WorkerTransferOp) ->tuple[torch.Tensor, torch.Tensor]: + slot_id = transfer_op.slot_id + valid_block_num = transfer_op.valid_block_num + if slot_id == -1: + assert len(transfer_op.src_block_ids) == valid_block_num and len(transfer_op.dst_block_ids) == valid_block_num + src_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64).pin_memory() + dst_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64).pin_memory() + else: + src_block_ids = self.src_shared_tensor[slot_id, :valid_block_num] + dst_block_ids = self.dst_shared_tensor[slot_id, :valid_block_num] + + return src_block_ids, dst_block_ids + def _log_transfer_performance(self, transfer_op: WorkerTransferOp, transfer_size: int, @@ -174,7 +205,7 @@ def run(self) -> None: try: nvtx.push_range(f"launch {op.transfer_type.name} op_id: {op.transfer_op_id}, " f"graph_id: {op.transfer_graph_id}, " - f"num_blocks: {len(op.src_block_ids)}", + f"num_blocks: {op.valid_block_num}", color=get_nvtx_range_color(op.transfer_graph_id)) self.launch_transfer(op) nvtx.pop_range() @@ -224,6 +255,8 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor, gpu_blocks: List[TensorSharedHandle], cpu_blocks: torch.Tensor, gpu_kv_layout: KVCacheLayout, @@ -235,7 +268,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8) -> None: # initialize worker in a new process - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) # Register CPU tensors with CUDA cudaHostRegister(cpu_blocks) self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] @@ -273,33 +306,30 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, ) -> None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2D: - gpu_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + gpu_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids use_ce_transfer = self.use_ce_transfer_h2d transfer_sms = self.transfer_sms_h2d elif transfer_type == TransferType.D2H: - gpu_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + gpu_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids use_ce_transfer = self.use_ce_transfer_d2h transfer_sms = self.transfer_sms_d2h else: raise ValueError(f"Invalid transfer type: {transfer_type} for GPUCPUTransferWorker") - gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() - assert len(gpu_block_id_list) == len(cpu_block_id_list) if len(gpu_block_id_list) == 0: @@ -335,11 +365,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + with torch.cuda.stream(self.transfer_stream): start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, layer_id, layer_granularity, @@ -347,7 +379,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -361,6 +393,8 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], cpu_blocks: torch.Tensor, gpu_kv_layout: KVCacheLayout, @@ -368,14 +402,12 @@ def __init__(self, dtype: torch.dtype, tp_group_size: int, dp_group_id: int, - src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor, use_ce_transfer_h2d: bool = False, use_ce_transfer_d2h: bool = False, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8): - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) assert len(gpu_blocks) == tp_group_size # Handle tensor import for multi-process case imported_gpu_blocks = [] @@ -416,14 +448,6 @@ def __init__(self, self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id) - # get and pin the shared tensor for better transfer - flexkv_logger.info(f"[worker] src data ptr: {src_buffer_tensor.data_ptr()}") - flexkv_logger.info(f"[worker] dst data ptr: {dst_buffer_tensor.data_ptr()}") - self.src_pin_tensor = src_buffer_tensor - self.dst_pin_tensor = dst_buffer_tensor - - cudaHostRegister(self.src_pin_tensor) - cudaHostRegister(self.dst_pin_tensor) def _transfer_impl(self, src_block_ids: torch.Tensor, @@ -450,8 +474,6 @@ def _transfer_impl(self, else: raise ValueError(f"Invalid transfer type: {transfer_type} for tpGPUCPUTransferWorker") - # gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - # cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() assert len(gpu_block_id_list) == len(cpu_block_id_list) @@ -485,13 +507,12 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers - slot_id = transfer_op.slot_id - valid_block_num = transfer_op.valid_block_num + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) start_time = time.time() self._transfer_impl( - self.src_pin_tensor[slot_id, :valid_block_num], - self.dst_pin_tensor[slot_id, :valid_block_num], + src_block_ids, + dst_block_ids, transfer_op.transfer_type, layer_id, layer_granularity, @@ -499,7 +520,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -513,6 +534,8 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor, cpu_blocks: torch.Tensor, ssd_files: Dict[int, List[str]], # ssd_device_id -> file_paths cpu_kv_layout: KVCacheLayout, @@ -520,7 +543,7 @@ def __init__(self, dtype: torch.dtype, num_blocks_per_file: int, cache_config: CacheConfig): - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) @@ -557,30 +580,26 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, ) -> None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2DISK: - ssd_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + ssd_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids elif transfer_type == TransferType.DISK2H: - ssd_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + ssd_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids else: raise ValueError(f"Invalid transfer type: {transfer_type} for CPUSSDDiskTransferWorker") - # this means partial read hit cpu and other hit ssd - # or partial write hit ssd and none hit cpu - ssd_block_id_list = torch.from_numpy(ssd_block_ids).to(dtype=torch.int64) - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64) layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) @@ -610,10 +629,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers + + src_block_ids , dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, transfer_op.layer_id, transfer_op.layer_granularity, @@ -621,7 +643,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -635,6 +657,8 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + src_buffer_tensor: torch.Tensor, + dst_buffer_tensor: torch.Tensor, cpu_blocks: List[torch.Tensor], remote_file: List[str], cpu_kv_layout: KVCacheLayout, @@ -643,7 +667,7 @@ def __init__(self, remote_config_custom: Dict[str, Any]): if transfer_kv_blocks_remote is None: raise RuntimeError("transfer_kv_blocks_remote not available, please build with FLEXKV_ENABLE_CFS=1") - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.remote_files = remote_file @@ -716,15 +740,15 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any ) -> None: - assert dst_block_ids.dtype == np.int64 - assert src_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if layer_id == -1: @@ -736,17 +760,14 @@ def _transfer_impl( # or partial write hit remote and none hit cpu if transfer_type == TransferType.H2REMOTE: - remote_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + remote_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids elif transfer_type == TransferType.REMOTE2H: - remote_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + remote_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids else: raise ValueError(f"Invalid transfer type: {transfer_type} for CPUSSDDiskTransferWorker") - remote_block_id_list = torch.from_numpy(remote_block_ids).pin_memory().to(dtype=torch.int64) - cpu_block_id_list = torch.from_numpy(cpu_block_ids).pin_memory().to(dtype=torch.int64) - layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) transfer_kv_blocks_remote( file_nodeid_list=self.file_nodeid_list, @@ -777,10 +798,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers + + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, transfer_op.layer_id, transfer_op.layer_granularity, @@ -788,7 +812,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, From 0ca63040e4157db4463f6be31841236fba8f9f71 Mon Sep 17 00:00:00 2001 From: moritzxu Date: Thu, 28 Aug 2025 21:04:08 +0800 Subject: [PATCH 19/42] rename PinnedMemoryRing to SharedMemoryRing --- flexkv/common/ring_buffer.py | 10 +++++----- flexkv/transfer/transfer_engine.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py index b0f3fb2d04..731abad8b8 100644 --- a/flexkv/common/ring_buffer.py +++ b/flexkv/common/ring_buffer.py @@ -8,7 +8,7 @@ from flexkv.common.transfer import TransferOp from flexkv.common.debug import flexkv_logger -class PinnedMemoryRing: +class SharedMemoryRing: def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): self.max_task_num = max_task_num self.max_block_num = max_block_num @@ -21,8 +21,8 @@ def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): self.src_buffer = self.src_buffer_o.share_memory_() self.dst_buffer = self.dst_buffer_o.share_memory_() - flexkv_logger.info(f"[PinnedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.storage().data_ptr()}") - flexkv_logger.info(f"[PinnedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.storage().data_ptr()}") + flexkv_logger.info(f"[SharedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.storage().data_ptr()}") + flexkv_logger.info(f"[SharedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.storage().data_ptr()}") self.op_slot_map = OrderedDict() ## {op_id : ring buffer slot} self.slot_in_use = [False]*max_task_num @@ -56,7 +56,7 @@ def allocate_and_write(self, op_id: int, op: TransferOp): with self.condition: while not self.free_slots: if not self.condition.wait(timeout=self.time_out): - flexkv_logger.info("No empty slot in PinnedMemoryRing, transfer the block ids") + flexkv_logger.info("No empty slot in SharedMemoryRing, transfer the block ids") op.slot_id = -1 op.valid_block_num = num_blocks return -1, num_blocks @@ -150,7 +150,7 @@ def producer(manager, task_id, data): print(f"Producer {task_id} encountered an error: {e}") if __name__ == "__main__": - manager = PinnedMemoryRing(4, 10) + manager = SharedMemoryRing(4, 10) \ No newline at end of file diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 8c6d7eb61a..f845b2b9bd 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -37,7 +37,7 @@ tpGPUCPUTransferWorker, ) from flexkv.common.config import CacheConfig, ModelConfig -from flexkv.common.ring_buffer import PinnedMemoryRing +from flexkv.common.ring_buffer import SharedMemoryRing class TransferEngine: @@ -72,7 +72,7 @@ def __init__(self, self._remote_handle = remote_handle self._cache_config = cache_config - self.pin_buffer = PinnedMemoryRing(500, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) + self.pin_buffer = SharedMemoryRing(500, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) self.op_id_to_nvtx_range: Dict[int, str] = {} From 277b6a39ea737ef50dbec4a8c2729dc07efda5d9 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 28 Aug 2025 21:46:53 -0700 Subject: [PATCH 20/42] allow to exceed the max_block_num --- flexkv/common/ring_buffer.py | 51 ++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py index 731abad8b8..33e6ca2ce4 100644 --- a/flexkv/common/ring_buffer.py +++ b/flexkv/common/ring_buffer.py @@ -20,7 +20,7 @@ def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): # move tensor to share memory self.src_buffer = self.src_buffer_o.share_memory_() self.dst_buffer = self.dst_buffer_o.share_memory_() - + flexkv_logger.info(f"[SharedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.storage().data_ptr()}") flexkv_logger.info(f"[SharedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.storage().data_ptr()}") @@ -29,10 +29,10 @@ def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): self.free_slots = deque(range(max_task_num)) self.valid_length = [0]*max_task_num - + self.lock = threading.Lock() self.condition = threading.Condition(self.lock) - + def allocate_and_write(self, op_id: int, op: TransferOp): """ Allocating a slot for the op and copy src block ids and dst block ids to the buffer. @@ -46,8 +46,10 @@ def allocate_and_write(self, op_id: int, op: TransferOp): # firstly, determine whether the length of block ids exceeds the limit num_blocks = len(op.src_block_ids) if num_blocks > self.max_block_num: - raise ValueError(f"block_ids too large: {num_blocks} > {self.max_block_num}") - + flexkv_logger.warning(f"block_ids too large: {num_blocks} > {self.max_block_num}, " + f"please increase the max_block_num") + return -1, num_blocks + assert len(op.src_block_ids) == len(op.dst_block_ids), \ f"the number of src block ids ({len(op.src_block_ids)}) is not eaqual to" \ f"the number of dst block ids ({len(op.dst_block_ids)})" @@ -60,9 +62,9 @@ def allocate_and_write(self, op_id: int, op: TransferOp): op.slot_id = -1 op.valid_block_num = num_blocks return -1, num_blocks - - slot = self.free_slots.popleft() # O(1) - + + slot = self.free_slots.popleft() # O(1) + # update status managers self.slot_in_use[slot] = True self.op_slot_map[op_id] = slot @@ -71,21 +73,21 @@ def allocate_and_write(self, op_id: int, op: TransferOp): # do copy self.src_buffer[slot, :num_blocks] = torch.from_numpy(op.src_block_ids).to(torch.int64) self.dst_buffer[slot, :num_blocks] = torch.from_numpy(op.dst_block_ids).to(torch.int64) - + # set the rest value of this buffer to -1 if num_blocks < self.max_block_num: - self.src_buffer[slot, num_blocks:] = -1 # - self.dst_buffer[slot, num_blocks:] = -1 # + self.src_buffer[slot, num_blocks:] = -1 # + self.dst_buffer[slot, num_blocks:] = -1 # # update slot id and valid_block_num of current op op.slot_id = slot op.valid_block_num = num_blocks return slot, num_blocks - + def mark_free(self, op_id: int): """ Free the relevant resources of corresponding op, called when op transfer completed. - Input: + Input: op_id: the index of current op Output: None @@ -93,37 +95,37 @@ def mark_free(self, op_id: int): with self.condition: if op_id not in self.op_slot_map: raise KeyError(f"Task {op_id} not found in buffer") - + slot = self.op_slot_map[op_id] if not self.slot_in_use[slot]: raise RuntimeError(f"Slot {slot} is already free, double free detected!") - + self.slot_in_use[slot] = False self.valid_length[slot] = 0 self.free_slots.append(slot) del self.op_slot_map[op_id] - + self.condition.notify() - + def get_src_block_ids(self, slot: int): if slot < 0 or slot >= self.max_task_num: raise IndexError(f"Invalid slot index {slot}") - return self.src_buffer[slot, :self.valid_length[slot]] - + return self.src_buffer[slot, :self.valid_length[slot]] + def get_dst_block_ids(self, slot: int): if slot < 0 or slot >= self.max_task_num: raise IndexError(f"Invalid slot index {slot}") - return self.dst_buffer[slot, :self.valid_length[slot]] + return self.dst_buffer[slot, :self.valid_length[slot]] def get_src_buffer(self): return self.src_buffer - + def get_dst_buffer(self): return self.dst_buffer - + def get_buffer_size(self): return self.max_task_num, self.max_block_num - + def status(self): """ Current status logger @@ -151,6 +153,3 @@ def producer(manager, task_id, data): if __name__ == "__main__": manager = SharedMemoryRing(4, 10) - - - \ No newline at end of file From d1aff964e5e548f8713a1405a4e1a96289ccdf10 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Fri, 29 Aug 2025 02:50:37 -0700 Subject: [PATCH 21/42] refactor: use hash to allocate buffer && no wait for free slot --- flexkv/common/hash_utils.py | 24 +++-- flexkv/common/ring_buffer.py | 154 ++++++++++------------------- flexkv/common/transfer.py | 8 +- flexkv/transfer/transfer_engine.py | 38 +++---- flexkv/transfer/worker.py | 88 +++++++++-------- 5 files changed, 141 insertions(+), 171 deletions(-) diff --git a/flexkv/common/hash_utils.py b/flexkv/common/hash_utils.py index 8acdc49aa6..32f0055f48 100644 --- a/flexkv/common/hash_utils.py +++ b/flexkv/common/hash_utils.py @@ -20,15 +20,17 @@ def reset(self) -> None: self.hasher.reset() def update(self, array: np.ndarray) -> None: - self.hasher.update(array) + self.hasher.update(torch.from_numpy(array)) def digest(self) -> HashType: return HashType(self.hasher.digest()) +_HASHER = Hasher() + def hash_array(array: np.ndarray) -> HashType: - hasher = Hasher() - hasher.update(array) - return HashType(hasher.digest()) + _HASHER.reset() + _HASHER.update(array) + return HashType(_HASHER.digest()) def gen_hashes(token_ids: np.ndarray, tokens_per_block: int, hasher: Optional[Hasher] = None) -> np.ndarray: block_hashes = np.zeros(token_ids.size // tokens_per_block, dtype=np.uint64) @@ -39,13 +41,15 @@ def gen_hashes(token_ids: np.ndarray, tokens_per_block: int, hasher: Optional[Ha if __name__ == "__main__": np.random.seed(0) - token_ids = np.random.randint(0, 10000, (32000, ), dtype=np.int64) + token_ids = np.random.randint(0, 10000, (1000, ), dtype=np.int64) print(f"token ids length: {token_ids.shape[0]}") - start = time.time() result = hash_array(token_ids) - end = time.time() - print(f"array hash: {result}, time: {end - start}s") start = time.time() - result2 = gen_hashes(token_ids, 16) + for i in range(1): + result = hash_array(token_ids) end = time.time() - print(f"block hashes: {result2}, time: {end - start}s") + print(f"array hash: {result}, average time: {(end - start)*1000/5}ms") + # start = time.time() + # result2 = gen_hashes(token_ids, 16) + # end = time.time() + # print(f"block hashes: {result2}, time: {(end - start)*1000}ms") diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py index 33e6ca2ce4..3f67d612ba 100644 --- a/flexkv/common/ring_buffer.py +++ b/flexkv/common/ring_buffer.py @@ -7,84 +7,69 @@ import numpy as np from flexkv.common.transfer import TransferOp from flexkv.common.debug import flexkv_logger +from flexkv.common.hash_utils import hash_array -class SharedMemoryRing: - def __init__(self, max_task_num: int, max_block_num: int, dtype = np.int64): - self.max_task_num = max_task_num + +class SharedOpPool: + def __init__(self, max_op_num: int, max_block_num: int, dtype = np.int64): + self.max_op_num = max_op_num self.max_block_num = max_block_num self.dtype = dtype - self.time_out = 0.001 ## waiting time for get free slot (1 ms) # create the buffer tensor - self.src_buffer_o = torch.empty((self.max_task_num, self.max_block_num), dtype = torch.int64) - self.dst_buffer_o = torch.empty((self.max_task_num, self.max_block_num), dtype = torch.int64) + self.buffer_o = torch.empty((self.max_op_num, self.max_block_num), dtype = torch.int64) # move tensor to share memory - self.src_buffer = self.src_buffer_o.share_memory_() - self.dst_buffer = self.dst_buffer_o.share_memory_() + self.buffer = self.buffer_o.share_memory_() - flexkv_logger.info(f"[SharedMemoryRing] block ids src_buffer data_ptr: {self.src_buffer.storage().data_ptr()}") - flexkv_logger.info(f"[SharedMemoryRing] block ids dst_buffer data_ptr: {self.dst_buffer.storage().data_ptr()}") + flexkv_logger.info(f"[SharedOpPool] block ids buffer data_ptr: {self.buffer.storage().data_ptr()}") - self.op_slot_map = OrderedDict() ## {op_id : ring buffer slot} - self.slot_in_use = [False]*max_task_num - self.free_slots = deque(range(max_task_num)) + self.free_slots = deque(range(max_op_num)) + self.slot_map = dict() # {slot_hash: slot_id} - self.valid_length = [0]*max_task_num + self.slot_ref_count = np.zeros(max_op_num, dtype=np.int32) + self.slot_hashes = [0]*max_op_num self.lock = threading.Lock() - self.condition = threading.Condition(self.lock) - def allocate_and_write(self, op_id: int, op: TransferOp): + def allocate_slot(self, block_ids: np.ndarray): """ - Allocating a slot for the op and copy src block ids and dst block ids to the buffer. + Allocating a slot for the given block ids Params: - op_id: the id index of the op - op: the actual op object, which contains the block ids + block_ids: the block ids of src address or dst address Returns: - slot: the slot which is assigned to the current op - num_blocks: the valid number of blocks in the current op + slot_id: the slot which is assigned to the given block ids, -1 if failed """ # firstly, determine whether the length of block ids exceeds the limit - num_blocks = len(op.src_block_ids) - if num_blocks > self.max_block_num: - flexkv_logger.warning(f"block_ids too large: {num_blocks} > {self.max_block_num}, " - f"please increase the max_block_num") - return -1, num_blocks + num_blocks = block_ids.size + if num_blocks > self.max_block_num or num_blocks == 0: + return -1 - assert len(op.src_block_ids) == len(op.dst_block_ids), \ - f"the number of src block ids ({len(op.src_block_ids)}) is not eaqual to" \ - f"the number of dst block ids ({len(op.dst_block_ids)})" + slot_hash = hash_array(block_ids) + reuse = False # get the slot of empty buffer - with self.condition: - while not self.free_slots: - if not self.condition.wait(timeout=self.time_out): - flexkv_logger.info("No empty slot in SharedMemoryRing, transfer the block ids") - op.slot_id = -1 - op.valid_block_num = num_blocks - return -1, num_blocks + with self.lock: + if slot_hash in self.slot_map: + slot_id = self.slot_map[slot_hash] + reuse = True + else: + if not self.free_slots: + flexkv_logger.info("No empty slot in SharedOpPool") + return -1 - slot = self.free_slots.popleft() # O(1) + slot_id = self.free_slots.popleft() + self.slot_map[slot_hash] = slot_id # update status managers - self.slot_in_use[slot] = True - self.op_slot_map[op_id] = slot - self.valid_length[slot] = num_blocks + self.slot_ref_count[slot_id] += 1 + self.slot_hashes[slot_id] = slot_hash # do copy - self.src_buffer[slot, :num_blocks] = torch.from_numpy(op.src_block_ids).to(torch.int64) - self.dst_buffer[slot, :num_blocks] = torch.from_numpy(op.dst_block_ids).to(torch.int64) + if not reuse: + self.buffer[slot_id, :num_blocks] = torch.from_numpy(block_ids).to(torch.int64) - # set the rest value of this buffer to -1 - if num_blocks < self.max_block_num: - self.src_buffer[slot, num_blocks:] = -1 # - self.dst_buffer[slot, num_blocks:] = -1 # + return slot_id - # update slot id and valid_block_num of current op - op.slot_id = slot - op.valid_block_num = num_blocks - return slot, num_blocks - - def mark_free(self, op_id: int): + def free_slot(self, slot_id: int): """ Free the relevant resources of corresponding op, called when op transfer completed. Input: @@ -92,64 +77,33 @@ def mark_free(self, op_id: int): Output: None """ - with self.condition: - if op_id not in self.op_slot_map: - raise KeyError(f"Task {op_id} not found in buffer") - - slot = self.op_slot_map[op_id] - if not self.slot_in_use[slot]: - raise RuntimeError(f"Slot {slot} is already free, double free detected!") - - self.slot_in_use[slot] = False - self.valid_length[slot] = 0 - self.free_slots.append(slot) - del self.op_slot_map[op_id] - - self.condition.notify() - - def get_src_block_ids(self, slot: int): - if slot < 0 or slot >= self.max_task_num: - raise IndexError(f"Invalid slot index {slot}") - return self.src_buffer[slot, :self.valid_length[slot]] - - def get_dst_block_ids(self, slot: int): - if slot < 0 or slot >= self.max_task_num: - raise IndexError(f"Invalid slot index {slot}") - return self.dst_buffer[slot, :self.valid_length[slot]] - - def get_src_buffer(self): - return self.src_buffer - - def get_dst_buffer(self): - return self.dst_buffer + with self.lock: + slot_hash = self.slot_hashes[slot_id] + if slot_hash not in self.slot_map: + raise RuntimeError(f"Slot {slot_id} is not in use, double free detected!") + self.slot_ref_count[slot_id] -= 1 + assert self.slot_ref_count[slot_id] >= 0, f"Slot {slot_id} ref count is negative" + if self.slot_ref_count[slot_id] == 0: + self.free_slots.append(slot_id) + del self.slot_map[slot_hash] + + def get_buffer(self): + return self.buffer def get_buffer_size(self): - return self.max_task_num, self.max_block_num + return self.max_op_num, self.max_block_num def status(self): """ Current status logger """ with self.lock: - used = sum(self.slot_in_use) - free = self.max_task_num - used + used = len(self.slot_map) + free = self.max_op_num - used return {"used_slots": used, "free_slots": free, - "capacity": self.max_task_num} - - -def producer(manager, task_id, data): - try: - print(f"Producer {task_id} trying to allocate...") - slot = manager.allocate_and_write(task_id, data) - print(f"Producer {task_id} got slot {slot}") - - time.sleep(random.uniform(0.1, 2.0)) + "capacity": self.max_op_num} - manager.mark_free(task_id) - print(f"Producer {task_id} released slot {slot}") - except Exception as e: - print(f"Producer {task_id} encountered an error: {e}") if __name__ == "__main__": - manager = SharedMemoryRing(4, 10) + manager = SharedOpPool(4, 10) diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 056d09b71d..91229b7834 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -55,9 +55,10 @@ class TransferOp: status: TransferOpStatus = TransferOpStatus.PENDING dp_id: int = 0 # used for get block ids inner worker process - slot_id: int = -1 - valid_block_num: int =0 - + src_slot_id: int = -1 + dst_slot_id: int = -1 + valid_block_num: int = 0 + def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ self.src_block_ids.size != self.dst_block_ids.size: @@ -67,6 +68,7 @@ def __post_init__(self) -> None: TransferOp._next_op_id += 1 assert self.src_block_ids.dtype == np.int64 assert self.dst_block_ids.dtype == np.int64 + self.valid_block_num = self.src_block_ids.size class TransferOpGraph: diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index f845b2b9bd..10f4690ade 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -37,9 +37,19 @@ tpGPUCPUTransferWorker, ) from flexkv.common.config import CacheConfig, ModelConfig -from flexkv.common.ring_buffer import SharedMemoryRing +from flexkv.common.ring_buffer import SharedOpPool +def register_op_to_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + op.src_slot_id = pin_buffer.allocate_slot(op.src_block_ids) + op.dst_slot_id = pin_buffer.allocate_slot(op.dst_block_ids) + +def free_op_from_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + if op.src_slot_id != -1: + pin_buffer.free_slot(op.src_slot_id) + if op.dst_slot_id != -1: + pin_buffer.free_slot(op.dst_slot_id) + class TransferEngine: def __init__(self, gpu_handles: List[StorageHandle], @@ -72,7 +82,7 @@ def __init__(self, self._remote_handle = remote_handle self._cache_config = cache_config - self.pin_buffer = SharedMemoryRing(500, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) + self.pin_buffer = SharedOpPool(2048, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) self.op_id_to_nvtx_range: Dict[int, str] = {} @@ -92,8 +102,7 @@ def _init_workers(self) -> None: self.gpucpu_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), + op_buffer_tensor = self.pin_buffer.get_buffer(), gpu_blocks=self.gpu_handles[i].get_tensor_handle_list(), cpu_blocks=self._cpu_handle.get_tensor(), gpu_kv_layout=self.gpu_handles[i].kv_layout, @@ -111,8 +120,7 @@ def _init_workers(self) -> None: self.gpucpu_workers = [ tpGPUCPUTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), + op_buffer_tensor = self.pin_buffer.get_buffer(), gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_blocks=self._cpu_handle.get_tensor(), @@ -134,8 +142,7 @@ def _init_workers(self) -> None: if self._ssd_handle is not None and self._cpu_handle is not None: self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -146,8 +153,7 @@ def _init_workers(self) -> None: ) self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -161,8 +167,7 @@ def _init_workers(self) -> None: if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -172,8 +177,7 @@ def _init_workers(self) -> None: ) self.remotecpu_write_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( finished_ops_queue=self.finished_ops_queue, - src_buffer_tensor = self.pin_buffer.get_src_buffer(), - dst_buffer_tensor = self.pin_buffer.get_dst_buffer(), + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -222,9 +226,7 @@ def _scheduler_loop(self) -> None: try: op_id = self.finished_ops_queue.get_nowait() op = self.op_id_to_op[op_id] - # release the slot for ring buffer, only when slot_id not equal to -1 - if op.slot_id != -1: - self.pin_buffer.mark_free(op_id) + free_op_from_buffer(op, self.pin_buffer) self.completed_queue.put((op.graph_id, op.op_id)) finished_ops.append(op) del self.op_id_to_op[op_id] @@ -243,7 +245,7 @@ def _scheduler_loop(self) -> None: else: self.op_id_to_op[op.op_id] = op # copy block ids into buffer and update slot id info - self.pin_buffer.allocate_and_write(op.op_id, op) + register_op_to_buffer(op, self.pin_buffer) self._assign_op_to_worker(op) # Handle completed graphs for graph_id in completed_graph_ids: diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 86bb735464..35cd31bc46 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -53,7 +53,8 @@ class WorkerTransferOp: transfer_type: TransferType layer_id: int layer_granularity: int - slot_id: int + src_slot_id: int + dst_slot_id: int valid_block_num: int src_block_ids: np.ndarray dst_block_ids: np.ndarray @@ -65,9 +66,10 @@ def __init__(self, transfer_op: TransferOp): self.transfer_type = transfer_op.transfer_type self.layer_id = transfer_op.layer_id self.layer_granularity = transfer_op.layer_granularity - self.slot_id = transfer_op.slot_id + self.src_slot_id = transfer_op.src_slot_id + self.dst_slot_id = transfer_op.dst_slot_id self.valid_block_num = transfer_op.valid_block_num - if self.slot_id == -1: + if self.src_slot_id == -1: self.src_block_ids = transfer_op.src_block_ids self.dst_block_ids = transfer_op.dst_block_ids else: @@ -83,18 +85,14 @@ def __init__(self, worker_id: int, transfer_conn: Connection, # receive end of pipe finished_ops_queue: MPQueue, - src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor): + op_buffer_tensor: torch.Tensor): self.worker_id = worker_id self.transfer_conn = transfer_conn # receive end of pipe self.finished_ops_queue: MPQueue[int] = finished_ops_queue - flexkv_logger.info(f"[TransferWorkerBase] src data ptr: {src_buffer_tensor.storage().data_ptr()}") - flexkv_logger.info(f"[TransferWorkerBase] dst data ptr: {dst_buffer_tensor.storage().data_ptr()}") - self.src_shared_tensor = src_buffer_tensor - self.dst_shared_tensor = dst_buffer_tensor - cudaHostRegister(self.src_shared_tensor) - cudaHostRegister(self.dst_shared_tensor) + flexkv_logger.info(f"[TransferWorkerBase] op buffer data ptr: {op_buffer_tensor.storage().data_ptr()}") + self.op_buffer_tensor = op_buffer_tensor + cudaHostRegister(self.op_buffer_tensor) @classmethod def _get_worker_id(cls) -> int: @@ -117,8 +115,10 @@ def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) return layer_ptrs @classmethod - def create_worker(cls, finished_ops_queue: MPQueue, src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor, *args: Any, **kwargs: Any) -> 'WorkerHandle': + def create_worker(cls, + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, + *args: Any, **kwargs: Any) -> 'WorkerHandle': """Generic worker creation template method""" parent_conn, child_conn = MPPipe() # create pipe ready_event = mp.Event() @@ -126,8 +126,7 @@ def create_worker(cls, finished_ops_queue: MPQueue, src_buffer_tensor: torch.Ten process = mp.Process( target=cls._worker_process, - args=(worker_id, child_conn, finished_ops_queue, src_buffer_tensor, - dst_buffer_tensor, ready_event, *args), + args=(worker_id, child_conn, finished_ops_queue, op_buffer_tensor, ready_event, *args), kwargs=kwargs, daemon=True ) @@ -137,10 +136,8 @@ def create_worker(cls, finished_ops_queue: MPQueue, src_buffer_tensor: torch.Ten @classmethod def _worker_process(cls, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - src_buffer_tensor: torch.Tensor, dst_buffer_tensor: torch.Tensor, - ready_event: Any, *args: Any, **kwargs: Any) -> None: - worker = cls(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, - dst_buffer_tensor, *args, **kwargs) + op_buffer_tensor: torch.Tensor, ready_event: Any, *args: Any, **kwargs: Any) -> None: + worker = cls(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor, *args, **kwargs) ready_event.set() worker.run() @@ -156,17 +153,32 @@ def _transfer_impl( ) -> None: pass - - def get_transfer_block_ids(self, transfer_op: WorkerTransferOp) ->tuple[torch.Tensor, torch.Tensor]: - slot_id = transfer_op.slot_id + def get_transfer_block_ids(self, + transfer_op: WorkerTransferOp, + pinned: bool = True) ->tuple[torch.Tensor, torch.Tensor]: + """ + Get transfer block ids from op buffer tensor or directly from op + Args: + transfer_op: WorkerTransferOp + pinned: whether to pin the block ids tensor + Returns: + tuple[torch.Tensor, torch.Tensor]: src_block_ids and dst_block_ids + """ + src_slot_id = transfer_op.src_slot_id + dst_slot_id = transfer_op.dst_slot_id valid_block_num = transfer_op.valid_block_num - if slot_id == -1: - assert len(transfer_op.src_block_ids) == valid_block_num and len(transfer_op.dst_block_ids) == valid_block_num - src_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64).pin_memory() - dst_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64).pin_memory() + if src_slot_id == -1: + src_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64) + if pinned: + src_block_ids = src_block_ids.pin_memory() else: - src_block_ids = self.src_shared_tensor[slot_id, :valid_block_num] - dst_block_ids = self.dst_shared_tensor[slot_id, :valid_block_num] + src_block_ids = self.op_buffer_tensor[src_slot_id, :valid_block_num] + if dst_slot_id == -1: + dst_block_ids = torch.from_numpy(transfer_op.dst_block_ids).to(dtype=torch.int64) + if pinned: + dst_block_ids = dst_block_ids.pin_memory() + else: + dst_block_ids = self.op_buffer_tensor[dst_slot_id, :valid_block_num] return src_block_ids, dst_block_ids @@ -255,8 +267,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor, + op_buffer_tensor: torch.Tensor, gpu_blocks: List[TensorSharedHandle], cpu_blocks: torch.Tensor, gpu_kv_layout: KVCacheLayout, @@ -268,7 +279,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8) -> None: # initialize worker in a new process - super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) # Register CPU tensors with CUDA cudaHostRegister(cpu_blocks) self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] @@ -393,8 +404,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor, + op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], cpu_blocks: torch.Tensor, gpu_kv_layout: KVCacheLayout, @@ -407,7 +417,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8): - super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size # Handle tensor import for multi-process case imported_gpu_blocks = [] @@ -534,8 +544,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor, + op_buffer_tensor: torch.Tensor, cpu_blocks: torch.Tensor, ssd_files: Dict[int, List[str]], # ssd_device_id -> file_paths cpu_kv_layout: KVCacheLayout, @@ -543,7 +552,7 @@ def __init__(self, dtype: torch.dtype, num_blocks_per_file: int, cache_config: CacheConfig): - super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) @@ -657,8 +666,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - src_buffer_tensor: torch.Tensor, - dst_buffer_tensor: torch.Tensor, + op_buffer_tensor: torch.Tensor, cpu_blocks: List[torch.Tensor], remote_file: List[str], cpu_kv_layout: KVCacheLayout, @@ -667,7 +675,7 @@ def __init__(self, remote_config_custom: Dict[str, Any]): if transfer_kv_blocks_remote is None: raise RuntimeError("transfer_kv_blocks_remote not available, please build with FLEXKV_ENABLE_CFS=1") - super().__init__(worker_id, transfer_conn, finished_ops_queue, src_buffer_tensor, dst_buffer_tensor) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.remote_files = remote_file From fa50901828944c47892757d95715e8ee30427cc9 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Tue, 2 Sep 2025 20:34:07 +0800 Subject: [PATCH 22/42] allow different tp ranks have different num_gpu_blocks --- flexkv/transfer_manager.py | 14 ++++---------- tests/test_kvmanager.py | 2 +- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index d6e5be00f6..1eff8ecfe6 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -31,7 +31,7 @@ def __init__(self, self.cache_config = cache_config self.gpu_register_port = gpu_register_port - self.gpu_layout: Optional[KVCacheLayout] = None + self.all_gpu_layouts: Dict[int, KVCacheLayout] = {} self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks self.context = zmq.Context(2) @@ -64,13 +64,7 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: self.client_dict[device_id] = send_to_client self.all_gpu_blocks[device_id] = req.handles - if self.gpu_layout is None: - self.gpu_layout = req.gpu_layout - elif self.gpu_layout != req.gpu_layout: - flexkv_logger.error(f"GPU {device_id} has different GPU layout: " - f"{self.gpu_layout} != {req.gpu_layout}") - raise ValueError(f"GPU {device_id} has different GPU layout: " - f"{self.gpu_layout} != {req.gpu_layout}") + self.all_gpu_layouts[device_id] = req.gpu_layout flexkv_logger.info(f"GPU {device_id} registered successfully") except Exception as e: flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") @@ -113,11 +107,11 @@ def _register_gpu_blocks_via_socket(self) -> None: def initialize_transfer_engine(self) -> None: self._register_gpu_blocks_via_socket() - assert self.gpu_layout is not None + assert len(self.all_gpu_layouts) == self.model_config.tp_size * self.model_config.dp_size assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, - self.gpu_layout, + self.all_gpu_layouts[device_id], device_id, dtype=self.model_config.dtype) self.gpu_handles = [ diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index b997e2eb21..16ac51517e 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -137,7 +137,7 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) tp_client_process = Process( target=run_tp_client, - args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks, child_conn), + args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks + tp_rank, child_conn), daemon=True ) tp_client_processes.append(tp_client_process) From 82c6e2ff3aebe12ae97b81ed47db6ced5dc3c098 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Wed, 3 Sep 2025 13:47:38 +0800 Subject: [PATCH 23/42] fix --- csrc/bindings.cpp | 6 ++---- csrc/tp_transfer_thread_group.cpp | 18 ++++++++++------- csrc/tp_transfer_thread_group.h | 13 +++++++++---- flexkv/transfer/transfer_engine.py | 2 +- flexkv/transfer/worker.py | 31 +++++++++++++++++------------- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 2918cad482..c03ad74c40 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -144,12 +144,10 @@ PYBIND11_MODULE(c_ext, m) { py::class_(m, "TPTransferThreadGroup") .def(py::init> &, - torch::Tensor &, int>()) + torch::Tensor &, int, torch::Tensor &, torch::Tensor &, torch::Tensor &>()) .def("tp_group_transfer", &flexkv::TPTransferThreadGroup::tp_group_transfer, - py::arg("gpu_block_id_tensor"), py::arg("gpu_kv_stride_in_bytes"), - py::arg("gpu_block_stride_in_bytes"), - py::arg("gpu_chunk_size_in_bytes"), py::arg("cpu_block_id_tensor"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), py::arg("cpu_kv_stride_in_bytes"), py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 8c6f2d620f..f0198395fa 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -22,8 +22,15 @@ namespace flexkv { TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, - torch::Tensor &cpu_blocks, int dp_group_id) { + torch::Tensor &cpu_blocks, int dp_group_id, + torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor) { + num_gpus_ = num_gpus; + gpu_kv_strides_in_bytes_ = static_cast(gpu_kv_strides_tensor.data_ptr()); + gpu_block_strides_in_bytes_ = static_cast(gpu_block_strides_tensor.data_ptr()); + gpu_chunk_sizes_in_bytes_ = static_cast(gpu_chunk_sizes_tensor.data_ptr()); queues_.resize(num_gpus_); mtxs_ = std::vector(num_gpus_); @@ -89,9 +96,6 @@ std::future TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx, Task task) void TPTransferThreadGroup::tp_group_transfer( const torch::Tensor &gpu_block_id_tensor, - const int64_t gpu_kv_stride_in_bytes, - const int64_t gpu_block_stride_in_bytes, - const int64_t gpu_chunk_size_in_bytes, const torch::Tensor &cpu_block_id_tensor, const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, @@ -123,14 +127,14 @@ void TPTransferThreadGroup::tp_group_transfer( static_cast(gpu_blocks_ + i * num_layers + layer_id); void *cpu_ptr = cpu_blocks_; int64_t cpu_startoff_inside_chunks = - is_mla ? 0 : i * gpu_chunk_size_in_bytes; + is_mla ? 0 : i * gpu_chunk_sizes_in_bytes_[i]; flexkv::transfer_kv_blocks( num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_layer_ptrs, gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, + gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i], cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, gpu_chunk_size_in_bytes, streams_[i], + cpu_startoff_inside_chunks, gpu_chunk_sizes_in_bytes_[i], streams_[i], transfer_sms, is_host_to_device, use_ce_transfer, is_mla ); diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 0e034cf305..3d57e569c5 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -33,13 +33,13 @@ class TPTransferThreadGroup { public: TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, - torch::Tensor &cpu_blocks, int dp_group_id); + torch::Tensor &cpu_blocks, int dp_group_id, + torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor); ~TPTransferThreadGroup(); void tp_group_transfer(const torch::Tensor &gpu_block_id_tensor, - const int64_t gpu_kv_stride_in_bytes, - const int64_t gpu_block_stride_in_bytes, - const int64_t gpu_chunk_size_in_bytes, const torch::Tensor &cpu_block_id_tensor, const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, @@ -57,6 +57,11 @@ class TPTransferThreadGroup { int dp_group_id_; void **gpu_blocks_; void *cpu_blocks_; + + int64_t *gpu_kv_strides_in_bytes_; + int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_chunk_sizes_in_bytes_; + std::vector threads_; std::vector streams_; diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 10f4690ade..1b4036cbe3 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -124,7 +124,7 @@ def _init_workers(self) -> None: gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layout=self.gpu_handles[i].kv_layout, + gpu_kv_layouts=[self.gpu_handles[i].kv_layout for i in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=self.gpu_handles[i].dtype, tp_group_size=self.tp_size, diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 35cd31bc46..de53b5acdf 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -407,7 +407,7 @@ def __init__(self, op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], cpu_blocks: torch.Tensor, - gpu_kv_layout: KVCacheLayout, + gpu_kv_layouts: List[KVCacheLayout], cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, @@ -428,7 +428,7 @@ def __init__(self, imported_gpu_blocks.append(blocks_in_one_gpu) self.gpu_blocks = imported_gpu_blocks self.dtype = dtype - self.is_mla = gpu_kv_layout.is_mla + self.is_mla = gpu_kv_layouts[0].is_mla self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size @@ -436,19 +436,22 @@ def __init__(self, cudaHostRegister(cpu_blocks) - self.num_layers = gpu_kv_layout.num_layer - gpu_kv_layout_per_layer = gpu_kv_layout.div_layer(self.num_layers) + self.num_layers = gpu_kv_layouts[0].num_layer + gpu_kv_layouts_per_layer = [gpu_kv_layout.div_layer(self.num_layers) for gpu_kv_layout in gpu_kv_layouts] - self.gpu_chunk_size_in_bytes = gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize - self.gpu_kv_stride_in_bytes = gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize - self.gpu_block_stride_in_bytes = gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + self.gpu_block_strides_in_bytes = [gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize - if not gpu_kv_layout.type == KVCacheLayoutType.LAYERWISE: + if not gpu_kv_layouts[0].type == KVCacheLayoutType.LAYERWISE: raise ValueError("Only layerwise layout is supported for GPU") self.transfer_sms_h2d = transfer_sms_h2d @@ -456,7 +459,12 @@ def __init__(self, self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h - self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id) + gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) + gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) + gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) + + self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, + gpu_kv_strides_tensor, gpu_block_strides_tensor, gpu_chunk_sizes_tensor) def _transfer_impl(self, @@ -489,12 +497,9 @@ def _transfer_impl(self, if len(gpu_block_id_list) == 0: return - + self.tp_transfer_thread_group.tp_group_transfer( gpu_block_id_list, - self.gpu_kv_stride_in_bytes, - self.gpu_block_stride_in_bytes, - self.gpu_chunk_size_in_bytes, cpu_block_id_list, self.cpu_kv_stride_in_bytes, self.cpu_layer_stride_in_bytes, From 828d36f31e77f42f46e8382deb687b9e25eae9c7 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Wed, 3 Sep 2025 15:55:43 +0800 Subject: [PATCH 24/42] create arrays of gpu_block infos in c++ to avoid invalid ptrs --- csrc/tp_transfer_thread_group.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index f0198395fa..d72f8915db 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -28,9 +28,20 @@ TPTransferThreadGroup::TPTransferThreadGroup( torch::Tensor &gpu_chunk_sizes_tensor) { num_gpus_ = num_gpus; - gpu_kv_strides_in_bytes_ = static_cast(gpu_kv_strides_tensor.data_ptr()); - gpu_block_strides_in_bytes_ = static_cast(gpu_block_strides_tensor.data_ptr()); - gpu_chunk_sizes_in_bytes_ = static_cast(gpu_chunk_sizes_tensor.data_ptr()); + + gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t* kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); + int64_t* block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t* chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; + gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; + gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + } queues_.resize(num_gpus_); mtxs_ = std::vector(num_gpus_); @@ -81,6 +92,10 @@ TPTransferThreadGroup::~TPTransferThreadGroup() { stop_pool_ = true; for (auto& cv : cvs_) cv.notify_all(); for (auto& t : threads_) if (t.joinable()) t.join(); + + delete[] gpu_kv_strides_in_bytes_; + delete[] gpu_block_strides_in_bytes_; + delete[] gpu_chunk_sizes_in_bytes_; } std::future TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx, Task task) { From f27ef1821aa80b7477aa3a2f4a50051f711d1526 Mon Sep 17 00:00:00 2001 From: zuogan Date: Wed, 3 Sep 2025 04:40:30 -0700 Subject: [PATCH 25/42] vllm v0.10.1.1 adapter --- flexkv/common/debug.py | 8 +- flexkv/integration/__init__.py | 0 flexkv/integration/config.py | 68 ++ flexkv/integration/stats.py | 100 +++ flexkv/integration/utils.py | 5 + .../vllm/0001-add-flexkv-connector.patch | 472 +++++++++++ flexkv/integration/vllm/README.md | 44 ++ flexkv/integration/vllm/__init__.py | 0 flexkv/integration/vllm/vllm_v1_adapter.py | 745 ++++++++++++++++++ flexkv/server/client.py | 22 +- flexkv/server/request.py | 2 - flexkv/transfer_manager.py | 15 - 12 files changed, 1444 insertions(+), 37 deletions(-) create mode 100644 flexkv/integration/__init__.py create mode 100644 flexkv/integration/config.py create mode 100644 flexkv/integration/stats.py create mode 100644 flexkv/integration/utils.py create mode 100644 flexkv/integration/vllm/0001-add-flexkv-connector.patch create mode 100644 flexkv/integration/vllm/README.md create mode 100644 flexkv/integration/vllm/__init__.py create mode 100644 flexkv/integration/vllm/vllm_v1_adapter.py diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index 6c52931db6..0f79cf869b 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -6,13 +6,19 @@ from typing import Optional, Callable, Any +FLEXKV_LOGGING_PREFIX = os.getenv("FLEXKV_LOGGING_PREFIX", "FLEXKV") +_FORMAT = (f"[{FLEXKV_LOGGING_PREFIX}] %(levelname)s %(asctime)s.%(msecs)03d " + " %(message)s") +_DATE_FORMAT = "%m-%d %H:%M:%S" + class FlexkvLogger: def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(message)s" + fmt=_FORMAT, + datefmt=_DATE_FORMAT, ) console_handler = logging.StreamHandler(sys.stdout) diff --git a/flexkv/integration/__init__.py b/flexkv/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py new file mode 100644 index 0000000000..76f27f5b34 --- /dev/null +++ b/flexkv/integration/config.py @@ -0,0 +1,68 @@ + +import json +import os +import torch +import tempfile +from typing import TYPE_CHECKING +from dataclasses import dataclass, field + +from flexkv.common.debug import flexkv_logger + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig, FullAttentionSpec + from vllm.config import VllmConfig + + +logger = flexkv_logger + +@dataclass +class FlexKVConfig: + #base config + server_recv_port: str + + # cache config + cache_config: dict = field(default_factory=dict) + + # model config + block_size: int = None + num_layers: int = None + num_kv_heads: int = None + head_size: int = None + dtype: torch.dtype = None + use_mla: bool = False + tp_size: int = 1 + + # log config + num_log_interval_requests: int = 200 + + @classmethod + def from_env(cls) -> 'FlexKVConfig': + config_file_path = os.getenv('FLEXKV_CONFIG_PATH', None) + logger.info(f"{config_file_path=}") + if config_file_path is None: + return cls(enable_flexkv=False, + server_recv_port="") + + assert config_file_path.endswith(".json"), "flexkv config must be a json file." + + with open(config_file_path, 'r') as f: + config_dict: dict = json.load(f) + logger.info(f"FlexKV Config Dict: {config_dict}") + + return cls( + server_recv_port=config_dict.get("server_recv_port", f"ipc:///tmp/flexkv_test"), + cache_config=config_dict.get("cache_config", {}), + num_log_interval_requests=config_dict.get("num_log_interval_requests", 200), + ) + + def post_init_from_vllm_config( + self, + vllm_config: "VllmConfig", + ): + self.num_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) + self.block_size = vllm_config.cache_config.block_size + self.num_kv_heads = vllm_config.model_config.get_total_num_kv_heads() + self.head_size = vllm_config.model_config.get_head_size() + self.dtype = vllm_config.model_config.dtype + self.use_mla = vllm_config.model_config.is_deepseek_mla + self.tp_size = vllm_config.parallel_config.tensor_parallel_size \ No newline at end of file diff --git a/flexkv/integration/stats.py b/flexkv/integration/stats.py new file mode 100644 index 0000000000..3f4d1a70f9 --- /dev/null +++ b/flexkv/integration/stats.py @@ -0,0 +1,100 @@ +import time +from dataclasses import dataclass +from collections import deque + +from flexkv.common.debug import flexkv_logger + +logger = flexkv_logger + + +@dataclass +class FlexKVStats: + num_log_interval_requests: int + + # get info + num_get_requests: int = 0 + num_get_query_tokens: int = 0 + num_gpu_matched_tokens: int = 0 + num_flexkv_matched_tokens: int = 0 + + # put info + num_put_requests: int = 0 + num_put_query_tokens: int = 0 + num_put_unmatched_tokens: int = 0 + + num_failed_requests: int = 0 + + @property + def tatal_num_requests(self) -> int: + return self.num_get_requests + self.num_put_requests + + @property + def get_gpu_match_ratio(self) -> float: + if self.num_get_query_tokens == 0: + return 0.0 + return self.num_gpu_matched_tokens / self.num_get_query_tokens + + @property + def get_flexkv_match_ratio(self) -> float: + if self.num_get_query_tokens == 0: + return 0.0 + return self.num_flexkv_matched_tokens / self.num_get_query_tokens + + @property + def get_put_token_ratio(self) -> float: + if self.num_put_unmatched_tokens == 0: + return 0.0 + return self.num_flexkv_matched_tokens / self.num_put_unmatched_tokens + + def record_get( + self, + num_prompt_tokens: int, + num_gpu_matched_tokens: int, + num_flexkv_matched_tokens: int, + ): + self.num_get_requests += 1 + self.num_get_query_tokens += num_prompt_tokens + self.num_gpu_matched_tokens += num_gpu_matched_tokens + self.num_flexkv_matched_tokens += num_flexkv_matched_tokens + if self.num_get_requests == self.num_log_interval_requests: + self.log() + self.clear() + + def record_put( + self, + num_all_tokens: int, + num_unmatched_tokens: int, + ): + self.num_put_requests += 1 + self.num_put_query_tokens += num_all_tokens + self.num_put_unmatched_tokens += num_unmatched_tokens + + def record_faild( + self, + num_failed_requests: int + ): + self.num_failed_requests += num_failed_requests + + def clear(self): + self.num_get_requests = 0 + self.num_get_query_tokens = 0 + self.num_gpu_matched_tokens = 0 + self.num_flexkv_matched_tokens = 0 + self.num_put_requests = 0 + self.num_put_query_tokens = 0 + self.num_put_unmatched_tokens = 0 + self.num_failed_requests = 0 + + def log(self): + if self.num_put_unmatched_tokens == 0: + get_put_token_ratio_str = "Nan" + else: + get_put_token_ratio_str = f"{self.get_put_token_ratio*100:.2f}%" + logger.info( + f"[FlexKV] Metric of Recent {self.num_log_interval_requests} Requests: " + f"Num Failed Request: {self.num_failed_requests}, " + f"Num Get Query Tokens: {self.num_get_query_tokens}, " + f"GPU Hit Ratio: {self.get_gpu_match_ratio*100:.2f}%, " + f"FlexKV Hit Ratio: {self.get_flexkv_match_ratio*100:.2f}%, " + f"Get/Put Token Ratio: {get_put_token_ratio_str}.") + \ No newline at end of file diff --git a/flexkv/integration/utils.py b/flexkv/integration/utils.py new file mode 100644 index 0000000000..9b107b7c23 --- /dev/null +++ b/flexkv/integration/utils.py @@ -0,0 +1,5 @@ + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) \ No newline at end of file diff --git a/flexkv/integration/vllm/0001-add-flexkv-connector.patch b/flexkv/integration/vllm/0001-add-flexkv-connector.patch new file mode 100644 index 0000000000..fc0a558d03 --- /dev/null +++ b/flexkv/integration/vllm/0001-add-flexkv-connector.patch @@ -0,0 +1,472 @@ +From a434b67b8097990f20d8c020a8c713b10dd3d5b0 Mon Sep 17 00:00:00 2001 +From: zuogan +Date: Wed, 3 Sep 2025 05:11:50 -0700 +Subject: [PATCH] add flexkv connector + +--- + .../prefix_caching_flexkv.py | 163 +++++++++++++++ + .../kv_transfer/kv_connector/factory.py | 5 + + .../kv_connector/v1/flexkv_connector.py | 191 ++++++++++++++++++ + vllm/v1/core/sched/scheduler.py | 13 +- + .../worker/kv_connector_model_runner_mixin.py | 6 +- + 5 files changed, 373 insertions(+), 5 deletions(-) + create mode 100644 examples/offline_inference/prefix_caching_flexkv.py + create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py + +diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py +new file mode 100644 +index 000000000..4cfe2ef7f +--- /dev/null ++++ b/examples/offline_inference/prefix_caching_flexkv.py +@@ -0,0 +1,163 @@ ++# SPDX-License-Identifier: Apache-2.0 ++import os ++import time ++import json ++ ++from vllm import LLM, SamplingParams ++from vllm.distributed import cleanup_dist_env_and_memory ++ ++# NOTE: This is just a running example. For benchmarking purpose, ++# please see benchmarks/benchmark_prefix_caching.py ++ ++ ++flexkv_config = { ++ "server_recv_port": "ipc:///tmp/flexkv_test", ++ "cache_config": { ++ "enable_cpu": True, ++ "num_cpu_blocks": 10240, ++ "use_pinned_memory": True ++ }, ++ "num_log_interval_requests": 200 ++} ++flexkv_config_path = "./flexkv_config.json" ++with open(flexkv_config_path, 'w') as f: ++ json.dump(flexkv_config, f) ++os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path ++ ++ ++# Common prefix. ++prefix = ( ++ "You are an expert school principal, skilled in effectively managing " ++ "faculty and staff. Draft 10-15 questions for a potential first grade " ++ "Head Teacher for my K-12, all-girls', independent school that emphasizes " ++ "community, joyful discovery, and life-long learning. The candidate is " ++ "coming in for a first-round panel interview for a 8th grade Math " ++ "teaching role. They have 5 years of previous teaching experience " ++ "as an assistant teacher at a co-ed, public school with experience " ++ "in middle school math teaching. Based on these information, fulfill " ++ "the following paragraph: ") ++ ++# Sample prompts. ++prompts = [ ++ "Hello, my name is", ++ "The president of the United States is", ++ "The capital of France is", ++ "The future of AI is", ++] ++ ++generating_prompts = [prefix + prompt for prompt in prompts] ++ ++# Create a sampling params object. ++sampling_params = SamplingParams(temperature=0.0) ++ ++kv_transfer_config = { ++ "kv_connector": "FlexKVConnectorV1", ++ "kv_role": "kv_both", ++} ++# model_path = "/data0/models/facebook/opt-125m" ++model_path = "/data0/models/Qwen3/Qwen3-32B" ++tp_size = 8 ++gpu_memory_utilization = 0.4 ++ ++ ++ ++def main(): ++ # Create an LLM without prefix caching as a baseline. ++ regular_llm = LLM(model=model_path, ++ enable_prefix_caching=False, ++ gpu_memory_utilization=gpu_memory_utilization, ++ tensor_parallel_size=tp_size ++ ) ++ ++ print("Results without `enable_prefix_caching`") ++ ++ # ruff: noqa: E501 ++ # Generate texts from the prompts. The output is a list of RequestOutput objects ++ # that contain the prompt, generated text, and other information. ++ outputs = regular_llm.generate(generating_prompts, sampling_params) ++ ++ regular_generated_texts = [] ++ # Print the outputs. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ regular_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Destroy the LLM object and free up the GPU memory. ++ del regular_llm ++ cleanup_dist_env_and_memory() ++ ++ # return ++ ++ # Create an LLM with prefix caching enabled. ++ prefix_cached_llm = LLM(model=model_path, ++ enable_prefix_caching=True, ++ gpu_memory_utilization=gpu_memory_utilization, ++ tensor_parallel_size=tp_size, ++ kv_transfer_config=kv_transfer_config, ++ ) ++ ++ # Warmup so that the shared prompt's KV cache is computed. ++ prefix_cached_llm.generate(generating_prompts[0], sampling_params) ++ ++ # wait for offload kv task finished. ++ time.sleep(2) ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `enable_prefix_caching`") ++ ++ cached_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ cached_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == cached_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ # wait for offload kv task finished. ++ time.sleep(2) ++ ++ # reset prefix cache to use flexkv ++ prefix_cached_llm.reset_prefix_cache() ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `flexkv`") ++ ++ flexkv_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ flexkv_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == flexkv_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ ++ ++if __name__ == "__main__": ++ main() ++ # pass +\ No newline at end of file +diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py +index 584fc1d65..db1cfe36b 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/factory.py ++++ b/vllm/distributed/kv_transfer/kv_connector/factory.py +@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") ++ ++KVConnectorFactory.register_connector( ++ "FlexKVConnectorV1", ++ "vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector", ++ "FlexKVConnectorV1") +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py +new file mode 100644 +index 000000000..bdfa9f321 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py +@@ -0,0 +1,191 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++from typing import TYPE_CHECKING, Any, Optional ++ ++import torch ++from flexkv.integration.vllm.vllm_v1_adapter import FlexKVConnectorV1Impl ++ ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.logger import init_logger ++from vllm.v1.core.sched.output import SchedulerOutput ++from vllm.v1.outputs import KVConnectorOutput ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.core.kv_cache_manager import KVCacheBlocks ++ from vllm.v1.request import Request ++ ++logger = init_logger(__name__) ++ ++ ++class FlexKVConnectorV1(KVConnectorBase_V1): ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ super().__init__(vllm_config=vllm_config, role=role) ++ self._flexkv_connector = FlexKVConnectorV1Impl(vllm_config, role) ++ ++ def shutdown(self): ++ self._flexkv_connector.shutdown() ++ ++ # ============================== ++ # Worker-side methods ++ # ============================== ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """ ++ Start loading the KV cache from the connector to vLLM's paged ++ KV buffer. This is called from the forward context before the ++ forward pass to enable async loading during model execution. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ ++ """ ++ self._flexkv_connector.start_load_kv(forward_context, **kwargs) ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """ ++ Block until the KV for a specific layer is loaded into vLLM's ++ paged buffer. This is called from within attention layer to ensure ++ async copying from start_load_kv is complete. ++ ++ This interface will be useful for layer-by-layer pipelining. ++ ++ Args: ++ layer_name: the name of that layer ++ """ ++ self._flexkv_connector.wait_for_layer_load(layer_name) ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """ ++ Start saving the a layer of KV cache from vLLM's paged buffer ++ to the connector. This is called from within attention layer to ++ enable async copying during execution. ++ ++ Args: ++ layer_name (str): the name of the layer. ++ kv_layer (torch.Tensor): the paged KV buffer of the current ++ layer in vLLM. ++ attn_metadata (AttentionMetadata): the attention metadata. ++ **kwargs: additional arguments for the save operation. ++ """ ++ self._flexkv_connector.save_kv_layer(layer_name, kv_layer, attn_metadata, ++ **kwargs) ++ ++ def wait_for_save(self): ++ """ ++ Block until all the save operations is done. This is called ++ as the forward context exits to ensure that the async saving ++ from save_kv_layer is complete before finishing the forward. ++ ++ This prevents overwrites of paged KV buffer before saving done. ++ """ ++ self._flexkv_connector.wait_for_save() ++ ++ def get_finished( ++ self, finished_req_ids: set[str] ++ ) -> tuple[Optional[set[str]], Optional[set[str]]]: ++ """ ++ Notifies worker-side connector ids of requests that have ++ finished generating tokens. ++ ++ Returns: ++ ids of requests that have finished asynchronous transfer ++ (requests that previously returned True from request_finished()), ++ tuple of (sending/saving ids, recving/loading ids). ++ The finished saves/sends req ids must belong to a set provided in a ++ call to this method (this call or a prior one). ++ """ ++ return self._flexkv_connector.get_finished(finished_req_ids) ++ ++ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ++ """ ++ Initialize with the KV caches. Useful for pre-registering the ++ KV Caches in the KVConnector (e.g. for NIXL). ++ ++ Args: kv_caches: ++ dictionary of layer names, kv cache ++ """ ++ self._flexkv_connector.register_kv_caches(kv_caches) ++ ++ # ============================== ++ # Scheduler-side methods ++ # ============================== ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> tuple[int, bool]: ++ """ ++ Get number of new tokens that can be loaded from the ++ external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ return self._flexkv_connector.get_num_new_matched_tokens( ++ request, num_computed_tokens) ++ ++ def update_state_after_alloc(self, request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int): ++ """ ++ Update KVConnector state after block allocation. ++ """ ++ self._flexkv_connector.update_state_after_alloc(request, blocks, ++ num_external_tokens) ++ ++ def build_connector_meta( ++ self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: ++ """ ++ Build the connector metadata for this step. ++ ++ This function should NOT modify fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ return self._flexkv_connector.build_connector_meta(scheduler_output) ++ ++ def update_connector_output(self, connector_output: KVConnectorOutput): ++ """ ++ Update KVConnector state from worker-side connectors output. ++ ++ Args: ++ connector_output (KVConnectorOutput): the worker-side ++ connectors output. ++ """ ++ self._flexkv_connector.update_connector_output(connector_output) ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ """ ++ Called when a request has finished, before its blocks are freed. ++ ++ Returns: ++ True if the request is being saved/sent asynchronously and blocks ++ should not be freed until the request_id is returned from ++ get_finished(). ++ Optional KVTransferParams to be included in the request outputs ++ returned by the engine. ++ """ ++ return self._flexkv_connector.request_finished(request, block_ids) +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 981023409..a6c8fac38 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -118,6 +118,7 @@ class Scheduler(SchedulerInterface): + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() ++ self.sending_kv_reqs: dict[str, Request] = {} + + # Encoder-related. + # Calculate encoder cache size if applicable +@@ -1029,7 +1030,8 @@ class Scheduler(SchedulerInterface): + + if not delay_free_blocks: + self._free_blocks(request) +- ++ else: ++ self.sending_kv_reqs[request.request_id] = request + return kv_xfer_params + + def _free_blocks(self, request: Request): +@@ -1041,7 +1043,7 @@ class Scheduler(SchedulerInterface): + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: +- return len(self.finished_req_ids) > 0 ++ return len(self.finished_req_ids) > 0 or len(self.sending_kv_reqs) > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() +@@ -1082,6 +1084,8 @@ class Scheduler(SchedulerInterface): + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() ++ if self.connector and hasattr(self.connector, "shutdown"): ++ self.connector.shutdown() + + ######################################################################## + # KV Connector Related Methods +@@ -1149,6 +1153,10 @@ class Scheduler(SchedulerInterface): + scheduler the request during the next step. + """ + ++ # avoid busy checking ++ if len(self.running) == 0: ++ time.sleep(0.01) ++ + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + +@@ -1158,4 +1166,5 @@ class Scheduler(SchedulerInterface): + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (kv_connector_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) ++ del self.sending_kv_reqs[req_id] + self._free_blocks(self.requests[req_id]) +diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py +index a03ebe35d..8e4460957 100644 +--- a/vllm/v1/worker/kv_connector_model_runner_mixin.py ++++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py +@@ -66,9 +66,9 @@ class KVConnectorModelRunnerMixin: + scheduler_output, wait_for_save=False) as kv_connector_output: + pass + +- if (not kv_connector_output.finished_sending +- and not kv_connector_output.finished_recving): +- return EMPTY_MODEL_RUNNER_OUTPUT ++ # if (not kv_connector_output.finished_sending ++ # and not kv_connector_output.finished_recving): ++ # return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output +-- +2.34.1 + diff --git a/flexkv/integration/vllm/README.md b/flexkv/integration/vllm/README.md new file mode 100644 index 0000000000..136f8b8682 --- /dev/null +++ b/flexkv/integration/vllm/README.md @@ -0,0 +1,44 @@ +Use flexkv on vllm v0.10.1.1 + +1. apply patch +```bash +cd vllm +git apply 0001-add-flexkv-connector.patch +``` + +2. offline test +```bash +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + "use_pinned_memory": true + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` \ No newline at end of file diff --git a/flexkv/integration/vllm/__init__.py b/flexkv/integration/vllm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py new file mode 100644 index 0000000000..951259474b --- /dev/null +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -0,0 +1,745 @@ +import os +import time +from typing import TYPE_CHECKING, Optional, Literal, Any +from dataclasses import dataclass, field +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from flexkv.kvmanager import KVManager +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ModelConfig, CacheConfig +from flexkv.common.request import KVResponseStatus +from flexkv.common.debug import flexkv_logger +from flexkv.integration.stats import FlexKVStats +from flexkv.integration.utils import cdiv +from flexkv.integration.config import FlexKVConfig + +# vllm +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, KVConnectorRole) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + from vllm.v1.outputs import KVConnectorOutput + + +logger = flexkv_logger + + +@dataclass +class FlexKVResponse: + task_id: int + task_type: Literal["get", "put"] + request: "Request" + success: bool + + +@dataclass +class FlexKVTask(ABC): + task_id: int = 0 + request: "Request" = 0 + + # slot mapping + slot_mapping: Optional[np.ndarray] = None + + # timer + match_start_time: float = 0 + match_end_time: float = 0 + task_launch_time: float = 0 + task_finished_time: float = 0 + + @property + def match_cost(self) -> float: + return (self.match_end_time - self.match_start_time) + + @property + def task_execute_cost(self) -> float: + return (self.task_finished_time - self.task_launch_time) + + @property + @abstractmethod + def task_type(self) -> str: + ... + + def __str__(self): + return (f"FlexKVTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVGetTask(FlexKVTask): + num_computed_tokens: int + num_new_matched_tokens: int + + @property + def task_type(self) -> str: + return "get" + + def __str__(self): + return (f"FlexKVGetTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"num_computed_tokens={self.num_computed_tokens}, " + f"num_new_matched_tokens={self.num_new_matched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVPutTask(FlexKVTask): + num_matched_tokens: int + num_unmatched_tokens: int + + @property + def task_type(self) -> str: + return "put" + + def __str__(self): + return (f"FlexKVPutTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"num_matched_tokens={self.num_matched_tokens}, " + f"num_unmatched_tokens={self.num_unmatched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +class FlexKVSchedulerConnector: + def __init__( + self, + flexkv_config: FlexKVConfig + ): + logger.info(f"Start init FlexKVSchedulerConnector with {flexkv_config}") + self.flexkv_config = flexkv_config + self.server_recv_port = flexkv_config.server_recv_port + self.tp_size = flexkv_config.tp_size + self.block_size = flexkv_config.block_size + self.model_config = ModelConfig( + num_layers=flexkv_config.num_layers, + num_kv_heads=flexkv_config.num_kv_heads, + head_size=flexkv_config.head_size, + use_mla=flexkv_config.use_mla, + dtype=flexkv_config.dtype, + tp_size=flexkv_config.tp_size, + ) + if "tokens_per_block" in flexkv_config.cache_config: + assert flexkv_config.cache_config.pop("tokens_per_block") == flexkv_config.block_size + self.cache_config = CacheConfig( + tokens_per_block=flexkv_config.block_size, + **flexkv_config.cache_config, + ) + self.flexkv_manager = KVManager(model_config=self.model_config, + cache_config=self.cache_config, + gpu_register_port=flexkv_config.server_recv_port) + self.flexkv_manager.start() + # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) + + # request_id -> task_id + self.req_id_to_task_dict: dict[str, int] = {} + # launched but unfinished tasks + self.get_tasks: dict[int, FlexKVGetTask] = {} + self.put_tasks: dict[int, FlexKVPutTask] = {} + # unlaunched tasks + self.tasks_to_launch: dict[int, FlexKVTask] = {} + self.tasks_to_cancel: dict[int, FlexKVTask] = {} + + self.flexkv_stats = FlexKVStats(flexkv_config.num_log_interval_requests) + + while not self.is_ready(): + logger.info(f"Waiting for flexkv init...") + time.sleep(5) + + logger.info(f"Finish init FlexKVSchedulerConnector") + + def is_ready( + self, + ) -> bool: + " Ask flexkv is ready " + return self.flexkv_manager.is_ready() + + def shutdown(self) -> None: + self.flexkv_manager.shutdown() + + @property + def dp_client_id(self) -> int: + return self.flexkv_manager.dp_client_id + + #################### + #### Get Method #### + #################### + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, bool]: A tuple containing two integer values representing the + number of new matched tokens and whether it is necessary + to get the new matched blocks from flexkv. + """ + task_id, num_new_matched_tokens = self._get_match(request=request, + num_computed_tokens=num_computed_tokens) + self.flexkv_stats.record_get(num_prompt_tokens=request.num_prompt_tokens, + num_gpu_matched_tokens=num_computed_tokens, + num_flexkv_matched_tokens=num_new_matched_tokens) + + if not self._need_to_get(num_prompt_tokens=request.num_prompt_tokens, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens): + return 0, False + + return num_new_matched_tokens, True + + + def _get_match( + self, + request: "Request", + num_computed_tokens: int = 0, + ) -> tuple[int, int]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, int]: A tuple containing two integer values representing + the task_id and number of new matched tokens. + """ + match_start_time = time.perf_counter() + num_tokens_to_get = (cdiv(request.num_prompt_tokens, self.block_size)-1)*self.block_size + token_ids = request.prompt_token_ids[:num_tokens_to_get] + + assert num_computed_tokens <= num_tokens_to_get + assert num_computed_tokens % self.block_size == 0 + + if num_tokens_to_get == num_computed_tokens: + return -1, 0 + + np_token_ids = np.array(token_ids) + np_token_mask = np.ones_like(np_token_ids, dtype=bool) + np_token_mask[:num_computed_tokens] = False + task_id, matched_mask = self.flexkv_manager.get_match(token_ids=np_token_ids, + token_mask=np_token_mask) + num_new_matched_tokens = matched_mask.sum().item() + + # Auto cancel if not call update_state_after_alloc() + match_end_time = time.perf_counter() + logger.debug(f"Get match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + if num_new_matched_tokens > 0: + self.req_id_to_task_dict[request.request_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVGetTask(task_id=task_id, + request=request, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + + logger.debug(f"FlexKV create get task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_new_matched_tokens + + def _need_to_get( + self, + num_prompt_tokens: int, + num_computed_tokens: int, + num_new_matched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to get the new matched blocks from flexkv. + """ + return num_new_matched_tokens > 0 + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_new_matched_tokens: int, + ) -> None: + """ + Compute slot mapping and prepare to launch task. + Only call after get_num_new_matched_tokens(). + + Args: + request: Request to get. + blocks: All blocks of the request. + num_new_matched_tokens: Number of new matched tokens returned by + get_num_new_matched_tokens(). + + Returns: + None. + """ + if num_new_matched_tokens == 0: + return + # prepare to launch task + task_id = self.req_id_to_task_dict[request.request_id] + task: FlexKVGetTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot_mapping + num_computed_blocks = task.num_computed_tokens // self.block_size + num_blocks_to_get = num_new_matched_tokens // self.block_size + all_block_ids = blocks.get_block_ids()[0] + block_ids_to_get = all_block_ids[num_computed_blocks:num_computed_blocks+num_blocks_to_get] + task.slot_mapping = np.array(block_ids_to_get).repeat(self.block_size)*self.block_size + + def wait_for_all_get_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all get tasks. + + Returns: + list[FlexKVResponse]: Responses of all get tasks. + """ + return self._blocking_waiting_for_tasks(self.get_tasks) + + #################### + #### Put Method #### + #################### + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> bool: + """ + Args: + request: Request to put. + blocks: All block_ids of the request. + + Returns: + bool: whether thire is unfinished task for this request. + """ + # Task not finished, can't free blocks + if request.request_id in self.req_id_to_task_dict: + return True + + # Abnormal finished, don't put + if not (request.is_finished() and request.get_finished_reason() < 2): + return False + + task_id, num_matched_tokens, num_unmatched_tokens = self._put_match(request=request) + + self.flexkv_stats.record_put(num_all_tokens=request.num_tokens, + num_unmatched_tokens=num_unmatched_tokens) + + if not self._need_to_put(num_all_tokens=request.num_tokens, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens): + return False + + # prepare to launch task + task: FlexKVPutTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot mapping + # num_blocks_to_put = (num_matched_tokens+num_unmatched_tokens) // self.block_size + num_matched_blocks = num_matched_tokens // self.block_size + num_unmatched_tokens = num_unmatched_tokens // self.block_size + block_ids_to_put = block_ids[num_matched_blocks:num_matched_blocks+num_unmatched_tokens] + task.slot_mapping = np.array(block_ids_to_put).repeat(self.block_size)*self.block_size + + return True + + def _put_match( + self, + request: "Request" + ) -> tuple[int, int, int]: + """ + Args: + request: Request to put. + + Returns: + tuple[int, int, int]: A tuple containing three integer values representing + the task_id, number of matched tokens and number of unmatched tokens. + """ + match_start_time = time.perf_counter() + num_tokens_to_put = (cdiv(request.num_tokens, self.block_size)-1)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_put] + + if num_tokens_to_put == 0: + return -1, 0, 0 + + np_token_ids = np.array(token_ids) + task_id, unmatched_mask = self.flexkv_manager.put_match(token_ids=np_token_ids) + + num_unmatched_tokens = unmatched_mask.sum().item() + num_matched_tokens = num_tokens_to_put - num_unmatched_tokens + + # Auto cancel if not need to put. + match_end_time = time.perf_counter() + logger.debug(f"Put match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + + if num_unmatched_tokens > 0: + self.req_id_to_task_dict[request.request_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVPutTask(task_id=task_id, + request=request, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + logger.debug(f"FlexKV create put task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_matched_tokens, num_unmatched_tokens + + def _need_to_put( + self, + num_all_tokens: int, + num_matched_tokens: int, + num_unmatched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to put the unmatched blocks from flexkv. + """ + return num_unmatched_tokens > 0 + + def wait_for_all_put_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all put tasks. + + Returns: + list[FlexKVResponse]: Responses of all put tasks. + """ + return self._blocking_waiting_for_tasks(self.put_tasks) + + ####################### + #### Common Method #### + ####################### + + def cancel_tasks(self) -> None: + """ + Cancel tasks in self.cancel_tasks. + Call before launch_tasks() to delete req_id in self.req_id_to_task_dict + """ + # TODO: check if this method is inproc. + if len(self.tasks_to_cancel) == 0: + return + for task in self.tasks_to_cancel.values(): + del self.req_id_to_task_dict[task.request.request_id] + logger.info(f"FlexKV Cancel task: {task}") + self.flexkv_manager.cancel(task_ids=list(self.tasks_to_cancel.keys())) + self.tasks_to_cancel.clear() + + def launch_tasks(self) -> None: + """ + Launch tasks in self.unlaunched_tasks + """ + if len(self.tasks_to_launch) == 0: + return + task_launch_time = time.perf_counter() + task_ids: list[int] = [] + slot_mappings: list[np.ndarray] = [] + + for task_id, task in self.tasks_to_launch.items(): + logger.info(f"FlexKV Launch task: {task}") + task.task_launch_time = task_launch_time + task_ids.append(task_id) + slot_mappings.append(task.slot_mapping) + if isinstance(task, FlexKVGetTask): + self.get_tasks[task_id] = task + else: + self.put_tasks[task_id] = task + self.flexkv_manager.launch(task_ids=task_ids, + slot_mappings=slot_mappings) + self.tasks_to_launch.clear() + + def query_finished_task(self) -> tuple[set[str], set[str]]: + """ + Get response of finished task. + + Returns: + list[FlexKVResponse]: Responses of finished tasks. + """ + if len(self.req_id_to_task_dict) == 0: + return set(), set() + logger.debug(f"unfinished task: {self.req_id_to_task_dict}") + task_ids = list(self.get_tasks.keys()) + list(self.put_tasks.keys()) + responses_from_manager = self.flexkv_manager.try_wait(task_ids) + task_finished_time = time.perf_counter() + # responses_to_return: list[FlexKVResponse] = [] + finished_sending = set() + finished_recving = set() + num_failed_tasks = 0 + for task_id, response in responses_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + if task_id in self.get_tasks: + task = self.get_tasks.pop(task_id) + finished_recving.add(task.request.request_id) + else: + task = self.put_tasks.pop(task_id) + finished_sending.add(task.request.request_id) + del self.req_id_to_task_dict[task.request.request_id] + task.task_finished_time = task_finished_time + if success: + logger.info(f"{task} finished successfully.") + else: + logger.error(f"{task} failed, status: {response.status}.") + num_failed_tasks += 1 + # responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + # request=task.request, success=success)) + self.flexkv_stats.record_faild(num_failed_requests=num_failed_tasks) + return finished_sending, finished_recving + + def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[FlexKVResponse]: + """ + Blocking wait for tasks in task_dict. + + Returns: + list[FlexKVResponse]: Responses of all tasks in task_dict. + """ + if len(task_dict) == 0: + return [] + + task_ids = list(task_dict.keys()) + response_from_manager = self.flexkv_manager.wait(task_ids=task_ids) + task_finished_time = time.perf_counter() + responses_to_return: list[FlexKVResponse] = [] + for task_id, response in response_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + task = task_dict.pop(task_id) + task.task_finished_time = task_finished_time + if success: + logger.info(f"{task} finished successfully.") + else: + logger.error(f"{task} failed, status: {response.status}.") + responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + request=task.request, success=success)) + return responses_to_return + + +class FlexKVWorkerConnector: + def __init__( + self, + flexkv_config: FlexKVConfig, + ): + current_device_id = torch.cuda.current_device() + self.flexkv_config = flexkv_config + logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.server_recv_port}") + self.tp_client = KVTPClient(flexkv_config.server_recv_port, 0, current_device_id) + logger.info(f"Finish init FlexKVWorkerConnector") + + def register_to_server(self, kv_caches: dict[str, torch.Tensor]): + logger.info(f"Start register kv_caches") + gpu_blocks = list(kv_caches.values()) + num_layer = len(kv_caches) + if self.flexkv_config.use_mla: + assert gpu_blocks[0].ndim == 3, ( + f"expect kv cached tensor has 3 dim but get shape={gpu_blocks[0].shape}.") + num_blocks = gpu_blocks[0].shape[0] + block_size = gpu_blocks[0].shape[1] + num_kv_heads = 1 + head_size = gpu_blocks[0].shape[2] + else: + assert gpu_blocks[0].ndim == 5, ( + f"expect kv cached tensor has 5 dim but get shape={gpu_blocks[0].shape}.") + num_blocks = gpu_blocks[0].shape[1] + block_size = gpu_blocks[0].shape[2] + num_kv_heads = gpu_blocks[0].shape[3] + head_size = gpu_blocks[0].shape[4] + gpu_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERWISE, + num_layer=num_layer, + num_block=num_blocks, + tokens_per_block=block_size, + num_head=num_kv_heads, + head_size=head_size, + is_mla=self.flexkv_config.use_mla, + ) + self.tp_client.register_to_server(gpu_blocks, gpu_layout) + logger.info(f"Finish register kv_caches") + + +class FlexKVConnectorV1Impl: + def __init__(self, vllm_config: "VllmConfig", role: "KVConnectorRole"): + self.role = role + flexkv_config = FlexKVConfig.from_env() + flexkv_config.post_init_from_vllm_config(vllm_config) + + if role == KVConnectorRole.SCHEDULER: + self.connector = FlexKVSchedulerConnector(flexkv_config) + elif role == KVConnectorRole.WORKER: + self.connector = FlexKVWorkerConnector(flexkv_config) + else: + raise ValueError(f"Unrecognized KVConnectorRole: {role}.") + + def shutdown(self): + if self.role == KVConnectorRole.SCHEDULER: + self.connector.shutdown() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + self.connector.register_to_server(kv_caches) + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self.connector.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput") -> "KVConnectorMetadata": + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + self.connector.cancel_tasks() + self.connector.launch_tasks() + return KVConnectorMetadata() + + def update_connector_output(self, connector_output: "KVConnectorOutput"): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + + finished_sending, finished_recving = self.connector.query_finished_task() + connector_output.finished_sending = finished_sending + connector_output.finished_recving = finished_recving + + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return self.connector.request_finished(request, block_ids), None \ No newline at end of file diff --git a/flexkv/server/client.py b/flexkv/server/client.py index a6db160934..1155f99630 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -227,23 +227,17 @@ def __init__( server_recv_port: str, dp_client_id: int, device_id: int, - tp_rank: int, ): # Init inter-process communication context = zmq.Context(2) self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) - self.client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - self.recv_from_server = get_zmq_socket( - context, zmq.SocketType.PULL, self.client_recv_port, True - ) self.dp_client_id = dp_client_id self.device_id = device_id - self.tp_rank = tp_rank - flexkv_logger.info(f"KVTPClient {tp_rank} of KVDPClient {self.dp_client_id} Initialized!") + flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized!") def register_to_server( self, @@ -263,23 +257,13 @@ def register_to_server( register_req = RegisterTPClientRequest( self.dp_client_id, - self.tp_rank, self.device_id, - self.client_recv_port, handles, kv_layout ) - self.send_to_server.send_pyobj(register_req) - # blocking - response: Response = self.recv_from_server.recv_pyobj() - if response.error_msg is None: - flexkv_logger.info(f"TP client of DP client {self.dp_client_id} registered successfully!") - else: - flexkv_logger.error( - f"TP client of DP client {self.dp_client_id} registeration fialed: {response.error_msg}" - ) - raise + self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) + if __name__ == "__main__": num_layers = 32 diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 6f048e5515..1b8dade803 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -18,9 +18,7 @@ class RegisterDPClientRequest: @dataclass class RegisterTPClientRequest: dp_client_id: int - tp_rank: int device_id: int - client_recv_port: str handles: List[TensorSharedHandle] gpu_layout: KVCacheLayout diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 1eff8ecfe6..4567014aaa 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -37,7 +37,6 @@ def __init__(self, self.context = zmq.Context(2) self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, gpu_register_port, True) - self.client_dict: Dict[int, zmq.Socket] = {} self.transfer_engine: Optional[TransferEngine] = None self.storage_engine = StorageEngine(self.model_config, self.cache_config) @@ -47,32 +46,18 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: if device_id in self.all_gpu_blocks: flexkv_logger.error(f"GPU {device_id} has already registered.") - response = Response(req.dp_client_id, success=False, - error_msg=f"GPU {device_id} already registered") elif device_id >= self.model_config.tp_size * self.model_config.dp_size: flexkv_logger.error(f"GPU {device_id} is larger than TP size: " f"{self.model_config.tp_size * self.model_config.dp_size}.") - response = Response(req.dp_client_id, success=False, - error_msg=f"GPU {device_id} exceeds TP size " - f"{self.model_config.tp_size * self.model_config.dp_size}") else: try: - response = Response(req.dp_client_id) - send_to_client = get_zmq_socket( - self.context, zmq.SocketType.PUSH, req.client_recv_port, False) - send_to_client.send_pyobj(response) - self.client_dict[device_id] = send_to_client self.all_gpu_blocks[device_id] = req.handles self.all_gpu_layouts[device_id] = req.gpu_layout flexkv_logger.info(f"GPU {device_id} registered successfully") except Exception as e: flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") - response = Response(req.dp_client_id, success=False, - error_msg=f"Failed to register GPU {device_id}: {e}") - if device_id in self.client_dict: - self.client_dict[device_id].send_pyobj(response) def _register_gpu_blocks_via_socket(self) -> None: try: From a77191c9c478bf22c6f5588d3421ac9cf316c0b3 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Thu, 4 Sep 2025 01:15:45 -0700 Subject: [PATCH 26/42] fix bug --- benchmarks/benchmark_single_batch.py | 2 +- csrc/tp_transfer_thread_group.cpp | 2 ++ examples/scheduler_server_example.py | 2 +- flexkv/transfer/worker.py | 1 - flexkv/transfer_manager.py | 1 - tests/test_kvmanager.py | 5 ++--- tests/test_utils.py | 2 +- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index ca9e59c60a..397030becd 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -28,7 +28,7 @@ class BenchmarkConfig: def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): """Run tp_client process""" device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) num_gpu_blocks = cache_config.num_gpu_blocks diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index d72f8915db..06cb45c4fe 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -92,6 +92,8 @@ TPTransferThreadGroup::~TPTransferThreadGroup() { stop_pool_ = true; for (auto& cv : cvs_) cv.notify_all(); for (auto& t : threads_) if (t.joinable()) t.join(); + + cudaFreeHost(gpu_blocks_); delete[] gpu_kv_strides_in_bytes_; delete[] gpu_block_strides_in_bytes_; diff --git a/examples/scheduler_server_example.py b/examples/scheduler_server_example.py index 1aae7ec298..29826afc9a 100644 --- a/examples/scheduler_server_example.py +++ b/examples/scheduler_server_example.py @@ -28,7 +28,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo # Clear cache torch.cuda.empty_cache() - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Create GPU blocks for this TP client gpu_blocks = [] diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index de53b5acdf..2dbbc1be86 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -90,7 +90,6 @@ def __init__(self, self.transfer_conn = transfer_conn # receive end of pipe self.finished_ops_queue: MPQueue[int] = finished_ops_queue - flexkv_logger.info(f"[TransferWorkerBase] op buffer data ptr: {op_buffer_tensor.storage().data_ptr()}") self.op_buffer_tensor = op_buffer_tensor cudaHostRegister(self.op_buffer_tensor) diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py index 4567014aaa..fab65f345c 100644 --- a/flexkv/transfer_manager.py +++ b/flexkv/transfer_manager.py @@ -51,7 +51,6 @@ def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: f"{self.model_config.tp_size * self.model_config.dp_size}.") else: try: - self.all_gpu_blocks[device_id] = req.handles self.all_gpu_layouts[device_id] = req.gpu_layout flexkv_logger.info(f"GPU {device_id} registered successfully") diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index 16ac51517e..f2c48a53c5 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -28,7 +28,7 @@ def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_c """Run tp_client process""" try: device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) @@ -52,7 +52,6 @@ def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_c while True: time.sleep(1) except Exception as e: - print(f"[TP Client {tp_rank}] Error occurred: {e}") if child_conn is not None: child_conn.send(None) child_conn.close() @@ -69,7 +68,7 @@ def shutdown_tp_client(tp_client_processes): @pytest.mark.parametrize("model_config", [ {'tp_size': 1, 'dp_size': 1}, - {'tp_size': 2, 'dp_size': 2}, + {'tp_size': 2, 'dp_size': 2}, {'dtype': torch.float32}, {'use_mla': True}, {'tp_size': 4, 'dp_size': 1, 'use_mla': True}, diff --git a/tests/test_utils.py b/tests/test_utils.py index c31fb068eb..ba1392eabc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -342,7 +342,7 @@ def _run_tp_client(dp_client_id, tp_rank, device_id, server_recv_port, num_layer from flexkv.server.client import KVTPClient from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Convert dtype string back to torch dtype if dtype_str == "torch.float16": dtype = torch.float16 From 1e50623a228535eadb650f28a637db37afc8f832 Mon Sep 17 00:00:00 2001 From: linhu-nv <141609318+linhu-nv@users.noreply.github.com> Date: Fri, 5 Sep 2025 10:31:30 +0800 Subject: [PATCH 27/42] server-client mode works now (#92) --- flexkv/kvmanager.py | 58 +++--- flexkv/kvtask.py | 1 + flexkv/server/client.py | 107 +++++------ flexkv/server/request.py | 10 +- flexkv/server/server.py | 374 +++++++++++++++++---------------------- tests/test_kvmanager.py | 2 +- 6 files changed, 243 insertions(+), 309 deletions(-) diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index 2154218cae..8e85ee9959 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -31,54 +31,52 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: Optional[str] = None, - server_recv_port: Optional[str] = None): + server_recv_port: Optional[str] = None, + dp_client_id: int = 0): flexkv_logger.info(f"{model_config = }") flexkv_logger.info(f"{cache_config = }") self.model_config = model_config self.cache_config = cache_config self.gpu_register_port = gpu_register_port self.server_recv_port = server_recv_port - self.server_client_mode = model_config.dp_size > 1 # True #just for test + self.server_client_mode = model_config.dp_size > 1 + self.dp_client_id = dp_client_id flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") if self.server_client_mode: - # TODO: server should only be created once but kvmanager will init in every dp rank. - self.server_handle = KVServer.create_server(model_config, cache_config, gpu_register_port, server_recv_port) - self.dp_client = KVDPClient(self.server_recv_port, self.model_config) + # server should only be created once but kvmanager will init in every dp rank. + if dp_client_id == 0: + self.server_handle = KVServer.create_server(model_config, + cache_config, + gpu_register_port, + server_recv_port) + + else: + self.server_handle = None + self.dp_client = KVDPClient(self.server_recv_port, self.model_config, dp_client_id) else: self.server_handle = None self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port) - - #def _launch_server(self) -> None: - # self.server = KVServer(self.model_config, self.cache_config, self.server_recv_port) - # self.server.run() - # time.sleep(10) - # self.dp_client = DPClient(self.server_recv_port, self.model_config) @property - def dp_client_id(self) -> int: - if self.server_client_mode: - return self.dp_client.dp_client_id - else: - return 0 + def dpclient_id(self) -> int: + return self.dp_client_id def start(self) -> None: if not self.server_client_mode: self.kv_task_engine.start() - # for server client mode, we need to do nothing, because the start is actually called - # when the server is created + else: + # send the start request to the server + self.dp_client.start_server_and_register() def is_ready(self) -> bool: if self.server_client_mode: - return self.server_handle is not None and self.server_handle.ready_event.is_set() + return self.dp_client.is_ready() else: return self.kv_task_engine.is_ready() def shutdown(self) -> None: if self.server_client_mode: - if self.server_handle is not None: - self.server_handle.shutdown() - else: - flexkv_logger.error("Shutdown server failed, server is not created") + self.dp_client.shutdown() else: self.kv_task_engine.shutdown() @@ -97,10 +95,9 @@ def get_async(self, token_mask = token_mask.numpy() if self.server_client_mode: task_id = self.dp_client.get_async(token_ids, - slot_mapping, - token_mask, - layer_granularity, - dp_id) + slot_mapping, + token_mask, + layer_granularity) else: task_id, _ = self.kv_task_engine.get_async(token_ids, slot_mapping, @@ -122,8 +119,7 @@ def get_match(self, if self.server_client_mode: task_id, mask = self.dp_client.get_match(token_ids, token_mask, - layer_granularity, - dp_id) + layer_granularity) else: task_id, mask = self.kv_task_engine.get_match(token_ids, token_mask, @@ -144,7 +140,7 @@ def put_async(self, if isinstance(token_mask, torch.Tensor): token_mask = token_mask.numpy() if self.server_client_mode: - task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask, dp_id) + task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask) else: task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_id) return task_id @@ -159,7 +155,7 @@ def put_match(self, if isinstance(token_mask, torch.Tensor): token_mask = token_mask.numpy() if self.server_client_mode: - task_id, mask = self.dp_client.put_match(token_ids, token_mask, dp_id) + task_id, mask = self.dp_client.put_match(token_ids, token_mask) else: task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id) return task_id, mask diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index af01e00761..3d661a66bf 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -64,6 +64,7 @@ def is_completed(self) -> bool: TaskStatus.COMPLETED: KVResponseStatus.SUCCESS, TaskStatus.CANCELLED: KVResponseStatus.CANCELLED, TaskStatus.FAILED: KVResponseStatus.FAILED, + TaskStatus.RUNNING: KVResponseStatus.SUCCESS, # for early return: still running, but success } def convert_to_response_status(task_status: TaskStatus) -> KVResponseStatus: diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 1155f99630..1643af98a3 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -28,6 +28,7 @@ WaitRequest, TryWaitRequest, CheckRunningRequest, + StartRequest, ShutdownRequest, Response ) @@ -37,18 +38,19 @@ def __init__( self, server_recv_port: str, model_config: ModelConfig, + dp_client_id: int, ): # Init inter-process communication context = zmq.Context(2) self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) - # is this ok when there are multiple dp clients? - client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" + self.client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" self.recv_from_server = get_zmq_socket( - context, zmq.SocketType.PULL, client_recv_port, True + context, zmq.SocketType.PULL, self.client_recv_port, True ) - self.dp_client_id = self.register_to_server(model_config, client_recv_port) + self.dp_client_id = dp_client_id + self.model_config = model_config self._task_id_range = (self.dp_client_id * 10000000, (self.dp_client_id + 1) * 10000000) self._task_id_counter = self._task_id_range[0] @@ -63,22 +65,20 @@ def _get_task_id(self) -> int: self._task_id_counter = self._task_id_range[0] return old_value + def start_server_and_register(self) -> None: + #start server and register + req = StartRequest(self.dp_client_id) + self.send_to_server.send_pyobj(req) + self.register_to_server(self.model_config, self.client_recv_port) + def register_to_server( self, model_config: ModelConfig, client_recv_port: str, - ) -> int: - register_req = RegisterDPClientRequest(model_config, client_recv_port) - + ) -> None: + register_req = RegisterDPClientRequest(self.dp_client_id, model_config, client_recv_port) self.send_to_server.send_pyobj(register_req) - # blocking - response: Response = self.recv_from_server.recv_pyobj() - if response.error_msg is None: - flexkv_logger.info(f"DP client registered successfully! DP client id: {response.dp_client_id}") - return response.dp_client_id - else: - flexkv_logger.error(f"DP client registeration fialed: {response.error_msg}") - raise + flexkv_logger.info(f"DP client {self.dp_client_id} registered to server request sent!") def is_ready( self, @@ -90,28 +90,26 @@ def is_ready( def put_async( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray], ) -> int: req = PutRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, + token_ids, + slot_mapping, + token_mask if token_mask is not None else None, self._get_task_id()) self.send_to_server.send_pyobj(req) return req.task_id def put_match( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], + token_ids: np.ndarray, + token_mask: Optional[np.ndarray], ) -> Optional[Tuple[int, np.ndarray]]: req = PutMatchRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, + token_ids, + token_mask if token_mask is not None else None, self._get_task_id()) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() @@ -123,29 +121,30 @@ def put_match( def get_async( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray], + layer_granularity: int, ) -> int: req = GetRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, - self._get_task_id()) - + token_ids, + slot_mapping, + token_mask if token_mask is not None else None, + self._get_task_id(), + layer_granularity) self.send_to_server.send_pyobj(req) return req.task_id def get_match( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], + token_ids: np.ndarray, + token_mask: Optional[np.ndarray], + layer_granularity: int, ) -> Optional[Tuple[int, np.ndarray]]: req = GetMatchRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, + token_ids, + token_mask if token_mask is not None else None, + layer_granularity, self._get_task_id()) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() @@ -155,11 +154,12 @@ def get_match( flexkv_logger.error(f"get_match failed, error_msg: {response.error_msg}") return None - def launch_task( + def launch_tasks( self, task_ids: List[int], + slot_mappings: List[np.ndarray], ) -> None: - req = LaunchTaskRequest(self.dp_client_id, task_ids) + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings) self.send_to_server.send_pyobj(req) def cancel_task( @@ -176,7 +176,6 @@ def wait( completely: bool = False, ) -> Optional[Dict[int, KVResponse]]: req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout, completely) - self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() if response.status is not None: @@ -196,30 +195,18 @@ def try_wait( self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.masks is not None: - for k, v in response.masks.items(): + if response.status is not None: + for k, v in response.status.items(): if v.status != KVResponseStatus.SUCCESS: flexkv_logger.error(f"try_wait task {k} failed: {v.status}") - return response.masks + return response.status else: flexkv_logger.error(f"try_wait tasks: {try_wait_task_ids} in DP {self.dp_client_id} failed.") return None - """ - def check_running(self) -> bool: - req = CheckRunningRequest(self.dp_client_id) - self.send_to_server.send_pyobj(req) - response: Response = self.recv_from_server.recv_pyobj() - return response.running - """ + def shutdown(self) -> None: req = ShutdownRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"DP client {self.dp_client_id} shutdown successfully.") - else: - flexkv_logger.error(f"DP client {self.dp_client_id} shutdown failed.") - raise class KVTPClient: def __init__( diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 1b8dade803..f6df970b17 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -11,6 +11,7 @@ @dataclass class RegisterDPClientRequest: + dp_client_id: int model_config: ModelConfig client_recv_port: str @@ -42,12 +43,12 @@ class GetRequest: slot_mapping: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + layer_granularity: int = -1 @dataclass class PutMatchRequest: dp_client_id: int token_ids: np.ndarray - slot_mapping: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 @@ -55,14 +56,15 @@ class PutMatchRequest: class GetMatchRequest: dp_client_id: int token_ids: np.ndarray - slot_mapping: np.ndarray token_mask: Optional[np.ndarray] + layer_granularity: int task_id: int = -1 @dataclass class LaunchTaskRequest: dp_client_id: int task_ids: List[int] + slot_mappings: List[np.ndarray] @dataclass class CancelTaskRequest: @@ -99,12 +101,14 @@ def success(self) -> bool: return self.status is not None and \ all(self.status[task_id] == KVResponseStatus.SUCCESS for task_id in self.status.keys()) +@dataclass +class StartRequest: + dp_client_id: int @dataclass class ShutdownRequest: dp_client_id: int - @dataclass class CheckRunningRequest: dp_client_id: int \ No newline at end of file diff --git a/flexkv/server/server.py b/flexkv/server/server.py index ae3f5d0af0..daf25ff7ca 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -30,67 +30,12 @@ WaitRequest, TryWaitRequest, Response, + StartRequest, ShutdownRequest, CheckRunningRequest, ) import contextlib - -def _is_port_in_use(port_or_endpoint: str) -> bool: - """ - check if the port or IPC endpoint is in use by another process - - Args: - port_or_endpoint: port number or IPC endpoint string (e.g. "ipc:///tmp/xxx" or "5555") - - Returns: - bool: True if the port/endpoint is in use, False if it is free - """ - try: - if port_or_endpoint.startswith("ipc://"): - # IPC endpoint: check if the file exists - ipc_path = port_or_endpoint[6:] # remove "ipc://" prefix - return os.path.exists(ipc_path) - elif port_or_endpoint.startswith("tcp://"): - # TCP endpoint: parse host and port - tcp_part = port_or_endpoint[6:] # remove "tcp://" prefix - if ':' in tcp_part: - host, port_str = tcp_part.rsplit(':', 1) - port = int(port_str) - else: - host = "localhost" - port = int(tcp_part) - - # try to connect to the port to check if it is in use - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - result = sock.connect_ex((host, port)) - sock.close() - return result == 0 - else: - # assume it is a pure port number - port = int(port_or_endpoint) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - result = sock.connect_ex(("localhost", port)) - sock.close() - return result == 0 - except (ValueError, OSError): - # if parsing fails or connection fails, assume the port is free - return False -""" -class TPClient: - def __init__( - self, - send_to_client: zmq.Socket, - tp_rank: int = 0, - device_id: int = 0, - ): - self.tp_rank = tp_rank - self.device_id = device_id - self.send_to_client = send_to_client -""" - class DPClient: def __init__( self, @@ -104,34 +49,6 @@ def __init__( self.send_to_client = send_to_client self.is_ready: bool = False -""" - def register_tp_client( - self, - context: zmq.Context, - client_recv_port: str, - tp_rank: int = 0, - device_id: int = 0, - ) -> None: - if tp_rank in self.tp_client_dict: - flexkv_logger.error(f"TP rank: {tp_rank} in DP client: {self.client_id} has already registered.") - raise - if tp_rank >= self.tp_size: - flexkv_logger.error(f"TP rank: {tp_rank} is larger than TP size of DP client: {self.client_id}.") - raise - - send_to_client = get_zmq_socket( - context, zmq.SocketType.PUSH, client_recv_port, False - ) - - self.tp_client_dict[tp_rank] = TPClient(send_to_client, tp_rank, device_id) - - flexkv_logger.info(f"TP rank: {tp_rank} in DP client: {self.client_id} registered successfully.") - - if len(self.tp_client_dict) == self.tp_size: - self.is_ready = True - flexkv_logger.info(f"All the TP clients in DP client: {self.client_id} has registered. " - f"DP client: {self.client_id} is ready!") -""" class ClientManager: def __init__( @@ -164,21 +81,7 @@ def register_dp_client( flexkv_logger.info(f"DP client {client_id} registered successfully") return client_id - """ - def register_tp_client( - self, - context: zmq.Context, - dp_client_id: int, - client_recv_port: str, - tp_rank: int, - device_id: int - ) -> None: - if dp_client_id not in self.client_dict: - flexkv_logger.error(f"DP client: {dp_client_id} has not registered.") - raise - self.client_dict[dp_client_id].register_tp_client( - context, client_recv_port, tp_rank, device_id) - """ + def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -200,9 +103,8 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: return False class KVServerHandle: - def __init__(self, process: mp.Process, ready_event: mp.Event): + def __init__(self, process: mp.Process): self.process = process - self.ready_event = ready_event def shutdown(self) -> None: self.process.join(timeout=5) @@ -231,26 +133,41 @@ def __init__( self.client_manager = ClientManager(max_num_dp_client=model_config.dp_size) self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, False) - self.kv_task_engine.start() - self._is_ready = True self.req_counter = 0 - - flexkv_logger.info(f"Server Initialized! [Recv Port]: {server_recv_port}") + self._is_ready = False self._running = False + + # Request handler dispatch table + self.request_handlers = { + StartRequest: self._handle_start_request, + RegisterDPClientRequest: self._handle_register_dp_client_request, + IsReadyRequest: self._handle_is_ready_request, + GetRequest: self._handle_get_request, + PutRequest: self._handle_put_request, + GetMatchRequest: self._handle_get_match_request, + PutMatchRequest: self._handle_put_match_request, + WaitRequest: self._handle_wait_request, + LaunchTaskRequest: self._handle_launch_task_request, + CancelTaskRequest: self._handle_cancel_task_request, + TryWaitRequest: self._handle_try_wait_request, + ShutdownRequest: self._handle_shutdown_request, + } def is_ready(self) -> bool: return self._is_ready + def start_server(self) -> None: + self.kv_task_engine.start() + self._is_ready = True + @staticmethod def _server_process(model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: str, - server_recv_port: str, - ready_event: mp.Event) -> None: + server_recv_port: str) -> None: server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) - ready_event.set() server.run() @classmethod @@ -259,27 +176,36 @@ def create_server(cls, cache_config: CacheConfig, gpu_register_port: str, server_recv_port: Optional[str] = None) -> 'KVServerHandle': - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - #if _is_port_in_use(server_recv_port): - # flexkv_logger.info(f"port {server_recv_port} is in use, skip starting new kvserver") - # return None - #else: - # flexkv_logger.info(f"port {server_recv_port} is free, starting new kvserver") - mp.set_start_method("spawn") - ready_event = mp.Event() + #if server_recv_port is None: + # server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" #TODO unify this + + # Set spawn method for CUDA compatibility + try: + mp.set_start_method("spawn") + except RuntimeError: + # If already set, just continue + pass process = mp.Process(target=cls._server_process, - args=(model_config, cache_config, gpu_register_port, server_recv_port, ready_event)) + args=(model_config, cache_config, gpu_register_port, server_recv_port)) process.start() flexkv_logger.info(f"KVServer process started, PID: {process.pid}") - return KVServerHandle(process, ready_event) + return KVServerHandle(process) def run(self) -> None: """Main server loop""" # TODO: handle error and return error response # TODO: support check finish + flexkv_logger.info("Servering waiting to be started") + req = self.recv_from_client.recv_pyobj() + if isinstance(req, StartRequest): + flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}, " + f"Starting server...") + self.start_server() + else: + raise TypeError(f"Received RequestType: {type(req)} from DP client " + f"{req.dp_client_id} before the start request") self._running = True while self._running: try: @@ -287,101 +213,19 @@ def run(self) -> None: req = self.recv_from_client.recv_pyobj() flexkv_logger.info(f"recv req: {type(req)}") - # register dp client - if isinstance(req, RegisterDPClientRequest): - self._verify_model_config(req.model_config) - client_id = self.client_manager.register_dp_client( - self.context, - req.client_recv_port, - req.model_config.tp_size - ) - response = Response(client_id) - result_zmq = self.client_manager.get_zmq(client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, IsReadyRequest): - is_ready = self.kv_task_engine.is_ready() - response = Response(req.dp_client_id, is_ready=is_ready) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, GetRequest): - #assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kv_task_engine.get_async( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - layer_granularity=-1, - dp_id=req.dp_client_id, - task_id=req.task_id, - ) - - elif isinstance(req, PutRequest): - #assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kv_task_engine.put_async( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - dp_id=req.dp_client_id, - task_id=req.task_id, - ) - - elif isinstance(req, GetMatchRequest): - #assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id, mask = self.kv_task_engine.get_match( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - ) - response = Response(req.dp_client_id, task_id=req_id, mask=mask) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, PutMatchRequest): - #assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id, mask = self.kv_task_engine.put_match( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - ) - response = Response(req.dp_client_id, task_id=req_id, mask=mask) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, WaitRequest): - kv_responses = self.kv_task_engine.wait( - req.wait_task_ids, - timeout=req.wait_timeout, - ) - response = Response(req.dp_client_id, status=kv_responses) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, TryWaitRequest): - kv_responses = self.kv_task_engine.try_wait( - req.try_wait_task_ids, - ) - response = Response(req.dp_client_id, status=kv_responses) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, ShutdownRequest): - flexkv_logger.info(f"Received shutdown request from DP client {req.dp_client_id}") - # Gracefully shutdown the server - self._running = False - # Send response back to client - response = Response(req.dp_client_id, success=True) - result_zmq = self.client_manager.get_zmq(req.dp_client_id) - result_zmq.send_pyobj(response) + # Use dispatch table for request handling + req_type = type(req) + handler = self.request_handlers.get(req_type) + + if handler is None: + raise TypeError(f"Unrecognized RequestType: {req_type}") + + # Call the corresponding handler method + handler(req) + + # If the request is a shutdown request, exit the loop + if req_type == ShutdownRequest: break - - else: - raise TypeError(f"Unregonized RequestType: {type(req)}") except zmq.ZMQError as e: flexkv_logger.error(f"ZMQ Error: {e}", exc_info=True) @@ -401,8 +245,110 @@ def _verify_model_config( # TODO return True + # Request Handler Methods + + def _handle_start_request(self, req: StartRequest) -> None: + """Handle start request""" + flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}") + + def _handle_register_dp_client_request(self, req: RegisterDPClientRequest) -> None: + """Handle DP client registration request""" + self._verify_model_config(req.model_config) + client_id = self.client_manager.register_dp_client( + self.context, + req.client_recv_port, + req.model_config.tp_size + ) + flexkv_logger.info(f"DP client {client_id} registered successfully") + + def _handle_is_ready_request(self, req: IsReadyRequest) -> None: + """Handle ready state check request""" + is_ready = self.kv_task_engine.is_ready() + response = Response(req.dp_client_id, is_ready=is_ready) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_get_request(self, req: GetRequest) -> None: + """Handle Get request""" + req_id = self.kv_task_engine.get_async( + task_id=req.task_id, + token_ids=req.token_ids, + slot_mapping=req.slot_mapping, + token_mask=req.token_mask, + layer_granularity=req.layer_granularity, + dp_id=req.dp_client_id, + ) + + def _handle_put_request(self, req: PutRequest) -> None: + """Handle Put request""" + req_id = self.kv_task_engine.put_async( + token_ids=req.token_ids, + slot_mapping=req.slot_mapping, + token_mask=req.token_mask, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + + def _handle_get_match_request(self, req: GetMatchRequest) -> None: + """Handle GetMatch request""" + req_id, mask = self.kv_task_engine.get_match( + token_ids=req.token_ids, + token_mask=req.token_mask, + layer_granularity=req.layer_granularity, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_put_match_request(self, req: PutMatchRequest) -> None: + """Handle PutMatch request""" + req_id, mask = self.kv_task_engine.put_match( + token_ids=req.token_ids, + token_mask=req.token_mask, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: + """Handle LaunchTask request""" + self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings) + + def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: + """Handle CancelTask request""" + self.kv_task_engine.cancel_tasks(req.task_ids) + + def _handle_wait_request(self, req: WaitRequest) -> None: + """Handle Wait request""" + kv_responses = self.kv_task_engine.wait( + req.wait_task_ids, + timeout=req.wait_timeout, + completely=req.completely, + ) + response = Response(req.dp_client_id, status=kv_responses) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_try_wait_request(self, req: TryWaitRequest) -> None: + """Handle TryWait request""" + kv_responses = self.kv_task_engine.try_wait( + req.try_wait_task_ids, + ) + response = Response(req.dp_client_id, status=kv_responses) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_shutdown_request(self, req: ShutdownRequest) -> None: + """Handle shutdown request""" + flexkv_logger.info(f"Received shutdown request from DP client {req.dp_client_id}") + self._running = False + def __del__(self) -> None: - self.kvmanager.shutdown() + self.kv_task_engine.shutdown() if __name__ == "__main__": import torch diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index f2c48a53c5..f8e74d45e2 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -79,7 +79,7 @@ def shutdown_tp_client(tp_client_processes): {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': False, 'ssd_cache_iouring_entries': 512}, {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': True, 'num_ssd_blocks': 256, 'num_remote_blocks': 512}, {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': True, - 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, + 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, ], indirect=True) @pytest.mark.parametrize("test_config", [ {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, From 38c83cef2be2ecc4e495296fa10c475cc8fbbfd2 Mon Sep 17 00:00:00 2001 From: leolingli Date: Fri, 5 Sep 2025 15:13:35 +0800 Subject: [PATCH 28/42] [docs] change vllm adapter README --- README.md | 18 +--- README_zh.md | 16 +--- docs/vllm_adapter/README_en.md | 85 +++++++++++++++++++ docs/vllm_adapter/README_zh.md | 84 ++++++++++++++++++ .../vllm_0_10_1_1-flexkv-connector.patch | 0 .../flexkv_vllm_0_10_0.patch | 0 .../flexkv_vllm_0_8_4.patch | 0 flexkv/integration/vllm/README.md | 44 ---------- 8 files changed, 172 insertions(+), 75 deletions(-) create mode 100644 docs/vllm_adapter/README_en.md create mode 100644 docs/vllm_adapter/README_zh.md rename flexkv/integration/vllm/0001-add-flexkv-connector.patch => examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch (100%) rename examples/{vllm_adaption => vllm_adaption_legacy}/flexkv_vllm_0_10_0.patch (100%) rename examples/{vllm_adaption => vllm_adaption_legacy}/flexkv_vllm_0_8_4.patch (100%) delete mode 100644 flexkv/integration/vllm/README.md diff --git a/README.md b/README.md index 56875811e3..98159bafb9 100644 --- a/README.md +++ b/README.md @@ -14,23 +14,9 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) ./build.sh ``` -### Use FlexKV with vLLM (v0.8.4) +### Use FlexKV with vLLM -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: - -```bash -# Start FlexKV as server -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# Start vLLM as client -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# Start benchmark -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. - -> **Note**: The current script is only compatible with the `main` branch. Support for the latest features in the `dev` branch is under development. +See [docs/vllm_adapter/README_en.md](docs/vllm_adapter/README_en.md) ## Design Architecture diff --git a/README_zh.md b/README_zh.md index 8223a5d9c0..f95d5c5e50 100644 --- a/README_zh.md +++ b/README_zh.md @@ -16,21 +16,7 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE ### 以 vLLM 为例使用 FlexKV -在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: - -```bash -# 启动 FlexKV 作为服务端 -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# 启动 vLLM 作为客户端 -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# 启动性能测试 -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` -在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_10_0.patch`,测试方法同上。 - -> **注意**:当前脚本仅适配 `main` 分支。`dev` 分支的最新特性支持脚本正在开发中。 +见[docs/vllm_adapter/README_zh.md](docs/vllm_adapter/README_zh.md) ## 设计框架 diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md new file mode 100644 index 0000000000..acc2f36de2 --- /dev/null +++ b/docs/vllm_adapter/README_en.md @@ -0,0 +1,85 @@ +# Using FlexKV in vLLM + +## Current Version vs. Legacy Version +In commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934), we introduced a major update: +**FlexKV has transitioned from a client-server architecture to a library function that inference acceleration engines (such as vLLM) can directly invoke**, reducing inter-process communication overhead. + +This change involves significant API adjustments. Therefore, please note: + +- **Version >= `0.0.2`**: Use the **current version API**; the vLLM patch is located in `examples/vllm_adaption/`. +- **Version == `0.0.1`**: Supports the **legacy version API**; the vLLM patch is located in `examples/vllm_adaption_legacy/`. + +--- + +## Current Version (>= 0.0.2) + +### Supported Versions +- FlexKV >= `0.0.2` +- vLLM versions >= `0.8.5` can generally follow this version for adaptation + +### Example +We provide an adaptation example based on **vLLM 0.10.1.1**: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + "use_pinned_memory": true + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +## Legacy Version (<= 0.0.1) – Not Recommended for Current Use + +### Supported Versions +- FlexKV <= `0.0.1` + +### Example +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: + +```bash +# Start FlexKV as server +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# Start vLLM as client +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# Start benchmark +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md new file mode 100644 index 0000000000..f13815db1e --- /dev/null +++ b/docs/vllm_adapter/README_zh.md @@ -0,0 +1,84 @@ +# 在 vLLM 中使用 FlexKV + +## 当前版本与 Legacy 版本说明 +在 commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934),我们更新了一个重要功能: + **FlexKV 从 client-server 模式,变为推理加速引擎(如 vLLM)可直接调用的库函数**,以减少进程间消息传递的开销。 +这一变更引发了较大的 API 调整。因此,请注意: + +- **版本 >= `0.0.2`**:应使用 **当前版本 API**,vLLM patch位于 `examples/vllm_adaption/`。 +- **版本 == `0.0.1`**:仅支持 **Legacy 版本 API**, vLLM patch位于`examples/vllm_adaption_legacy/`。 + +--- + +## 当前版本(>= 0.0.2) + +### 适用版本 +- FlexKV >= `0.0.2` +- vLLM 原则上>= `0.8.5`版本均可参考示例代码进行修改 + +### 示例 +我们提供了基于 **vLLM 0.10.1.1** 的适配示例: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + "use_pinned_memory": true + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +## Legacy版本(<= 0.0.1),目前的版本尽量不要使用 + +### 适用版本 +- FlexKV <= `0.0.1` + +### 示例 +在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: + +```bash +# 启动 FlexKV 作为服务端 +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# 启动 vLLM 作为客户端 +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# 启动性能测试 +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 \ No newline at end of file diff --git a/flexkv/integration/vllm/0001-add-flexkv-connector.patch b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch similarity index 100% rename from flexkv/integration/vllm/0001-add-flexkv-connector.patch rename to examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch diff --git a/examples/vllm_adaption/flexkv_vllm_0_10_0.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_10_0.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch diff --git a/examples/vllm_adaption/flexkv_vllm_0_8_4.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_8_4.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch diff --git a/flexkv/integration/vllm/README.md b/flexkv/integration/vllm/README.md deleted file mode 100644 index 136f8b8682..0000000000 --- a/flexkv/integration/vllm/README.md +++ /dev/null @@ -1,44 +0,0 @@ -Use flexkv on vllm v0.10.1.1 - -1. apply patch -```bash -cd vllm -git apply 0001-add-flexkv-connector.patch -``` - -2. offline test -```bash -python examples/offline_inference/prefix_caching_flexkv.py -``` - -3. online serving -```bash -# generate config -cat < ./flexkv_config.json -{ - "server_recv_port": "ipc:///tmp/flexkv_test", - "cache_config": { - "enable_cpu": true, - "num_cpu_blocks": 10240, - "use_pinned_memory": true - }, - "num_log_interval_requests": 200 -} -EOF -export FLEXKV_CONFIG_PATH="./flexkv_config.json" - -VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ - --tensor-parallel-size 8 \ - --trust-remote-code \ - --port 30001 \ - --max-num-seqs 128 \ - --max-num-batched-tokens 8192 \ - --max_model_len 8192 \ - --max-seq-len-to-capture 8192 \ - --gpu-memory-utilization 0.8 \ - --enable-chunked-prefill \ - --enable-prefix-caching \ - --kv-transfer-config \ - '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' - -``` \ No newline at end of file From f17d6a8b7228e77d6175fe85e510fe887e5977c1 Mon Sep 17 00:00:00 2001 From: leolingli Date: Fri, 5 Sep 2025 15:45:51 +0800 Subject: [PATCH 29/42] [docs] add stable branch introduce --- README.md | 1 + README_zh.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 98159bafb9..ed78dbca43 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ FlexKV performs: - The main branch is the stable branch, which maintains already tested commits. Please pull from main branch if you need stable code. - The dev branch is the development branch, which contains newer features. Please branch from and merge into dev if you need new features or are developing new functionality. - The bugfix branch is for bug fixes, maintaining urgent bugs that need immediate resolution or documentation that requires prompt updates. If you need to fix a bug or update documentation urgently, please branch from and merge into the bugfix branch. +- The stable branch refers to the previous main branch state, intended only for rollback or extremely conservative use cases (e.g., production deployment). Its use is discouraged. ## Roadmap diff --git a/README_zh.md b/README_zh.md index f95d5c5e50..0618a83220 100644 --- a/README_zh.md +++ b/README_zh.md @@ -74,6 +74,7 @@ FlexKV 在处理 *get* 请求时: - main 为稳定分支,维护已经测试过的commit。需要稳定的代码请从此分支拉取。 - dev 为开发分支,维护较新特性。需要新特性和开发新特性请从此分支拉取和合入。 - bugfix 为bug分支,维护需要立即解决的bug或需要立即更新的文档。需要解决bug和立即更新的文档请从此分支拉取和合入。 +- stable 为上一个版本的main分支位置,仅用于回滚以及极其保守的情况使用(如产品化)。不鼓励使用此版本。 ## Roadmap From b80fe94ff4c5dc8ce367a7dbf7797a5be3569abb Mon Sep 17 00:00:00 2001 From: leolingli Date: Fri, 5 Sep 2025 16:02:55 +0800 Subject: [PATCH 30/42] [doc] add CONTRRIBUTING.md --- CONTRIBUTING.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..301a6fe36a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing to Mooncake + +Thank you for your interest in contributing to FlexKV! + +## PR Title and Classification +Use a prefixed PR title to indicate the type of changes. Please use one of the following: + +- `[bugfix]` for bugfixes +- `[feature]` for new features +- `[test]` for test cases +- `[ci/build]` for build or continuous integration improvements +- `[doc]` for documentation fixes +- `[misc]` for PRs that do not fit the above categories. Please use this sparingly. \ No newline at end of file From ca070af523382557a1e20a5be27bd24df69c5e80 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Sun, 7 Sep 2025 20:44:42 -0700 Subject: [PATCH 31/42] [bugfix] fix incorrect num_tokens_to_get/put --- flexkv/integration/vllm/vllm_v1_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 951259474b..fbc82b3de0 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -222,7 +222,7 @@ def _get_match( the task_id and number of new matched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_get = (cdiv(request.num_prompt_tokens, self.block_size)-1)*self.block_size + num_tokens_to_get = (cdiv(request.num_prompt_tokens+1, self.block_size)-1)*self.block_size token_ids = request.prompt_token_ids[:num_tokens_to_get] assert num_computed_tokens <= num_tokens_to_get @@ -368,7 +368,7 @@ def _put_match( the task_id, number of matched tokens and number of unmatched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_put = (cdiv(request.num_tokens, self.block_size)-1)*self.block_size + num_tokens_to_put = (cdiv(request.num_tokens+1, self.block_size)-1)*self.block_size token_ids = request.all_token_ids[:num_tokens_to_put] if num_tokens_to_put == 0: From d79e4c15d3ce1612ed6756cf87e81c28e36dff00 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 8 Sep 2025 00:32:37 -0700 Subject: [PATCH 32/42] [bugfix] fix incorrect num_tokens_to_get && format code --- flexkv/integration/vllm/vllm_v1_adapter.py | 174 ++++++++++----------- 1 file changed, 87 insertions(+), 87 deletions(-) diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 951259474b..c0a046de99 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -46,29 +46,29 @@ class FlexKVResponse: class FlexKVTask(ABC): task_id: int = 0 request: "Request" = 0 - + # slot mapping slot_mapping: Optional[np.ndarray] = None - + # timer match_start_time: float = 0 match_end_time: float = 0 task_launch_time: float = 0 task_finished_time: float = 0 - + @property def match_cost(self) -> float: return (self.match_end_time - self.match_start_time) - + @property def task_execute_cost(self) -> float: return (self.task_finished_time - self.task_launch_time) - + @property @abstractmethod def task_type(self) -> str: ... - + def __str__(self): return (f"FlexKVTask(task_id={self.task_id}, " f"request={self.request.request_id}, " @@ -80,11 +80,11 @@ def __str__(self): class FlexKVGetTask(FlexKVTask): num_computed_tokens: int num_new_matched_tokens: int - + @property def task_type(self) -> str: return "get" - + def __str__(self): return (f"FlexKVGetTask(task_id={self.task_id}, " f"request={self.request.request_id}, " @@ -93,16 +93,16 @@ def __str__(self): f"match_cost {self.match_cost*1000:.2f} ms, " f"task execute cost {self.task_execute_cost*1000:.2f} ms)") - + @dataclass(kw_only=True) class FlexKVPutTask(FlexKVTask): num_matched_tokens: int num_unmatched_tokens: int - + @property def task_type(self) -> str: return "put" - + def __str__(self): return (f"FlexKVPutTask(task_id={self.task_id}, " f"request={self.request.request_id}, " @@ -110,7 +110,7 @@ def __str__(self): f"num_unmatched_tokens={self.num_unmatched_tokens}, " f"match_cost {self.match_cost*1000:.2f} ms, " f"task execute cost {self.task_execute_cost*1000:.2f} ms)") - + class FlexKVSchedulerConnector: def __init__( @@ -136,12 +136,12 @@ def __init__( tokens_per_block=flexkv_config.block_size, **flexkv_config.cache_config, ) - self.flexkv_manager = KVManager(model_config=self.model_config, + self.flexkv_manager = KVManager(model_config=self.model_config, cache_config=self.cache_config, gpu_register_port=flexkv_config.server_recv_port) self.flexkv_manager.start() # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) - + # request_id -> task_id self.req_id_to_task_dict: dict[str, int] = {} # launched but unfinished tasks @@ -150,32 +150,32 @@ def __init__( # unlaunched tasks self.tasks_to_launch: dict[int, FlexKVTask] = {} self.tasks_to_cancel: dict[int, FlexKVTask] = {} - + self.flexkv_stats = FlexKVStats(flexkv_config.num_log_interval_requests) while not self.is_ready(): - logger.info(f"Waiting for flexkv init...") + logger.info("Waiting for flexkv init...") time.sleep(5) - logger.info(f"Finish init FlexKVSchedulerConnector") - + logger.info("Finish init FlexKVSchedulerConnector") + def is_ready( self, ) -> bool: " Ask flexkv is ready " return self.flexkv_manager.is_ready() - + def shutdown(self) -> None: self.flexkv_manager.shutdown() - + @property def dp_client_id(self) -> int: return self.flexkv_manager.dp_client_id - + #################### #### Get Method #### - #################### - + #################### + def get_num_new_matched_tokens( self, request: "Request", @@ -188,11 +188,11 @@ def get_num_new_matched_tokens( which means not need to transfer from flexkv. Returns: - tuple[int, bool]: A tuple containing two integer values representing the - number of new matched tokens and whether it is necessary + tuple[int, bool]: A tuple containing two integer values representing the + number of new matched tokens and whether it is necessary to get the new matched blocks from flexkv. """ - task_id, num_new_matched_tokens = self._get_match(request=request, + task_id, num_new_matched_tokens = self._get_match(request=request, num_computed_tokens=num_computed_tokens) self.flexkv_stats.record_get(num_prompt_tokens=request.num_prompt_tokens, num_gpu_matched_tokens=num_computed_tokens, @@ -202,10 +202,10 @@ def get_num_new_matched_tokens( num_computed_tokens=num_computed_tokens, num_new_matched_tokens=num_new_matched_tokens): return 0, False - + return num_new_matched_tokens, True - - + + def _get_match( self, request: "Request", @@ -222,22 +222,22 @@ def _get_match( the task_id and number of new matched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_get = (cdiv(request.num_prompt_tokens, self.block_size)-1)*self.block_size + num_tokens_to_get = (cdiv(request.num_prompt_tokens+1, self.block_size)-1)*self.block_size token_ids = request.prompt_token_ids[:num_tokens_to_get] - + assert num_computed_tokens <= num_tokens_to_get assert num_computed_tokens % self.block_size == 0 - + if num_tokens_to_get == num_computed_tokens: return -1, 0 - + np_token_ids = np.array(token_ids) np_token_mask = np.ones_like(np_token_ids, dtype=bool) np_token_mask[:num_computed_tokens] = False task_id, matched_mask = self.flexkv_manager.get_match(token_ids=np_token_ids, token_mask=np_token_mask) num_new_matched_tokens = matched_mask.sum().item() - + # Auto cancel if not call update_state_after_alloc() match_end_time = time.perf_counter() logger.debug(f"Get match cost {(match_end_time-match_start_time)*1000:.2f} ms.") @@ -249,11 +249,11 @@ def _get_match( num_new_matched_tokens=num_new_matched_tokens, match_start_time=match_start_time, match_end_time=match_end_time) - + logger.debug(f"FlexKV create get task: {self.tasks_to_cancel[task_id]}") - + return task_id, num_new_matched_tokens - + def _need_to_get( self, num_prompt_tokens: int, @@ -264,21 +264,21 @@ def _need_to_get( Determine whether it is necessary to get the new matched blocks from flexkv. """ return num_new_matched_tokens > 0 - + def update_state_after_alloc( self, request: "Request", - blocks: "KVCacheBlocks", + blocks: "KVCacheBlocks", num_new_matched_tokens: int, ) -> None: """ Compute slot mapping and prepare to launch task. Only call after get_num_new_matched_tokens(). - + Args: request: Request to get. blocks: All blocks of the request. - num_new_matched_tokens: Number of new matched tokens returned by + num_new_matched_tokens: Number of new matched tokens returned by get_num_new_matched_tokens(). Returns: @@ -290,27 +290,27 @@ def update_state_after_alloc( task_id = self.req_id_to_task_dict[request.request_id] task: FlexKVGetTask = self.tasks_to_cancel.pop(task_id) self.tasks_to_launch[task_id] = task - + # compute slot_mapping num_computed_blocks = task.num_computed_tokens // self.block_size num_blocks_to_get = num_new_matched_tokens // self.block_size all_block_ids = blocks.get_block_ids()[0] block_ids_to_get = all_block_ids[num_computed_blocks:num_computed_blocks+num_blocks_to_get] task.slot_mapping = np.array(block_ids_to_get).repeat(self.block_size)*self.block_size - + def wait_for_all_get_tasks(self) -> list[FlexKVResponse]: """ Blocking wait for all get tasks. - + Returns: list[FlexKVResponse]: Responses of all get tasks. """ return self._blocking_waiting_for_tasks(self.get_tasks) - + #################### #### Put Method #### #################### - + def request_finished( self, request: "Request", @@ -327,34 +327,34 @@ def request_finished( # Task not finished, can't free blocks if request.request_id in self.req_id_to_task_dict: return True - + # Abnormal finished, don't put if not (request.is_finished() and request.get_finished_reason() < 2): return False - + task_id, num_matched_tokens, num_unmatched_tokens = self._put_match(request=request) - + self.flexkv_stats.record_put(num_all_tokens=request.num_tokens, num_unmatched_tokens=num_unmatched_tokens) - + if not self._need_to_put(num_all_tokens=request.num_tokens, num_matched_tokens=num_matched_tokens, num_unmatched_tokens=num_unmatched_tokens): return False - + # prepare to launch task task: FlexKVPutTask = self.tasks_to_cancel.pop(task_id) self.tasks_to_launch[task_id] = task - + # compute slot mapping # num_blocks_to_put = (num_matched_tokens+num_unmatched_tokens) // self.block_size num_matched_blocks = num_matched_tokens // self.block_size num_unmatched_tokens = num_unmatched_tokens // self.block_size block_ids_to_put = block_ids[num_matched_blocks:num_matched_blocks+num_unmatched_tokens] task.slot_mapping = np.array(block_ids_to_put).repeat(self.block_size)*self.block_size - + return True - + def _put_match( self, request: "Request" @@ -373,17 +373,17 @@ def _put_match( if num_tokens_to_put == 0: return -1, 0, 0 - + np_token_ids = np.array(token_ids) task_id, unmatched_mask = self.flexkv_manager.put_match(token_ids=np_token_ids) - + num_unmatched_tokens = unmatched_mask.sum().item() num_matched_tokens = num_tokens_to_put - num_unmatched_tokens - + # Auto cancel if not need to put. match_end_time = time.perf_counter() logger.debug(f"Put match cost {(match_end_time-match_start_time)*1000:.2f} ms.") - + if num_unmatched_tokens > 0: self.req_id_to_task_dict[request.request_id] = task_id self.tasks_to_cancel[task_id] = FlexKVPutTask(task_id=task_id, @@ -393,9 +393,9 @@ def _put_match( match_start_time=match_start_time, match_end_time=match_end_time) logger.debug(f"FlexKV create put task: {self.tasks_to_cancel[task_id]}") - + return task_id, num_matched_tokens, num_unmatched_tokens - + def _need_to_put( self, num_all_tokens: int, @@ -406,23 +406,23 @@ def _need_to_put( Determine whether it is necessary to put the unmatched blocks from flexkv. """ return num_unmatched_tokens > 0 - + def wait_for_all_put_tasks(self) -> list[FlexKVResponse]: """ Blocking wait for all put tasks. - + Returns: list[FlexKVResponse]: Responses of all put tasks. """ return self._blocking_waiting_for_tasks(self.put_tasks) - + ####################### #### Common Method #### ####################### - + def cancel_tasks(self) -> None: """ - Cancel tasks in self.cancel_tasks. + Cancel tasks in self.cancel_tasks. Call before launch_tasks() to delete req_id in self.req_id_to_task_dict """ # TODO: check if this method is inproc. @@ -433,7 +433,7 @@ def cancel_tasks(self) -> None: logger.info(f"FlexKV Cancel task: {task}") self.flexkv_manager.cancel(task_ids=list(self.tasks_to_cancel.keys())) self.tasks_to_cancel.clear() - + def launch_tasks(self) -> None: """ Launch tasks in self.unlaunched_tasks @@ -443,7 +443,7 @@ def launch_tasks(self) -> None: task_launch_time = time.perf_counter() task_ids: list[int] = [] slot_mappings: list[np.ndarray] = [] - + for task_id, task in self.tasks_to_launch.items(): logger.info(f"FlexKV Launch task: {task}") task.task_launch_time = task_launch_time @@ -456,11 +456,11 @@ def launch_tasks(self) -> None: self.flexkv_manager.launch(task_ids=task_ids, slot_mappings=slot_mappings) self.tasks_to_launch.clear() - + def query_finished_task(self) -> tuple[set[str], set[str]]: """ Get response of finished task. - + Returns: list[FlexKVResponse]: Responses of finished tasks. """ @@ -493,17 +493,17 @@ def query_finished_task(self) -> tuple[set[str], set[str]]: # request=task.request, success=success)) self.flexkv_stats.record_faild(num_failed_requests=num_failed_tasks) return finished_sending, finished_recving - + def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[FlexKVResponse]: """ Blocking wait for tasks in task_dict. - + Returns: list[FlexKVResponse]: Responses of all tasks in task_dict. """ if len(task_dict) == 0: return [] - + task_ids = list(task_dict.keys()) response_from_manager = self.flexkv_manager.wait(task_ids=task_ids) task_finished_time = time.perf_counter() @@ -516,11 +516,11 @@ def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[ logger.info(f"{task} finished successfully.") else: logger.error(f"{task} failed, status: {response.status}.") - responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, request=task.request, success=success)) return responses_to_return - - + + class FlexKVWorkerConnector: def __init__( self, @@ -530,10 +530,10 @@ def __init__( self.flexkv_config = flexkv_config logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.server_recv_port}") self.tp_client = KVTPClient(flexkv_config.server_recv_port, 0, current_device_id) - logger.info(f"Finish init FlexKVWorkerConnector") + logger.info("Finish init FlexKVWorkerConnector") def register_to_server(self, kv_caches: dict[str, torch.Tensor]): - logger.info(f"Start register kv_caches") + logger.info("Start register kv_caches") gpu_blocks = list(kv_caches.values()) num_layer = len(kv_caches) if self.flexkv_config.use_mla: @@ -560,9 +560,9 @@ def register_to_server(self, kv_caches: dict[str, torch.Tensor]): is_mla=self.flexkv_config.use_mla, ) self.tp_client.register_to_server(gpu_blocks, gpu_layout) - logger.info(f"Finish register kv_caches") + logger.info("Finish register kv_caches") + - class FlexKVConnectorV1Impl: def __init__(self, vllm_config: "VllmConfig", role: "KVConnectorRole"): self.role = role @@ -595,9 +595,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ pass @@ -606,7 +606,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -617,13 +617,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs) -> None: """ - Start saving the a layer of KV cache from vLLM's paged buffer + Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -677,14 +677,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self.connector.get_num_new_matched_tokens( @@ -742,4 +742,4 @@ def request_finished( Optional KVTransferParams to be included in the request outputs returned by the engine. """ - return self.connector.request_finished(request, block_ids), None \ No newline at end of file + return self.connector.request_finished(request, block_ids), None From 96591df44a55a1a53c5f2afdf868f39d3e1a24f4 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Mon, 8 Sep 2025 00:33:17 -0700 Subject: [PATCH 33/42] [bugfix] fix build issue --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d11ffc5143..078650644e 100755 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] extra_compile_args = ["-std=c++17"] -include_dirs = [os.path.join(build_dir, "include")] +include_dirs = [os.path.abspath(os.path.join(build_dir, "include"))] # Add rpath to find libraries at runtime lib_dir = os.path.join(build_dir, "lib") From c05db2c00d632e777af265c08e63d20345ac2713 Mon Sep 17 00:00:00 2001 From: zuogan Date: Tue, 9 Sep 2025 16:33:23 +0800 Subject: [PATCH 34/42] fix vllm connector bug --- flexkv/cache/cache_engine.py | 6 +++--- flexkv/common/debug.py | 18 +++++++++++------- flexkv/integration/vllm/vllm_v1_adapter.py | 11 ++++++----- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 7d280c7ead..943e739353 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -301,7 +301,7 @@ def get(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block @@ -652,7 +652,7 @@ def put(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) # the mask should has a prefix of True @@ -1068,7 +1068,7 @@ def _get_block_range(self, token_mask: np.ndarray) -> Tuple[int, int]: mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: - return 0, 0 + return len(token_mask), len(token_mask) start_idx = mask_idx[0].item() // self.tokens_per_block end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index 0f79cf869b..a522c5549a 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -16,14 +16,18 @@ def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") - formatter = logging.Formatter( - fmt=_FORMAT, - datefmt=_DATE_FORMAT, + has_console_handler = any( + isinstance(handler, logging.StreamHandler) + for handler in self.logger.handlers ) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - self.logger.addHandler(console_handler) + if not has_console_handler: + formatter = logging.Formatter( + fmt=_FORMAT, + datefmt=_DATE_FORMAT, + ) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) self.set_level(debug_level) diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py index 5c5f6ed27c..c129e44563 100644 --- a/flexkv/integration/vllm/vllm_v1_adapter.py +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -194,11 +194,11 @@ def get_num_new_matched_tokens( """ task_id, num_new_matched_tokens = self._get_match(request=request, num_computed_tokens=num_computed_tokens) - self.flexkv_stats.record_get(num_prompt_tokens=request.num_prompt_tokens, + self.flexkv_stats.record_get(num_prompt_tokens=request.num_tokens, num_gpu_matched_tokens=num_computed_tokens, num_flexkv_matched_tokens=num_new_matched_tokens) - if not self._need_to_get(num_prompt_tokens=request.num_prompt_tokens, + if not self._need_to_get(num_prompt_tokens=request.num_tokens, num_computed_tokens=num_computed_tokens, num_new_matched_tokens=num_new_matched_tokens): return 0, False @@ -222,10 +222,11 @@ def _get_match( the task_id and number of new matched tokens. """ match_start_time = time.perf_counter() - num_tokens_to_get = (cdiv(request.num_prompt_tokens+1, self.block_size)-1)*self.block_size - token_ids = request.prompt_token_ids[:num_tokens_to_get] + num_tokens_to_get = (request.num_tokens//self.block_size)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_get] - assert num_computed_tokens <= num_tokens_to_get + assert num_computed_tokens <= num_tokens_to_get, ( + f"{num_computed_tokens=} must less equal to {num_tokens_to_get=}") assert num_computed_tokens % self.block_size == 0 if num_tokens_to_get == num_computed_tokens: From a03978de9c35a99c9eaf4ced283788ebd2cef913 Mon Sep 17 00:00:00 2001 From: zuogan Date: Tue, 9 Sep 2025 17:23:59 +0800 Subject: [PATCH 35/42] further_fix --- flexkv/cache/cache_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 943e739353..e113aeb69b 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -1068,7 +1068,7 @@ def _get_block_range(self, token_mask: np.ndarray) -> Tuple[int, int]: mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: - return len(token_mask), len(token_mask) + return len(token_mask)//self.tokens_per_block, len(token_mask)//self.tokens_per_block start_idx = mask_idx[0].item() // self.tokens_per_block end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 From 15de1ce4c224de307047ddad64afeac6149d39da Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Tue, 9 Sep 2025 20:59:13 -0700 Subject: [PATCH 36/42] modify default config --- benchmarks/example_config.json | 1 - docs/vllm_adapter/README_en.md | 1 - docs/vllm_adapter/README_zh.md | 3 +- examples/run_server.py | 25 ++++++------- examples/scheduler_server_example.py | 55 ++++++++++++++-------------- flexkv/common/config.py | 9 ++--- flexkv/common/tracer.py | 1 - flexkv/server/server.py | 32 +++++++--------- flexkv/storage/allocator.py | 3 +- flexkv/storage/storage_engine.py | 1 - tests/replay_from_tracer.py | 1 - tests/test_utils.py | 1 - 12 files changed, 59 insertions(+), 74 deletions(-) diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 4a710f41ca..d4854557c3 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -14,7 +14,6 @@ "enable_remote": false, "tokens_per_block": 16, "use_gds": false, - "use_pinned_memory": true, "gpu_kv_layout_type": "LAYERWISE", "cpu_kv_layout_type": "BLOCKWISE", "ssd_kv_layout_type": "BLOCKWISE", diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md index acc2f36de2..79a38f62ab 100644 --- a/docs/vllm_adapter/README_en.md +++ b/docs/vllm_adapter/README_en.md @@ -41,7 +41,6 @@ cat < ./flexkv_config.json "cache_config": { "enable_cpu": true, "num_cpu_blocks": 10240, - "use_pinned_memory": true }, "num_log_interval_requests": 200 } diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md index f13815db1e..81e291b5cc 100644 --- a/docs/vllm_adapter/README_zh.md +++ b/docs/vllm_adapter/README_zh.md @@ -40,7 +40,6 @@ cat < ./flexkv_config.json "cache_config": { "enable_cpu": true, "num_cpu_blocks": 10240, - "use_pinned_memory": true }, "num_log_interval_requests": 200 } @@ -81,4 +80,4 @@ bash benchmarks/flexkv_benchmark/serving_vllm.sh # 启动性能测试 bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh ``` -在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 \ No newline at end of file +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 diff --git a/examples/run_server.py b/examples/run_server.py index d5b6a182ec..48b24ecad1 100644 --- a/examples/run_server.py +++ b/examples/run_server.py @@ -12,16 +12,16 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - + # NAME - parser.add_argument("--enable-cpu", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-cpu", + action=argparse.BooleanOptionalAction, default=True) - parser.add_argument("--enable-ssd", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-ssd", + action=argparse.BooleanOptionalAction, default=False,) - parser.add_argument("--enable-remote", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-remote", + action=argparse.BooleanOptionalAction, default=False,) parser.add_argument("--model-path", type=str, help="model path", default="") parser.add_argument("--tp-size", type=int, default=1) @@ -54,7 +54,7 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() hf_config = AutoConfig.from_pretrained(args.model_path) - + num_layers=hf_config.num_hidden_layers if hasattr(hf_config, 'num_key_value_heads'): num_kv_heads=hf_config.num_key_value_heads @@ -65,7 +65,7 @@ def parse_args() -> argparse.Namespace: head_size=(hf_config.head_dim if hasattr(hf_config, 'head_dim') else hf_config.hidden_size//hf_config.num_attention_heads) use_mla=hf_config.architectures[0].startswith("Deepseek") - + # TODO: different model config may have different attribute name model_config = ModelConfig( num_layers=num_layers, @@ -76,14 +76,13 @@ def parse_args() -> argparse.Namespace: dp_size=args.dp_size, dtype=hf_config.torch_dtype ) - + cache_config = CacheConfig( enable_cpu=args.enable_cpu, enable_ssd=args.enable_ssd, enable_remote=args.enable_remote, use_gds=False, enable_trace=False, - use_pinned_memory=False, ssd_cache_iouring_entries=512, tokens_per_block=args.block_size, num_cpu_blocks=args.num_cpu_blocks, @@ -93,6 +92,6 @@ def parse_args() -> argparse.Namespace: remote_cache_size_mode=args.remote_cache_size_mode, remote_cache_path=args.remote_cache_path, ) - + kvserver = KVServer(model_config, cache_config, args.server_recv_port) - kvserver.run() \ No newline at end of file + kvserver.run() diff --git a/examples/scheduler_server_example.py b/examples/scheduler_server_example.py index 29826afc9a..059cc467aa 100644 --- a/examples/scheduler_server_example.py +++ b/examples/scheduler_server_example.py @@ -16,9 +16,9 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, model_config, gpu_kv_layout): """Run TP client process""" from flexkv.server.client import KVTPClient - + print(f"Starting TP client: dp_client_id={dp_client_id}, tp_rank={tp_rank}, device_id={device_id}") - + try: # Set CUDA device for this process if torch.cuda.is_available(): @@ -27,7 +27,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo torch.cuda.init() # Clear cache torch.cuda.empty_cache() - + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Create GPU blocks for this TP client @@ -51,7 +51,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo # Keep TP client running while True: time.sleep(1) - + except Exception as e: print(f"TP client {tp_rank} error: {e}") import traceback @@ -84,7 +84,6 @@ def main(): enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks, ) @@ -106,14 +105,14 @@ def main(): cache_config=cache_config, server_recv_port="ipc:///tmp/scheduler_server_example" # TPClient connects to this port ) - + # Start background server thread to handle TPClient registration scheduler_server.start_server_thread() - - print(f"SchedulerServer started!") + + print("SchedulerServer started!") print(f"TPClient can connect to: {scheduler_server.get_server_port()}") print("Starting TP client processes...") - + # Start TP client processes tp_client_processes = [] for tp_rank in range(tp_size): @@ -123,7 +122,7 @@ def main(): if device_id >= available_gpus: device_id = device_id % available_gpus print(f"Warning: Using GPU {device_id} for TP rank {tp_rank} (not enough GPUs)") - + tp_client_process = Process( target=run_tp_client_process, args=(0, tp_rank, device_id, scheduler_server.get_server_port(), model_config, gpu_kv_layout), @@ -134,32 +133,32 @@ def main(): print(f"Started TP client process for rank {tp_rank} on device {device_id}") print("Waiting for all TP clients to register...") - + time.sleep(5) - + # Now we can directly use scheduler_server without network communication # Example: Create some test data (following benchmark_kvmanager.py pattern) batch_size = 4 seq_len = 128 - + print("\n=== Generating test data ===") # Generate separate sequences for each request (correct approach) batch_token_ids = [] batch_slot_mappings = [] batch_token_masks = [] - + for i in range(batch_size): # Each sequence is independent (seq_len,) shape token_ids = torch.randint(0, 1000, (seq_len,)) slot_mapping = torch.arange(i * seq_len, (i + 1) * seq_len) token_mask = torch.ones(seq_len, dtype=torch.bool) - + batch_token_ids.append(token_ids) batch_slot_mappings.append(slot_mapping) batch_token_masks.append(token_mask) - + print(f"Generated {batch_size} sequences, each with {seq_len} tokens") - + print("\n=== Executing PUT Operations ===") # PUT operations - each sequence processed separately start_time = time.time() @@ -173,7 +172,7 @@ def main(): if task_id: put_task_ids.append(task_id) print(f"PUT task {task_id} created for sequence {i}") - + put_time = (time.time() - start_time) * 1000 print(f"Created {len(put_task_ids)} PUT tasks, time: {put_time:.2f}ms") time.sleep(2) @@ -190,10 +189,10 @@ def main(): if task_id: get_task_ids.append(task_id) print(f"GET task {task_id} created for sequence {i}") - + get_time = (time.time() - start_time) * 1000 print(f"Created {len(get_task_ids)} GET tasks, time: {get_time:.2f}ms") - + print("\n=== Waiting for All Tasks to Complete ===") # Wait for all tasks to complete - can wait for multiple tasks at once all_task_ids = put_task_ids + get_task_ids @@ -202,7 +201,7 @@ def main(): masks = scheduler_server.wait(all_task_ids) wait_time = (time.time() - start_time) * 1000 print(f"All {len(all_task_ids)} tasks completed, time: {wait_time:.2f}ms") - + # Analyze results if masks: total_tokens = 0 @@ -211,7 +210,7 @@ def main(): tokens = mask.sum().item() if hasattr(mask, 'sum') else len(mask) total_tokens += tokens print(f"Task {task_id}: {tokens} tokens processed") - + print("\n=== Trying Non-blocking Wait ===") # Create a few more tasks and try non-blocking wait extra_task_ids = [] @@ -223,7 +222,7 @@ def main(): ) if task_id: extra_task_ids.append(task_id) - + if extra_task_ids: # Immediately try to wait (might not be completed yet) masks = scheduler_server.try_wait(extra_task_ids) @@ -233,15 +232,15 @@ def main(): print(f"Tasks {extra_task_ids} not ready yet, will wait...") masks = scheduler_server.wait(extra_task_ids) print(f"Tasks {extra_task_ids} completed after wait") - + print("\n✅ All operations completed successfully!") - - + + # Clean up resources print("\n=== Shutting down SchedulerServer ===") scheduler_server.shutdown() print("SchedulerServer has been shut down") - + # Terminate TP client processes print("Terminating TP client processes...") for i, process in enumerate(tp_client_processes): @@ -253,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index fbcf465727..d20d7518dd 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -32,14 +32,13 @@ class CacheConfig: enable_ssd: bool = False enable_remote: bool = False use_gds: bool = False - use_pinned_memory: bool = False index_accel: bool = False # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE + cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE # mempool capacity configs num_cpu_blocks: int = 1000000 @@ -72,6 +71,6 @@ class CacheConfig: trace_max_file_size_mb: int = 100 trace_max_files: int = 5 trace_flush_interval_ms: int = 1000 - + #evict ratio evict_ratio: float = 0.0 diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 92668ae3f8..dff6b1ff3a 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -121,7 +121,6 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_kv_layout_type": str(cache_config.ssd_kv_layout_type), "remote_kv_layout_type": str(cache_config.remote_kv_layout_type), "use_gds": cache_config.use_gds, - "use_pinned_memory": cache_config.use_pinned_memory, "remote_cache_size_mode": cache_config.remote_cache_size_mode, "num_cpu_blocks": cache_config.num_cpu_blocks, "num_ssd_blocks": cache_config.num_ssd_blocks, diff --git a/flexkv/server/server.py b/flexkv/server/server.py index daf25ff7ca..1849c1e304 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -81,7 +81,7 @@ def register_dp_client( flexkv_logger.info(f"DP client {client_id} registered successfully") return client_id - + def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -105,7 +105,7 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: class KVServerHandle: def __init__(self, process: mp.Process): self.process = process - + def shutdown(self) -> None: self.process.join(timeout=5) if self.process.is_alive(): @@ -137,7 +137,7 @@ def __init__( self.req_counter = 0 self._is_ready = False self._running = False - + # Request handler dispatch table self.request_handlers = { StartRequest: self._handle_start_request, @@ -162,14 +162,14 @@ def start_server(self) -> None: self._is_ready = True @staticmethod - def _server_process(model_config: ModelConfig, + def _server_process(model_config: ModelConfig, cache_config: CacheConfig, gpu_register_port: str, server_recv_port: str) -> None: - + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) server.run() - + @classmethod def create_server(cls, model_config: ModelConfig, @@ -178,18 +178,15 @@ def create_server(cls, server_recv_port: Optional[str] = None) -> 'KVServerHandle': #if server_recv_port is None: # server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" #TODO unify this - + # Set spawn method for CUDA compatibility - try: + with contextlib.suppress(RuntimeError): mp.set_start_method("spawn") - except RuntimeError: - # If already set, just continue - pass process = mp.Process(target=cls._server_process, args=(model_config, cache_config, gpu_register_port, server_recv_port)) process.start() flexkv_logger.info(f"KVServer process started, PID: {process.pid}") - + return KVServerHandle(process) def run(self) -> None: @@ -216,13 +213,13 @@ def run(self) -> None: # Use dispatch table for request handling req_type = type(req) handler = self.request_handlers.get(req_type) - + if handler is None: raise TypeError(f"Unrecognized RequestType: {req_type}") - + # Call the corresponding handler method handler(req) - + # If the request is a shutdown request, exit the loop if req_type == ShutdownRequest: break @@ -246,7 +243,7 @@ def _verify_model_config( return True # Request Handler Methods - + def _handle_start_request(self, req: StartRequest) -> None: """Handle start request""" flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}") @@ -317,7 +314,7 @@ def _handle_put_match_request(self, req: PutMatchRequest) -> None: def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: """Handle LaunchTask request""" self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings) - + def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" self.kv_task_engine.cancel_tasks(req.task_ids) @@ -381,7 +378,6 @@ def __del__(self) -> None: enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks,) diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index 7cd38156e0..ed683e6505 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -95,7 +95,6 @@ def allocate(cls, layout: KVCacheLayout, dtype: torch.dtype, **kwargs: Any) -> StorageHandle: - pin_memory = kwargs.get("pin_memory", True) total_size = layout.get_total_elements() # although the kv layout may have multiple dimensions, we only have one-dim CPU tensor flexkv_logger.info(f"CPU allocate total_size: {2 * total_size/1024/1024/1024} GB") @@ -103,7 +102,7 @@ def allocate(cls, size=(total_size,), dtype=dtype, device="cpu", - pin_memory=pin_memory, + pin_memory=False, ) return StorageHandle( handle_type=AccessHandleType.TENSOR, diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 0762b0062d..0d48fe6230 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -35,7 +35,6 @@ def __init__(self, device_type=DeviceType.CPU, layout=self._cpu_layout, dtype=self._model_config.dtype, - pin_memory=self._cache_config.use_pinned_memory, ) if self._cache_config.enable_ssd: if not self._cache_config.ssd_kv_layout_type == self._cpu_layout.type: diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index 3ddc0ce810..fad6a20ea9 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -113,7 +113,6 @@ def parse_config_event(self, event: Dict[str, Any]): ssd_kv_layout_type=self._parse_layout_type(cache_config_data['ssd_kv_layout_type']), remote_kv_layout_type=self._parse_layout_type(cache_config_data['remote_kv_layout_type']), use_gds=cache_config_data['use_gds'], - use_pinned_memory=False,#cache_config_data['use_pinned_memory'], # for local test remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], num_cpu_blocks=cache_config_data['num_cpu_blocks'], num_ssd_blocks=cache_config_data['num_ssd_blocks'], diff --git a/tests/test_utils.py b/tests/test_utils.py index ba1392eabc..93541b612b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,7 +38,6 @@ 'remote_file_prefix': "remote_cache", 'use_gds': False, 'enable_trace': False, - 'use_pinned_memory': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], 'ssd_cache_iouring_entries': 32, 'remote_cache_path': ["remote_cache1", "remote_cache2"], From 31ab4cf4e1ad0c4a3774a44c08a6c957c0ce67a0 Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Tue, 9 Sep 2025 21:19:59 -0700 Subject: [PATCH 37/42] update vllm patch --- .../vllm_0_10_1_1-flexkv-connector.patch | 69 +++++++------------ 1 file changed, 25 insertions(+), 44 deletions(-) diff --git a/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch index fc0a558d03..812a1d6e2f 100644 --- a/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +++ b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch @@ -1,24 +1,9 @@ -From a434b67b8097990f20d8c020a8c713b10dd3d5b0 Mon Sep 17 00:00:00 2001 -From: zuogan -Date: Wed, 3 Sep 2025 05:11:50 -0700 -Subject: [PATCH] add flexkv connector - ---- - .../prefix_caching_flexkv.py | 163 +++++++++++++++ - .../kv_transfer/kv_connector/factory.py | 5 + - .../kv_connector/v1/flexkv_connector.py | 191 ++++++++++++++++++ - vllm/v1/core/sched/scheduler.py | 13 +- - .../worker/kv_connector_model_runner_mixin.py | 6 +- - 5 files changed, 373 insertions(+), 5 deletions(-) - create mode 100644 examples/offline_inference/prefix_caching_flexkv.py - create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py - diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py new file mode 100644 -index 000000000..4cfe2ef7f +index 000000000..a57328ffd --- /dev/null +++ b/examples/offline_inference/prefix_caching_flexkv.py -@@ -0,0 +1,163 @@ +@@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import time @@ -36,7 +21,6 @@ index 000000000..4cfe2ef7f + "cache_config": { + "enable_cpu": True, + "num_cpu_blocks": 10240, -+ "use_pinned_memory": True + }, + "num_log_interval_requests": 200 +} @@ -84,7 +68,7 @@ index 000000000..4cfe2ef7f + +def main(): + # Create an LLM without prefix caching as a baseline. -+ regular_llm = LLM(model=model_path, ++ regular_llm = LLM(model=model_path, + enable_prefix_caching=False, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size @@ -114,7 +98,7 @@ index 000000000..4cfe2ef7f + # return + + # Create an LLM with prefix caching enabled. -+ prefix_cached_llm = LLM(model=model_path, ++ prefix_cached_llm = LLM(model=model_path, + enable_prefix_caching=True, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, @@ -124,7 +108,7 @@ index 000000000..4cfe2ef7f + # Warmup so that the shared prompt's KV cache is computed. + prefix_cached_llm.generate(generating_prompts[0], sampling_params) + -+ # wait for offload kv task finished. ++ # wait for offload kv task finished. + time.sleep(2) + + # Generate with prefix caching. @@ -149,7 +133,7 @@ index 000000000..4cfe2ef7f + ]) + print(f"Generated answers are the same: {generated_same}") + -+ # wait for offload kv task finished. ++ # wait for offload kv task finished. + time.sleep(2) + + # reset prefix cache to use flexkv @@ -249,9 +233,9 @@ index 000000000..bdfa9f321 + **kwargs: additional arguments for the load operation + + Note: -+ The number of elements in kv_caches and layer_names should be ++ The number of elements in kv_caches and layer_names should be + the same. -+ ++ + """ + self._flexkv_connector.start_load_kv(forward_context, **kwargs) + @@ -260,7 +244,7 @@ index 000000000..bdfa9f321 + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. -+ ++ + This interface will be useful for layer-by-layer pipelining. + + Args: @@ -271,13 +255,13 @@ index 000000000..bdfa9f321 + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ -+ Start saving the a layer of KV cache from vLLM's paged buffer ++ Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. -+ kv_layer (torch.Tensor): the paged KV buffer of the current ++ kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. @@ -310,7 +294,7 @@ index 000000000..bdfa9f321 + call to this method (this call or a prior one). + """ + return self._flexkv_connector.get_finished(finished_req_ids) -+ ++ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the @@ -332,14 +316,14 @@ index 000000000..bdfa9f321 + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. -+ ++ + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: -+ the number of tokens that can be loaded from the ++ the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._flexkv_connector.get_num_new_matched_tokens( @@ -398,30 +382,30 @@ index 981023409..a6c8fac38 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -118,6 +118,7 @@ class Scheduler(SchedulerInterface): - + # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + self.sending_kv_reqs: dict[str, Request] = {} - + # Encoder-related. # Calculate encoder cache size if applicable @@ -1029,7 +1030,8 @@ class Scheduler(SchedulerInterface): - + if not delay_free_blocks: self._free_blocks(request) - + else: + self.sending_kv_reqs[request.request_id] = request return kv_xfer_params - + def _free_blocks(self, request: Request): @@ -1041,7 +1043,7 @@ class Scheduler(SchedulerInterface): return len(self.waiting) + len(self.running) - + def has_finished_requests(self) -> bool: - return len(self.finished_req_ids) > 0 + return len(self.finished_req_ids) > 0 or len(self.sending_kv_reqs) > 0 - + def reset_prefix_cache(self) -> bool: return self.kv_cache_manager.reset_prefix_cache() @@ -1082,6 +1084,8 @@ class Scheduler(SchedulerInterface): @@ -430,20 +414,20 @@ index 981023409..a6c8fac38 100644 self.kv_event_publisher.shutdown() + if self.connector and hasattr(self.connector, "shutdown"): + self.connector.shutdown() - + ######################################################################## # KV Connector Related Methods @@ -1149,6 +1153,10 @@ class Scheduler(SchedulerInterface): scheduler the request during the next step. """ - + + # avoid busy checking + if len(self.running) == 0: + time.sleep(0.01) + if self.connector is not None: self.connector.update_connector_output(kv_connector_output) - + @@ -1158,4 +1166,5 @@ class Scheduler(SchedulerInterface): self.finished_recving_kv_req_ids.add(req_id) for req_id in (kv_connector_output.finished_sending or ()): @@ -457,16 +441,13 @@ index a03ebe35d..8e4460957 100644 @@ -66,9 +66,9 @@ class KVConnectorModelRunnerMixin: scheduler_output, wait_for_save=False) as kv_connector_output: pass - + - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT + # if (not kv_connector_output.finished_sending + # and not kv_connector_output.finished_recving): + # return EMPTY_MODEL_RUNNER_OUTPUT - + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = kv_connector_output --- -2.34.1 - From 3371934919f65f54b6a4e381d9a764f95b0f8c56 Mon Sep 17 00:00:00 2001 From: zuogan Date: Wed, 10 Sep 2025 19:30:27 +0800 Subject: [PATCH 38/42] init xx_kv_layout_type from str --- flexkv/common/config.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index d20d7518dd..2805ae606c 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -74,3 +74,14 @@ class CacheConfig: #evict ratio evict_ratio: float = 0.0 + + def __post_init__(self): + layout_fields = ['gpu_kv_layout_type', + 'cpu_kv_layout_type', + 'ssd_kv_layout_type', + 'remote_kv_layout_type'] + for field in layout_fields: + value = getattr(self, field) + if isinstance(value, str): + setattr(self, field, KVCacheLayoutType[value.upper()]) + From 33c09016581d8e4bc47d4143ec13652144a00324 Mon Sep 17 00:00:00 2001 From: leolingli Date: Thu, 11 Sep 2025 15:50:51 +0800 Subject: [PATCH 39/42] [doc] add version --- CHANGELOG.md | 30 ++++++++++++++++++++++++++++++ VERSION | 1 + docs/vllm_adapter/README_en.md | 12 ++++++------ docs/vllm_adapter/README_zh.md | 12 ++++++------ setup.py | 5 ++++- 5 files changed, 47 insertions(+), 13 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 VERSION diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..0fe668086a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,30 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [1.0.0] - 2025-09-11 + +### Added +- C++ radix tree for fast match, need set "index_accel": true in cache_config +- sync kernel launch +- a huge change that move cache engine to a library for accelerator(vLLM e.g.) to use instead of server-client mode. + This accelerate the get and put when no KVCache is matched. This version includes breaking API changes and is not backward compatible. +- add evict_ratio, need set "evict_ratio": 0.05 in cache_config +- reducing the bubble inner the launch kernel +- add vLLM 0.10.1.1 adapter + +### Fixed +- cython release package + + +## [0.1.0] - 2025-08-29 + +### Init +- init version +- add license + diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000..3eefcb9dd5 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +1.0.0 diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md index 79a38f62ab..781b3ad3ee 100644 --- a/docs/vllm_adapter/README_en.md +++ b/docs/vllm_adapter/README_en.md @@ -6,15 +6,15 @@ In commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-p This change involves significant API adjustments. Therefore, please note: -- **Version >= `0.0.2`**: Use the **current version API**; the vLLM patch is located in `examples/vllm_adaption/`. -- **Version == `0.0.1`**: Supports the **legacy version API**; the vLLM patch is located in `examples/vllm_adaption_legacy/`. +- **Version >= `1.0.0`**: Use the **current version API**; the vLLM patch is located in `examples/vllm_adaption/`. +- **Version == `0.1.0`**: Supports the **legacy version API**; the vLLM patch is located in `examples/vllm_adaption_legacy/`. --- -## Current Version (>= 0.0.2) +## Current Version (>= 1.0.0) ### Supported Versions -- FlexKV >= `0.0.2` +- FlexKV >= `1.0.0` - vLLM versions >= `0.8.5` can generally follow this version for adaptation ### Example @@ -63,10 +63,10 @@ VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ ``` -## Legacy Version (<= 0.0.1) – Not Recommended for Current Use +## Legacy Version (<= 0.1.0) – Not Recommended for Current Use ### Supported Versions -- FlexKV <= `0.0.1` +- FlexKV <= `0.1.0` ### Example Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md index 81e291b5cc..0e7ce7687e 100644 --- a/docs/vllm_adapter/README_zh.md +++ b/docs/vllm_adapter/README_zh.md @@ -5,15 +5,15 @@ **FlexKV 从 client-server 模式,变为推理加速引擎(如 vLLM)可直接调用的库函数**,以减少进程间消息传递的开销。 这一变更引发了较大的 API 调整。因此,请注意: -- **版本 >= `0.0.2`**:应使用 **当前版本 API**,vLLM patch位于 `examples/vllm_adaption/`。 -- **版本 == `0.0.1`**:仅支持 **Legacy 版本 API**, vLLM patch位于`examples/vllm_adaption_legacy/`。 +- **版本 >= `1.0.0`**:应使用 **当前版本 API**,vLLM patch位于 `examples/vllm_adaption/`。 +- **版本 == `0.1.0`**:仅支持 **Legacy 版本 API**, vLLM patch位于`examples/vllm_adaption_legacy/`。 --- -## 当前版本(>= 0.0.2) +## 当前版本(>= 1.0.0) ### 适用版本 -- FlexKV >= `0.0.2` +- FlexKV >= `1.0.0` - vLLM 原则上>= `0.8.5`版本均可参考示例代码进行修改 ### 示例 @@ -62,10 +62,10 @@ VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ ``` -## Legacy版本(<= 0.0.1),目前的版本尽量不要使用 +## Legacy版本(<= 0.1.0),目前的版本尽量不要使用 ### 适用版本 -- FlexKV <= `0.0.1` +- FlexKV <= `0.1.0` ### 示例 在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: diff --git a/setup.py b/setup.py index 078650644e..fcd0f97a34 100755 --- a/setup.py +++ b/setup.py @@ -7,6 +7,9 @@ from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension +def get_version(): + with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: + return f.read().strip() build_dir = "build" os.makedirs(build_dir, exist_ok=True) @@ -130,7 +133,7 @@ def copy_shared_libraries(self): setup( name="flexkv", description="A global KV-Cache manager for LLM inference", - version="0.1.0", + version=get_version(), packages=find_packages(exclude=("benchmarks", "csrc", "examples", "tests")), package_data={ "flexkv": ["*.so", "lib/*.so", "lib/*.so.*"], From abc346858d40dc10bb2fb768549dd410e5dbe224 Mon Sep 17 00:00:00 2001 From: hsr Date: Mon, 15 Sep 2025 17:08:47 +0800 Subject: [PATCH 40/42] ADD: dynamo+flexkv doc --- docs/dynamo_integration/README_en.md | 151 +++++++++++++++++++++++++++ docs/dynamo_integration/README_zh.md | 151 +++++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 docs/dynamo_integration/README_en.md create mode 100644 docs/dynamo_integration/README_zh.md diff --git a/docs/dynamo_integration/README_en.md b/docs/dynamo_integration/README_en.md new file mode 100644 index 0000000000..bbd6d0db5f --- /dev/null +++ b/docs/dynamo_integration/README_en.md @@ -0,0 +1,151 @@ +# FlexKV and Dynamo Integration Guide + +This document demonstrates how to integrate FlexKV with NVIDIA's [Dynamo](https://github.com/ai-dynamo/dynamo) framework and complete performance testing. + +Dynamo is a framework designed by NVIDIA for large-scale distributed deployment, supporting multiple backend engines including TensorRT-LLM, vLLM, and SGLang. The KV Router is an intelligent request routing component that tracks and manages KV caches stored on different workers. It intelligently assigns requests to the most suitable worker based on the overlap between requests and KV cache, as well as the current worker load, thereby reducing expensive KV cache recomputations and improving inference efficiency. This document also explains how to integrate FlexKV into Dynamo when the KV Router is enabled. + +## 1. Environment Setup + +### Dynamo Image + +We use Dynamo 0.4.1 image with vLLM backend, which includes vLLM 0.10.1.1. + +```bash +docker pull nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.4.1 +``` + +### FlexKV Code Preparation + +```bash +git clone -b dev https://github.com/taco-project/FlexKV +``` + +### Install FlexKV + +```bash +apt update && apt install liburing-dev + +cd FlexKV && ./build.sh +``` + +### vLLM Apply Patch + +```bash +# Navigate to FlexKV directory +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +### FlexKV Verification + +Please refer to the test scripts in [vLLM online serving](https://github.com/taco-project/FlexKV/blob/dev/docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B). + +## 2. Dynamo Modifications + +### kv_transfer_config + +To integrate with FlexKV, you need to modify the kv_transfer_config in the Dynamo image. Change lines 245-248 in /opt/dynamo/venv/lib/python3.12/site-packages/dynamo/vllm/args.py to: + +```python +kv_transfer_config = KVTransferConfig( + kv_connector="FlexKVConnectorV1", kv_role="kv_both" +) +logger.info("Using FlexKVConnectorV1 configuration") +``` + +### CPU Offloading + +In Dynamo, the KV router updates its KV index by receiving events sent from workers, allowing it to track the KV cache status on each worker. When CPU offloading is enabled in FlexKV, we remove [BlockRemove](https://github.com/vllm-project/vllm/blob/v0.10.1.1/vllm/v1/core/block_pool.py#L221) in vLLM, allowing FlexKV to cache all KV blocks through CPU during the serving process. This ensures that the index maintained by the KV router accurately reflects the actual index in FlexKV. + +## 3. Starting and Verifying Dynamo Services + +### Starting Dynamo + FlexKV + +```bash +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +# Start nats and etcd +nats-server -js & + +etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379 --data-dir /tmp/etcd & + +sleep 3 + +# run ingress, set routing mode with --router-mode, options include kv, round-robin, random +python -m dynamo.frontend --router-mode kv --http-port 8000 & + +# Define number of worker nodes +NUM_WORKERS=4 + +# When using multiple workers, ensure FlexKV ports are different to avoid hanging at flexkv init +# Adjust num_cpu_blocks and num_ssd_blocks values according to your server configuration +for i in $(seq 0 $((NUM_WORKERS-1))); do + cat < ./flexkv_config_${i}.json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_${i}_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": false, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 10240, + "num_ssd_blocks": 256000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + + }, + "num_log_interval_requests": 200 +} +EOF +done + +# Use a loop to start worker nodes +for i in $(seq 0 $((NUM_WORKERS-1))); do + # Calculate GPU device IDs + GPU_START=$((i*2)) + GPU_END=$((i*2+1)) + + if [ $i -lt $((NUM_WORKERS-1)) ]; then + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 & + else + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 + fi +done +``` + +### Verification + +You can verify that the Dynamo service has started correctly with the following command: +```bash +curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "messages": [ + { + "role": "user", + "content": "Tell me a joke." + } + ], + "stream":false, + "max_tokens": 30 + }' +``` + +## 4. Benchmark + +We use [genai-perf](https://github.com/triton-inference-server/perf_analyzer/tree/main/genai-perf) as our benchmark tool and [mooncake trace](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#-open-source-trace) as our dataset to evaluate the performance of Dynamo + FlexKV. + +Mooncake Trace is an open-source request file saved in jsonl format. It records timestamps of request arrivals, ISL, OSL, and KV cache-related hash IDs, containing 23,608 requests over a 1-hour period. For our experiment with 4 LLaMA-70B workers, the concurrency in the mooncake trace was too high, so we sampled every 6th request from the trace to build our benchmark dataset. + +genai-perf can send requests according to the timestamps in the trace file and calculate metrics such as TTFT (Time To First Token) and TPOT (Tokens Per Output Token) for the LLM service. The command is as follows. Please use genai-perf==0.0.13, as newer versions have a bug in timestamp parsing. + +```bash +genai-perf profile --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-70B --endpoint-type chat --endpoint /v1/chat/completions --streaming --url http://localhost:8000 --input-file payload:mooncake_trace_1_6.jsonl --random-seed 100 -v -H 'Authorization: Bearer NOT USED' -H 'Accept: text/event-stream' -- --stability-percentage 99 +``` \ No newline at end of file diff --git a/docs/dynamo_integration/README_zh.md b/docs/dynamo_integration/README_zh.md new file mode 100644 index 0000000000..3f1d58ce01 --- /dev/null +++ b/docs/dynamo_integration/README_zh.md @@ -0,0 +1,151 @@ +# FlexKV 与 Dynamo 集成指南 + +该文档展示了如何将FlexKV和NVIDIA [Dynamo](https://github.com/ai-dynamo/dynamo) 框架集成,并完成性能测试的步骤。 + +Dynamo是NVIDIA专为大规模分离式部署而设计的框架,支持TensorRT-LLM, vLLM, SGLang等多个后端引擎。其中KV 路由器(KV Router)是一个智能的请求路由组件, 它能够追踪和管理存储在不同worker上的 KV cache,并根据请求与缓存的重叠程度和worker当前负载,智能地将请求分配给最合适的 GPU 节点,从而减少昂贵的 KV 缓存重新计算,提高推理效率。文档也介绍了如何在开启KV Router时,将FlexKV集成进Dynamo。 + +## 1. 环境准备 + +### Dynamo 镜像 + +该文档使用的是后端为vLLM的Dynamo 0.4.1 镜像,内置了vLLM 0.10.1.1。 + +```bash +docker pull nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.4.1 +``` + +### FlexKV代码准备 + +```bash +git clone -b dev https://github.com/taco-project/FlexKV +``` + +### 安装 FlexKV + +```bash +apt update && apt install liburing-dev + +cd FlexKV && ./build.sh +``` + +### vLLM Apply Patch + +```bash +# 进入 FlexKV 目录 +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +### FlexKV 验证 + +请参考[vLLM online serving](https://github.com/taco-project/FlexKV/blob/dev/docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B)里的测试脚本。 + + +## 2. Dynamo 配置修改 + +### kv_transfer_config + +为了和FlexKV集成,需要修改Dynamo镜像内的kv_transfer_config。将/opt/dynamo/venv/lib/python3.12/site-packages/dynamo/vllm/args.py 的245-248行修改为: + +```python +kv_transfer_config = KVTransferConfig( + kv_connector="FlexKVConnectorV1", kv_role="kv_both" +) +logger.info("Using FlexKVConnectorV1 configuration") +``` + +### CPU Offloading + +在Dynamo中,KV router通过接收worker发送的event来更新KV index,从而感知每个worker上的KV cache情况。当FlexKV开启CPU offloading时,我们删掉vLLM里[BlockRemove](https://github.com/vllm-project/vllm/blob/v0.10.1.1/vllm/v1/core/block_pool.py#L221),让FlexKV通过CPU能够缓存住所有serving过程中的KV block,这样KV router维护的index就能反映FlexKV的真实index了。 + +## 3. 启动和验证Dynamo服务 + +### 启动Dynamo + FlexKV + +```bash +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +# 启动nats和etcd +nats-server -js & + +etcd --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://0.0.0.0:2379 --data-dir /tmp/etcd & + +sleep 3 + +# run ingress, 通过--router-mode设置路由方式,可选项为kv, round-robin, random +python -m dynamo.frontend --router-mode kv --http-port 8000 & + +# 定义工作节点数量 +NUM_WORKERS=4 + +# 多个worker时注意FlexKV的端口应不同,否则会卡在flexkv init这一步 +# 请根据服务器的配置,调整num_cpu_blocks和num_ssd_blocks的数值 +for i in $(seq 0 $((NUM_WORKERS-1))); do + cat < ./flexkv_config_${i}.json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_${i}_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": false, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 10240, + "num_ssd_blocks": 256000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + + }, + "num_log_interval_requests": 200 +} +EOF +done + +# 使用for循环启动工作节点 +for i in $(seq 0 $((NUM_WORKERS-1))); do + # 计算GPU设备ID + GPU_START=$((i*2)) + GPU_END=$((i*2+1)) + + if [ $i -lt $((NUM_WORKERS-1)) ]; then + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 & + else + FLEXKV_CONFIG_PATH="./flexkv_config_${i}.json" CUDA_VISIBLE_DEVICES=${GPU_START},${GPU_END} python3 -m dynamo.vllm --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tensor_parallel_size 2 --block-size 64 --gpu-memory-utilization 0.9 --max-model-len 100310 + fi +done +``` + +### 验证 + +可通过如下命令验证Dynamo服务是否正确启动: +```bash +curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "messages": [ + { + "role": "user", + "content": "Tell me a joke." + } + ], + "stream":false, + "max_tokens": 30 + }' +``` +## 4. Benchmark + +我们使用[genai-perf](https://github.com/triton-inference-server/perf_analyzer/tree/main/genai-perf)作为benchmark工具、[mooncake trace](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#-open-source-trace)作为数据集来评估Dynamo + FlexKV的性能。 + +Mooncake Trace 是一个开源请求记录文件,以jsonl格式保存。它记录了请求到达的时间戳、输入文本长度、输出文本长度以及与缓存有关的hash id等信息,包含了1小时内的23608个请求。我们的实验资源是4个LLaMA-70B worker,mooncake trace对于该配置来说并发太高了,于是我们从mooncake trace里每6个抽取1个request,构建了用于benchmark的数据集。 + +genai-perf可以根据trace文件里的时间戳来发送请求,统计LLM服务的TTFT、TPOT等指标,命令如下。请使用genai-perf==0.0.13,更新的版本存在解析时间戳的bug。 + +```bash + genai-perf profile --model deepseek-ai/DeepSeek-R1-Distill-Llama-70B --tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-70B --endpoint-type chat --endpoint /v1/chat/completions --streaming --url http://localhost:8000 --input-file payload:mooncake_trace_1_6.jsonl --random-seed 100 -v -H 'Authorization: Bearer NOT USED' -H 'Accept: text/event-stream' -- --stability-percentage 99 +``` \ No newline at end of file From 4c82b9a3b54738ffdc6f9e3400f2d757d7dd496d Mon Sep 17 00:00:00 2001 From: hsr Date: Mon, 15 Sep 2025 17:39:42 +0800 Subject: [PATCH 41/42] MOD:dynamo doc and main doc --- README.md | 4 ++++ README_zh.md | 4 ++++ docs/dynamo_integration/README_en.md | 8 +++++--- docs/dynamo_integration/README_zh.md | 8 +++++--- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ed78dbca43..23a108fc2e 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,10 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) See [docs/vllm_adapter/README_en.md](docs/vllm_adapter/README_en.md) +### FlexKV Integration with Dynamo + +See [docs/dynamo_integration/README_en.md](docs/dynamo_integration/README_en.md) + ## Design Architecture
diff --git a/README_zh.md b/README_zh.md index 0618a83220..24f522d1f5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -18,6 +18,10 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE 见[docs/vllm_adapter/README_zh.md](docs/vllm_adapter/README_zh.md) +### FlexKV和Dynamo框架的集成 + +见[docs/dynamo_integration/README_zh.md](docs/dynamo_integration/README_zh.md) + ## 设计框架
diff --git a/docs/dynamo_integration/README_en.md b/docs/dynamo_integration/README_en.md index bbd6d0db5f..1fae4878a0 100644 --- a/docs/dynamo_integration/README_en.md +++ b/docs/dynamo_integration/README_en.md @@ -17,7 +17,7 @@ docker pull nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.4.1 ### FlexKV Code Preparation ```bash -git clone -b dev https://github.com/taco-project/FlexKV +git clone https://github.com/taco-project/FlexKV ``` ### Install FlexKV @@ -31,8 +31,10 @@ cd FlexKV && ./build.sh ### vLLM Apply Patch ```bash -# Navigate to FlexKV directory -git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +# Navigate to vLLM directory +cd /opt/vllm +# apply patch +git apply /your/path/to/FlexKV/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch ``` ### FlexKV Verification diff --git a/docs/dynamo_integration/README_zh.md b/docs/dynamo_integration/README_zh.md index 3f1d58ce01..c33171af69 100644 --- a/docs/dynamo_integration/README_zh.md +++ b/docs/dynamo_integration/README_zh.md @@ -17,7 +17,7 @@ docker pull nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.4.1 ### FlexKV代码准备 ```bash -git clone -b dev https://github.com/taco-project/FlexKV +git clone https://github.com/taco-project/FlexKV ``` ### 安装 FlexKV @@ -31,8 +31,10 @@ cd FlexKV && ./build.sh ### vLLM Apply Patch ```bash -# 进入 FlexKV 目录 -git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +# 进入 vLLM 目录 +cd /opt/vllm +# apply patch +git apply /your/path/to/FlexKV/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch ``` ### FlexKV 验证 From ecc1d5904866b5c692ca06d76f2cd49eba9c9207 Mon Sep 17 00:00:00 2001 From: leolingli Date: Mon, 15 Sep 2025 19:54:42 +0800 Subject: [PATCH 42/42] [doc] add flexkv_config.json introduce --- CONTRIBUTING.md | 2 +- README.md | 8 ++ README_zh.md | 8 ++ docs/dynamo_integration/README_en.md | 4 +- docs/dynamo_integration/README_zh.md | 4 +- docs/flexkv_config_reference/README_en.md | 147 ++++++++++++++++++++++ docs/flexkv_config_reference/README_zh.md | 145 +++++++++++++++++++++ docs/vllm_adapter/README_en.md | 2 + docs/vllm_adapter/README_zh.md | 2 + 9 files changed, 319 insertions(+), 3 deletions(-) create mode 100644 docs/flexkv_config_reference/README_en.md create mode 100644 docs/flexkv_config_reference/README_zh.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 301a6fe36a..a395746aa0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# Contributing to Mooncake +# Contributing to FlexKV Thank you for your interest in contributing to FlexKV! diff --git a/README.md b/README.md index 23a108fc2e..de5bbc5acf 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,18 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) ## How to Use +### Install Dependencies + +```bash +apt install liburing-dev +apt install libxxhash-dev +``` + ### Build FlexKV ```bash ./build.sh +#./build.sh --release for cython package ``` ### Use FlexKV with vLLM diff --git a/README_zh.md b/README_zh.md index 24f522d1f5..1654ff17ae 100644 --- a/README_zh.md +++ b/README_zh.md @@ -8,10 +8,18 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE ## 如何使用 +### 安装依赖 + +```bash +apt install liburing-dev +apt install libxxhash-dev +``` + ### 编译 FlexKV ```bash ./build.sh +#./build.sh --release for cython package ``` ### 以 vLLM 为例使用 FlexKV diff --git a/docs/dynamo_integration/README_en.md b/docs/dynamo_integration/README_en.md index 1fae4878a0..6f3988e23e 100644 --- a/docs/dynamo_integration/README_en.md +++ b/docs/dynamo_integration/README_en.md @@ -39,7 +39,7 @@ git apply /your/path/to/FlexKV/examples/vllm_adaption/vllm_0_10_1_1-flexkv-conne ### FlexKV Verification -Please refer to the test scripts in [vLLM online serving](https://github.com/taco-project/FlexKV/blob/dev/docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B). +Please refer to the test scripts in [vLLM online serving](../../docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B). ## 2. Dynamo Modifications @@ -123,6 +123,8 @@ for i in $(seq 0 $((NUM_WORKERS-1))); do done ``` +> Note: The `flexkv_config.json` configuration is provided as a simple example only. For full parameter options, please refer to [`docs/flexkv_config_reference/README_en.md`](../../docs/flexkv_config_reference/README_en.md) + ### Verification You can verify that the Dynamo service has started correctly with the following command: diff --git a/docs/dynamo_integration/README_zh.md b/docs/dynamo_integration/README_zh.md index c33171af69..651b0d9aef 100644 --- a/docs/dynamo_integration/README_zh.md +++ b/docs/dynamo_integration/README_zh.md @@ -39,7 +39,7 @@ git apply /your/path/to/FlexKV/examples/vllm_adaption/vllm_0_10_1_1-flexkv-conne ### FlexKV 验证 -请参考[vLLM online serving](https://github.com/taco-project/FlexKV/blob/dev/docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B)里的测试脚本。 +请参考[vLLM online serving](../../docs/vllm_adapter/README_zh.md#%E7%A4%BA%E4%BE%8B)里的测试脚本。 ## 2. Dynamo 配置修改 @@ -124,6 +124,8 @@ for i in $(seq 0 $((NUM_WORKERS-1))); do done ``` +> 注:`flexkv_config.json`配置仅为简单示例,选项请参考[`docs/flexkv_config_reference/README_zh.md`](../../docs/flexkv_config_reference/README_zh.md) + ### 验证 可通过如下命令验证Dynamo服务是否正确启动: diff --git a/docs/flexkv_config_reference/README_en.md b/docs/flexkv_config_reference/README_en.md new file mode 100644 index 0000000000..f91ca77ba9 --- /dev/null +++ b/docs/flexkv_config_reference/README_en.md @@ -0,0 +1,147 @@ +# FlexKV Configuration Guide + +This guide explains how to configure and use the FlexKV online serving configuration file (`flexkv_config.json`), including the meaning of all parameters, recommended values, and typical usage scenarios. + +--- + +## Recommended Configuration + +Below is a production-grade recommended configuration that balances performance and stability: + +```json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": true, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 233000, + "num_ssd_blocks": 4096000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + }, + "num_log_interval_requests": 2000 +} +``` +- `num_cpu_blocks` and `num_ssd_blocks` represent the total number of blocks in CPU memory and SSD respectively. These values must be configured according to your machine specs and model size. See [Cache Capacity Configuration](#cache-capacity-config) for calculation details. +- `ssd_cache_dir` specifies the directory where SSD-stored KV cache files are saved. + +--- + +## Configuration File Structure Overview + +The FlexKV configuration file is a JSON file, primarily consisting of three parts: + +- `enable_flexkv`: Whether to enable FlexKV (must be set to `true` to take effect). +- `server_recv_port`: The IPC port on which the FlexKV service listens. +- `cache_config`: The core cache configuration object, containing all cache behavior parameters. +- `num_log_interval_requests`: Log statistics interval (outputs performance log every N requests). + +--- + +## Complete `cache_config` Parameter Reference (from [`flexkv/common/config.py`](../../flexkv/common/config.py)) + +### Basic Configuration + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `tokens_per_block` | int | 16 | Number of tokens per KV block. Must match the `block_size` used in the acceleration framework (e.g., vLLM). | +| `enable_cpu` | bool | true | Whether to enable CPU memory as a cache layer. Strongly recommended to enable. | +| `enable_ssd` | bool | false | Whether to enable SSD as a cache layer. Recommended if NVMe SSD is available. | +| `enable_remote` | bool | false | Whether to enable remote cache (e.g., scalable cloud storage). Requires remote cache engine and custom implementation. | +| `use_gds` | bool | false | Whether to use GPU Direct Storage (GDS) to accelerate SSD I/O. Not currently supported. | +| `index_accel` | bool | false | Whether to enable C++ RadixTree. Recommended to enable. | + +--- + +### KV Cache Layout Types (Generally No Need to Modify) + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `gpu_kv_layout_type` | enum | LAYERWISE | Organization of KV cache on GPU (layer-wise or block-wise). Must match vLLM’s layout (currently `LAYERWISE`). | +| `cpu_kv_layout_type` | enum | BLOCKWISE | Organization on CPU. Recommended to use `BLOCKWISE`. Does not need to match vLLM. | +| `ssd_kv_layout_type` | enum | BLOCKWISE | Organization on SSD. Recommended to use `BLOCKWISE`. Does not need to match vLLM. | +| `remote_kv_layout_type` | enum | BLOCKWISE | Organization for remote cache. Must be defined according to remote backend’s layout. | + +> Note: Do not modify layout types unless you have specific performance requirements. + +--- + +### Cache Capacity Configuration + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `num_cpu_blocks` | int | 1000000 | Number of blocks allocated in CPU memory. Adjust based on available RAM. | +| `num_ssd_blocks` | int | 10000000 | Number of blocks allocated on SSD. | +| `num_remote_blocks` | int \| None | None | Number of blocks allocated in remote cache. | + +> Note: Block size in all cache levels (CPU/SSD/Remote) matches the GPU block size. Estimate cache capacities based on GPU KV cache memory usage and block count. + +> Note: `block_size = num_layer * _kv_dim * tokens_per_block * num_head * head_size * dtype_size`. + +--- + +### CPU-GPU Transfer Optimization + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `use_ce_transfer_h2d` | bool | false | Whether to use CUDA Copy Engine for Host→Device transfers. Reduces SM usage but may slightly reduce bandwidth. Real-world difference is minimal. | +| `use_ce_transfer_d2h` | bool | false | Whether to use CUDA Copy Engine for Device→Host transfers. | +| `transfer_sms_h2d` | int | 8 | Number of SMs (Streaming Multiprocessors) allocated for H2D transfers. | +| `transfer_sms_d2h` | int | 8 | Number of SMs allocated for D2H transfers. | + +--- + +### SSD Cache Configuration + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `max_blocks_per_file` | int | 32000 | Maximum number of blocks per SSD file. `-1` means unlimited. | +| `ssd_cache_dir` | str \| List[str] | None | **Required.** Path to SSD cache directory, e.g., `"/data/flexkv_ssd/"`. | +| `ssd_cache_iouring_entries` | int | 0 | io_uring queue depth. Recommended: `512` for significantly improved concurrent I/O performance. | +| `ssd_cache_iouring_flags` | int | 0 | io_uring flags. Keep as `0` in most cases. | + +> Note: To maximize bandwidth across multiple SSDs, bind each SSD to a separate directory and specify them as a list: +> `"ssd_cache_dir": ["/data0/flexkv_ssd/", "/data1/flexkv_ssd/"]`. +> KV blocks will be evenly distributed across all SSDs. + +> Note: Setting `ssd_cache_iouring_entries` to `0` disables io_uring. Not recommended. + +--- + +### Remote Cache Configuration (Skip if not enabled) + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `remote_cache_size_mode` | str | "file_size" | Allocate remote cache space by file size or block count. | +| `remote_file_size` | int \| None | None | Size (in bytes) of each remote file. | +| `remote_file_num` | int \| None | None | Number of remote files. | +| `remote_file_prefix` | str \| None | None | Prefix for remote file names. | +| `remote_cache_path` | str \| List[str] | None | Remote cache path (e.g., Redis URL, S3 path). | +| `remote_config_custom` | dict \| None | None | Custom remote cache configurations (e.g., timeout, authentication). | + +--- + +### Tracing and Logging + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `enable_trace` | bool | true | Whether to enable performance tracing. Disable (`false`) in production to reduce overhead. | +| `trace_file_path` | str | "./flexkv_trace.log" | Path to trace log file. | +| `trace_max_file_size_mb` | int | 100 | Maximum size (MB) per trace log file. | +| `trace_max_files` | int | 5 | Maximum number of trace log files to retain. | +| `trace_flush_interval_ms` | int | 1000 | Trace log flush interval (milliseconds). | + +--- + +### Cache Eviction Policy + +| Parameter Name | Type | Default | Description | +|----------------|------|---------|-------------| +| `evict_ratio` | float | 0.0 | Ratio of blocks to proactively evict from CPU/SSD per eviction cycle. `0.0` = evict only the minimal necessary blocks (more eviction cycles may impact performance). Recommended: `0.05` (evict 5% of least recently used blocks per cycle). | \ No newline at end of file diff --git a/docs/flexkv_config_reference/README_zh.md b/docs/flexkv_config_reference/README_zh.md new file mode 100644 index 0000000000..1752f844bf --- /dev/null +++ b/docs/flexkv_config_reference/README_zh.md @@ -0,0 +1,145 @@ +# FlexKV 配置使用指南 + +本指南详细说明如何配置和使用 FlexKV 的在线服务配置文件(`flexkv_config.json`),涵盖所有参数的含义、推荐值及典型使用场景。 + +--- + +## 推荐配置方案 + +以下是一个兼顾性能与稳定性的生产级推荐配置: + +```json +{ + "enable_flexkv": true, + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "enable_ssd": true, + "enable_remote": false, + "use_gds": false, + "enable_trace": false, + "ssd_cache_iouring_entries": 512, + "tokens_per_block": 64, + "num_cpu_blocks": 233000, + "num_ssd_blocks": 4096000, + "ssd_cache_dir": "/data/flexkv_ssd/", + "evict_ratio": 0.05, + "index_accel": true + }, + "num_log_interval_requests": 2000 +} +``` +- 其中的`num_cpu_blocks`和`num_ssd_blocks`分别代表内存和SSD中block的总数量,需要根据实际机器配置和模型来配置,具体计算方式见下文[缓存容量配置](#cache-capacity-config) +- `ssd_cache_dir`为ssd中KVCache存放的文件目录 + +--- + +## 配置文件结构概览 + +FlexKV 的配置文件是一个 JSON 文件,主要包含三个部分: + +- `enable_flexkv`: 是否启用 FlexKV 功能(必须设为 `true` 才生效) +- `server_recv_port`: FlexKV 服务监听的 IPC 端口 +- `cache_config`: 核心缓存配置对象,包含所有缓存行为参数 +- `num_log_interval_requests`: 日志统计间隔(每处理 N 个请求输出一次性能日志) + +--- + +## cache_config完整参数详解(来自 [`flexkv/common/config.py`](../../flexkv/common/config.py)) + +### 基础配置 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `tokens_per_block` | int | 16 | 每个 KV Block 包含的 token 数量。需要与加速框架(如vLLM)中`block_size`保持一致 | +| `enable_cpu` | bool | true | 是否启用 CPU 内存作为缓存层。强烈建议开启。 | +| `enable_ssd` | bool | false | 是否启用 SSD 作为缓存层。如配备 NVMe SSD,建议开启。 | +| `enable_remote` | bool | false | 是否启用远程缓存(如可扩展云存储等)。需要配合远程缓存和自定义的远程缓存引擎使用 | +| `use_gds` | bool | false | 是否使用 GPU Direct Storage(GDS)加速 SSD 读写。目前暂不支持。 | +| `index_accel` | bool | false | 是否启用C++ RadixTree。推荐开启。 | + +--- + +### KV 缓存布局类型(一般无需修改) + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `gpu_kv_layout_type` | enum | LAYERWISE | GPU 上 KV Cache 的组织方式(按层或按块)。目前vLLM在GPU组织方式为`LAYERWISE`,因此FlexKV的`gpu_kv_layout_type`须与vLLM保持一致 | +| `cpu_kv_layout_type` | enum | BLOCKWISE | CPU 上按块组织, 推荐使用`BLOCKWISE`,不需要与vLLM保持一致 | +| `ssd_kv_layout_type` | enum | BLOCKWISE | SSD 上按块组织, 推荐使用`BLOCKWISE`,不需要与vLLM保持一致 | +| `remote_kv_layout_type` | enum | BLOCKWISE | 远程缓存按块组织, 需要按照remote组织形式定义 | + +> 注:除非有特殊性能需求,否则不建议修改布局类型。 + +--- + +### 缓存容量配置 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `num_cpu_blocks` | int | 1000000 | CPU 缓存块数。根据内存大小调整。| +| `num_ssd_blocks` | int | 10000000 | SSD 缓存块数。| +| `num_remote_blocks` | int \| None | None | 远程缓存块数。| + +> 注:FlexKV里的各级缓存的block大小与GPU中的block大小保持一致,可以参考GPU的KVCache显存大小与block数量估算各级缓存中的block数量。 + +> 注:block_size = num_layer * _kv_dim * tokens_per_block * num_head * self.head_size * torch_dtype.size()。 + +--- + +### CPU-GPU 传输优化 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `use_ce_transfer_h2d` | bool | false | 是否使用 cuda copy engine 优化 Host→Device 传输,使用CE可以减少GPU SM在传输上的使用,但是传输速度会降低,实际测试差距不大 | +| `use_ce_transfer_d2h` | bool | false | 是否使用 cuda copy engine 优化 Device→Host 传输 | +| `transfer_sms_h2d` | int | 8 | H2D 传输使用的流处理器数量 | +| `transfer_sms_d2h` | int | 8 | D2H 传输使用的流处理器数量 | + +--- + +### SSD 缓存配置 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `max_blocks_per_file` | int | 32000 | 单个 SSD 文件最多包含的 block 数。-1 表示无限制 | +| `ssd_cache_dir` | str \| List[str] | None | SSD 缓存目录路径,**必须设置**,如 `"/data/flexkv_ssd/"` | +| `ssd_cache_iouring_entries` | int | 0 | io_uring 队列深度,推荐设为 `512` 以提升并发 IO 性能,实测比不使用iouring提升较大,推荐使用512 | +| `ssd_cache_iouring_flags` | int | 0 | io_uring 标志位,一般保持 0 | + +> 注:为了充分利用多块SSD的带宽上限,可以将多块SSD绑定至不同目录,并使用如 `"ssd cache dir": ["/data0/flexkv_ssd/", "/data1/flexkv_ssd/"]`方式初始化,SSD KVCache会均匀分布在所有SSD中,充分利用多个SSD带宽。 + +> 注:`ssd_cache_iouring_entries`设置为0即不适用iouring,不推荐设置为0 + +--- + +### 远程缓存配置(不启用时无需配置) + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `remote_cache_size_mode` | str | "file_size" | 按文件大小或块数分配远程缓存空间 | +| `remote_file_size` | int \| None | None | 单个远程文件大小(字节) | +| `remote_file_num` | int \| None | None | 远程文件数量 | +| `remote_file_prefix` | str \| None | None | 远程文件名前缀 | +| `remote_cache_path` | str \| List[str] | None | 远程缓存路径(如 Redis URL、S3 路径等) | +| `remote_config_custom` | dict \| None | None | 自定义远程缓存配置(如超时、认证等) | + +--- + +### 追踪与日志 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `enable_trace` | bool | true | 是否启用性能追踪。生产环境建议关闭(`false`)以减少开销 | +| `trace_file_path` | str | "./flexkv_trace.log" | 追踪日志路径 | +| `trace_max_file_size_mb` | int | 100 | 单个追踪文件最大大小(MB) | +| `trace_max_files` | int | 5 | 最多保留的追踪文件数 | +| `trace_flush_interval_ms` | int | 1000 | 追踪日志刷新间隔(毫秒) | + +--- + +### 缓存淘汰策略 + +| 参数名 | 类型 | 默认值 | 说明 | +|--------|------|--------|------| +| `evict_ratio` | float | 0.0 | cpu,ssd一次evict主动淘汰比例(0.0 = 只淘汰最小的必要的block数量,较多的淘汰次数会影响性能)。建议保持 `0.05`,即每一次淘汰5%的最久未使用的block | diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md index 781b3ad3ee..972cade803 100644 --- a/docs/vllm_adapter/README_en.md +++ b/docs/vllm_adapter/README_en.md @@ -63,6 +63,8 @@ VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ ``` +> Note: The `flexkv_config.json` configuration is provided as a simple example only. For full parameter options, please refer to [`docs/flexkv_config_reference/README_en.md`](../../docs/flexkv_config_reference/README_en.md) + ## Legacy Version (<= 0.1.0) – Not Recommended for Current Use ### Supported Versions diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md index 0e7ce7687e..bb9b51c292 100644 --- a/docs/vllm_adapter/README_zh.md +++ b/docs/vllm_adapter/README_zh.md @@ -62,6 +62,8 @@ VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ ``` +> 注:`flexkv_config.json`配置仅为简单示例,选项请参考[`docs/flexkv_config_reference/README_zh.md`](../../docs/flexkv_config_reference/README_zh.md) + ## Legacy版本(<= 0.1.0),目前的版本尽量不要使用 ### 适用版本