diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000..6c2eea8d02 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,73 @@ +# This workflow will upload a Python Package to Release asset +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/publish.yml +name: flexkv ci + +on: + pull_request: + branches: [ "main", "dev"] + push: + branches: [ "main", "dev"] + +# Needed to create wheel and upload assets +permissions: + contents: write + +jobs: + build: + name: Build Wheel + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ['ubuntu-22.04'] + python-version: ['3.10'] + pytorch-version: ['2.6.0'] + cuda-version: ['12.4'] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Linux Env + if: ${{ runner.os == 'Linux' }} + run: | + bash -x .github/workflows/scripts/env.sh + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install CUDA ${{ matrix.cuda-version }} + run: | + bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} + + - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} + run: | + bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} + + - name: Build wheel + shell: bash + env: + TORCH_CUDA_ARCH_LIST: "8.9 9.0+PTX" + MAX_JOBS: 4 + run: | + ./build.sh --release + + - name: Get Date and Time + run: | + echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + echo "time=$(date +'%H-%M-%S')" >> $GITHUB_ENV + + - name: Upload to cos + uses: shallwefootball/s3-upload-action@master + with: + aws_key_id: ${{ secrets.COS_SECRET_ID }} + aws_secret_access_key: ${{ secrets.COS_SECRET_KEY }} + aws_bucket: ${{ secrets.COS_BUCKET }} + endpoint: ${{ secrets.COS_ENDPOINT }} + source_dir: dist + destination_dir: flexkv/${{ env.date }}/${{ env.time }} diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh new file mode 100755 index 0000000000..3e4d7c8b7d --- /dev/null +++ b/.github/workflows/scripts/cuda-install.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/cuda-install.sh + +# Replace '.' with '-' ex: 11.8 -> 11-8 +cuda_version=$(echo "$1" | tr "." "-") +# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004 +OS=$(echo "$2" | tr -d ".\-") + +# Installs CUDA +wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb" +sudo dpkg -i cuda-keyring_1.1-1_all.deb +rm cuda-keyring_1.1-1_all.deb +sudo apt -qq update +sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}" +sudo apt clean + +# Test nvcc +PATH=/usr/local/cuda-$1/bin:${PATH} +nvcc --version + +# Log gcc, g++, c++ versions +gcc --version +g++ --version +c++ --version diff --git a/.github/workflows/scripts/env.sh b/.github/workflows/scripts/env.sh new file mode 100755 index 0000000000..299f281236 --- /dev/null +++ b/.github/workflows/scripts/env.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/env.sh + +# This file installs common linux environment tools + +export LANG=C.UTF-8 + +sudo apt-get update && \ +sudo apt-get install -y --no-install-recommends \ + software-properties-common + +sudo apt-get install -y --no-install-recommends \ + build-essential \ + liburing-dev \ + git \ + cmake + +# Remove github bloat files to free up disk space +sudo rm -rf "/usr/local/share/boost" +sudo rm -rf "$AGENT_TOOLSDIRECTORY" +sudo rm -rf "/usr/share/dotnet" diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh new file mode 100755 index 0000000000..559043d412 --- /dev/null +++ b/.github/workflows/scripts/pytorch-install.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copied from vLLM github actions https://github.com/vllm-project/vllm/blob/main/.github/workflows/scripts/pytorch-install.sh + +python_executable=python$1 +pytorch_version=$2 +cuda_version=$3 + +# Install torch +$python_executable -m pip install numpy ninja cython wheel typing typing-extensions dataclasses setuptools && conda clean -ya +$python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}" + +# Print version information +$python_executable --version +$python_executable -c "import torch; print('PyTorch:', torch.__version__)" +$python_executable -c "import torch; print('CUDA:', torch.version.cuda)" +$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" diff --git a/.gitignore b/.gitignore index fd0f58caa6..03727c4ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,6 @@ cover/ # mypy .mypy_cache/ + +# VSCode +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..0fe668086a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,30 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [1.0.0] - 2025-09-11 + +### Added +- C++ radix tree for fast match, need set "index_accel": true in cache_config +- sync kernel launch +- a huge change that move cache engine to a library for accelerator(vLLM e.g.) to use instead of server-client mode. + This accelerate the get and put when no KVCache is matched. This version includes breaking API changes and is not backward compatible. +- add evict_ratio, need set "evict_ratio": 0.05 in cache_config +- reducing the bubble inner the launch kernel +- add vLLM 0.10.1.1 adapter + +### Fixed +- cython release package + + +## [0.1.0] - 2025-08-29 + +### Init +- init version +- add license + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..301a6fe36a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing to Mooncake + +Thank you for your interest in contributing to FlexKV! + +## PR Title and Classification +Use a prefixed PR title to indicate the type of changes. Please use one of the following: + +- `[bugfix]` for bugfixes +- `[feature]` for new features +- `[test]` for test cases +- `[ci/build]` for build or continuous integration improvements +- `[doc]` for documentation fixes +- `[misc]` for PRs that do not fit the above categories. Please use this sparingly. \ No newline at end of file diff --git a/README.md b/README.md index c4388efdf0..ed78dbca43 100644 --- a/README.md +++ b/README.md @@ -14,22 +14,9 @@ FlexKV is released under the **Apache-2.0 License**. See the [LICENSE](LICENSE) ./build.sh ``` -### Use FlexKV with vLLM (v0.8.4) +### Use FlexKV with vLLM -Apply the patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: - -```bash -# Start FlexKV as server -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# Start vLLM as client -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# Start benchmark -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` - -> **Note**: The current script is only compatible with the `main` branch. Support for the latest features in the `dev` branch is under development. +See [docs/vllm_adapter/README_en.md](docs/vllm_adapter/README_en.md) ## Design Architecture @@ -84,8 +71,10 @@ FlexKV performs: - *put* requests can be called asynchronously; the time to copy data from GPU to CPU memory can overlap with subsequent computation. Data transfers between CPU memory, SSD, and scalable storage are fully handled asynchronously by the TransferEngine and transparent to the main process. ## Branch -- main is the stable branch, maintaining commits that have been tested. -- dev is the development branch, maintaining newer features. +- The main branch is the stable branch, which maintains already tested commits. Please pull from main branch if you need stable code. +- The dev branch is the development branch, which contains newer features. Please branch from and merge into dev if you need new features or are developing new functionality. +- The bugfix branch is for bug fixes, maintaining urgent bugs that need immediate resolution or documentation that requires prompt updates. If you need to fix a bug or update documentation urgently, please branch from and merge into the bugfix branch. +- The stable branch refers to the previous main branch state, intended only for rollback or extremely conservative use cases (e.g., production deployment). Its use is discouraged. ## Roadmap diff --git a/README_zh.md b/README_zh.md index 5ec5c476fb..0618a83220 100644 --- a/README_zh.md +++ b/README_zh.md @@ -16,20 +16,7 @@ FlexKV 采用 **Apache-2.0 开源协议**,详细信息请参见 [LICENSE](LICE ### 以 vLLM 为例使用 FlexKV -在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: - -```bash -# 启动 FlexKV 作为服务端 -bash benchmarks/flexkv_benchmark/run_flexkv_server.sh - -# 启动 vLLM 作为客户端 -bash benchmarks/flexkv_benchmark/serving_vllm.sh - -# 启动性能测试 -bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh -``` - -> **注意**:当前脚本仅适配 `main` 分支。`dev` 分支的最新特性支持脚本正在开发中。 +见[docs/vllm_adapter/README_zh.md](docs/vllm_adapter/README_zh.md) ## 设计框架 @@ -84,8 +71,10 @@ FlexKV 在处理 *get* 请求时: - *put*请求可以异步调用,从GPU copy到内存的时间可以与之后的计算重合。内存与SSD以及扩展存储间的传输则完全由TransferEngine之后执行,主进程不感知。 ## Branch -- main 为稳定分支,维护已经测试过的commit。 -- dev 为开发分支,维护较新特性。 +- main 为稳定分支,维护已经测试过的commit。需要稳定的代码请从此分支拉取。 +- dev 为开发分支,维护较新特性。需要新特性和开发新特性请从此分支拉取和合入。 +- bugfix 为bug分支,维护需要立即解决的bug或需要立即更新的文档。需要解决bug和立即更新的文档请从此分支拉取和合入。 +- stable 为上一个版本的main分支位置,仅用于回滚以及极其保守的情况使用(如产品化)。不鼓励使用此版本。 ## Roadmap diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000..3eefcb9dd5 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +1.0.0 diff --git a/benchmarks/benchmark_kvmanager.py b/benchmarks/benchmark_kvmanager.py deleted file mode 100644 index c1cedb9519..0000000000 --- a/benchmarks/benchmark_kvmanager.py +++ /dev/null @@ -1,273 +0,0 @@ -import os -import tempfile -from multiprocessing import Process -import argparse -import json -import time -from dataclasses import dataclass - -import torch - -from flexkv.server.client import KVDPClient, KVTPClient -from flexkv.server.server import KVServer, SchedulerServer -from flexkv.common.config import ModelConfig, CacheConfig -from flexkv.common.storage import KVCacheLayoutType, KVCacheLayout -from flexkv.common.debug import flexkv_logger -from utils import load_config - -flexkv_logger.set_level("INFO") - - -@dataclass -class BenchmarkConfig: - num_layers_to_transfer: int - batch_size: int - sequence_length: int - cache_ratio: float - -def run_server(model_config, cache_config, server_recv_port): - """Run server process""" - kvserver = KVServer(model_config, cache_config, server_recv_port) - kvserver.run() - time.sleep(10) - -def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): - """Run tp_client process""" - device_id = tp_rank + dp_client_id * model_config.tp_size - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) - - num_gpu_blocks = cache_config.num_gpu_blocks - - gpu_kv_layout = KVCacheLayout( - type=cache_config.gpu_kv_layout_type, - num_layer=model_config.num_layers, - num_block=num_gpu_blocks, - tokens_per_block=cache_config.tokens_per_block, - num_head=model_config.num_kv_heads, - head_size=model_config.head_size, - is_mla=model_config.use_mla, - ) - - # Create GPU blocks for this tp_rank in the tp_client process - gpu_blocks_for_tp = [] - for _ in range(model_config.num_layers): - gpu_blocks_for_tp.append( - torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) - ) - tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) - # Keep the process running - while True: - time.sleep(1) - -def shutdown_tp_client(tp_client_processes): - for tp_process in tp_client_processes: - if tp_process.is_alive(): - tp_process.terminate() - tp_process.join(timeout=5) - if tp_process.is_alive(): - print(f"Force killing tp_client process {tp_process.pid}") - tp_process.kill() - tp_process.join(timeout=2) - -class FlexkvWrapper: - def __init__(self, model_config, cache_config, server_recv_port): - self.model_config = model_config - self.cache_config = cache_config - self.server_recv_port = server_recv_port - - self.use_scheduler_server = model_config.dp_size == 1 - if self.use_scheduler_server: - self.launch_scheduler_server() - else: - self.launch_server() - - def launch_server(self): - def server_process(): - kvserver = KVServer(self.model_config, self.cache_config, self.server_recv_port) - kvserver.run() - time.sleep(10) - self.server_process = Process( - target=server_process, - daemon=False - ) - self.server_process.start() - time.sleep(5) - self.dp_client = KVDPClient(self.server_recv_port, self.model_config) - - def launch_scheduler_server(self): - self.scheduler_server = SchedulerServer(self.model_config, self.cache_config, self.server_recv_port) - self.scheduler_server.start_server_thread() - time.sleep(10) - - @property - def dp_client_id(self): - if self.use_scheduler_server: - return 0 - else: - return self.dp_client.dp_client_id - - def put_async(self, token_ids, slot_mapping, token_mask=None): - if self.use_scheduler_server: - return self.scheduler_server.put_async(token_ids, slot_mapping, token_mask) - else: - return self.dp_client.put_async(token_ids, slot_mapping, token_mask) - - def get_async(self, token_ids, slot_mapping, token_mask=None): - if self.use_scheduler_server: - return self.scheduler_server.get_async(token_ids, slot_mapping, token_mask) - else: - return self.dp_client.get_async(token_ids, slot_mapping, token_mask) - - def wait(self, request_ids): - if self.use_scheduler_server: - return self.scheduler_server.wait(request_ids) - else: - return self.dp_client.wait(request_ids) - - def try_wait(self, request_ids): - if self.use_scheduler_server: - return self.scheduler_server.try_wait(request_ids) - else: - return self.dp_client.try_wait(request_ids) - - def check_running(self): - if self.use_scheduler_server: - return self.scheduler_server.check_running() - else: - return self.dp_client.check_running() - - def shutdown(self): - if not self.use_scheduler_server: - try: - # Send a shutdown request to the server - self.dp_client.shutdown() - # Wait a bit for graceful shutdown - time.sleep(3) - except Exception as e: - print(f"Error sending shutdown request: {e}") - if self.server_process.is_alive(): - self.server_process.terminate() - self.server_process.join(timeout=10) - if self.server_process.is_alive(): - print(f"Force killing server process {self.server_process.pid}") - self.server_process.kill() - self.server_process.join(timeout=5) - if self.server_recv_port.startswith('ipc://'): - temp_file = self.server_recv_port[6:] # Remove 'ipc://' prefix - try: - if os.path.exists(temp_file): - os.unlink(temp_file) - except Exception as e: - print(f"Error cleaning up temporary file: {e}") - else: - self.scheduler_server.shutdown() - -def benchmark_kvmanager(model_config, cache_config, benchmark_config, server_recv_port): - if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): - raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " - f"the number of available GPUs {torch.cuda.device_count()}") - print(f"{model_config = }") - print(f"{cache_config = }") - print(f"{benchmark_config = }") - flexkv_wrapper = FlexkvWrapper(model_config, cache_config, server_recv_port) - - tp_client_processes = [] - - sequence_length = benchmark_config.sequence_length - batch_size = benchmark_config.batch_size - num_required_gpu_blocks = sequence_length * batch_size // cache_config.tokens_per_block - cache_config.num_gpu_blocks = num_required_gpu_blocks - print(f"allocate {num_required_gpu_blocks} gpu blocks for benchmark") - for tp_rank in range(model_config.tp_size): - tp_client_process = Process( - target=run_tp_client, - args=(flexkv_wrapper.dp_client_id, tp_rank, server_recv_port, - model_config, cache_config), - daemon=True - ) - tp_client_process.start() - tp_client_processes.append(tp_client_process) - time.sleep(5) - - batch_sequence_tensor = [] - batch_slot_mapping = [] - cache_length = int(sequence_length * benchmark_config.cache_ratio) - - # generate requests - for i in range(batch_size): - batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length, ), dtype=torch.int64)) - batch_slot_mapping.append(torch.arange(i * sequence_length, (i+1) * sequence_length, dtype=torch.int64)) - - while not flexkv_wrapper.check_running(): - time.sleep(0.1) - print("waiting for flexkv wrapper to be ready") - # benchmark put - start_time = time.time() - put_ids = [] - if benchmark_config.cache_ratio > 0: - for i in range(batch_size): - put_ids.append(flexkv_wrapper.put_async(batch_sequence_tensor[i][:cache_length], - batch_slot_mapping[i][:cache_length], - token_mask=None)) - put_result = flexkv_wrapper.wait(put_ids) - end_time = time.time() - time.sleep(1) - elapsed_time_put = end_time - start_time - put_tokens = 0 - for _, return_mask in put_result.items(): - put_tokens += return_mask.sum().item() - transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 - transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put - print(f"put {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " - f"time: {elapsed_time_put*1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") - - #benchmark get - start_time = time.time() - get_ids = [] - for i in range(batch_size): - get_ids.append(flexkv_wrapper.get_async(batch_sequence_tensor[i], - batch_slot_mapping[i], - token_mask=None)) - get_result = flexkv_wrapper.wait(get_ids) - end_time = time.time() - elapsed_time_get = end_time - start_time - cached_tokens = 0 - all_tokens = 0 - for _, return_mask in get_result.items(): - cached_tokens += return_mask.sum().item() - all_tokens += len(return_mask) - transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 - transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get - print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " - f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " - f"time: {elapsed_time_get*1000:.2f}ms, bandwidth: {transfer_bandwidth_get:.2f} GB/s") - - shutdown_tp_client(tp_client_processes) - flexkv_wrapper.shutdown() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="benchmarks/example_config.json") - # benchmark config - parser.add_argument("--num-layers", type=int, default=-1) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--sequence-length", type=int, default=1024) - parser.add_argument("--cache-ratio", type=float, default=1) - return parser.parse_args() - -if __name__ == "__main__": - args = parse_args() - benchmark_config = BenchmarkConfig( - num_layers_to_transfer=args.num_layers, - batch_size=args.batch_size, - sequence_length=args.sequence_length, - cache_ratio=args.cache_ratio - ) - model_config, cache_config = load_config(args.config) - #cache_config.num_cpu_blocks = 8192 - 2048 - # pad sequence length to divisible by tokens_per_block - benchmark_config.sequence_length = \ - ((benchmark_config.sequence_length - 1) // cache_config.tokens_per_block + 1) * cache_config.tokens_per_block - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - benchmark_kvmanager(model_config, cache_config, benchmark_config, server_recv_port) diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py new file mode 100644 index 0000000000..397030becd --- /dev/null +++ b/benchmarks/benchmark_single_batch.py @@ -0,0 +1,191 @@ +import tempfile +from multiprocessing import Process +import argparse +import time +from dataclasses import dataclass + +import torch + +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout +from flexkv.common.debug import flexkv_logger +from flexkv.common.config import ModelConfig, CacheConfig +from utils import load_config +from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVResponseStatus + +flexkv_logger.set_level("INFO") + + +@dataclass +class BenchmarkConfig: + num_layers_to_transfer: int + batch_size: int + sequence_length: int + cache_ratio: float + clear_cpu_cache: bool + +def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config): + """Run tp_client process""" + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) + + num_gpu_blocks = cache_config.num_gpu_blocks + + gpu_kv_layout = KVCacheLayout( + type=cache_config.gpu_kv_layout_type, + num_layer=model_config.num_layers, + num_block=num_gpu_blocks, + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla, + ) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + # Keep the process running + while True: + time.sleep(1) + +def shutdown_tp_client(tp_client_processes): + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) + +def benchmark_flexkv(model_config: ModelConfig, + cache_config: CacheConfig, + benchmark_config: BenchmarkConfig, + gpu_register_port: str, + server_recv_port: str): + if model_config.tp_size * model_config.dp_size > torch.cuda.device_count(): + raise ValueError(f"tp_size {model_config.tp_size} * dp_size {model_config.dp_size} is greater than " + f"the number of available GPUs {torch.cuda.device_count()}") + print(f"{benchmark_config = }") + kvmanager = KVManager(model_config, cache_config, gpu_register_port, server_recv_port) + kvmanager.start() + + tp_client_processes = [] + + sequence_length = benchmark_config.sequence_length + batch_size = benchmark_config.batch_size + num_required_gpu_blocks = sequence_length * batch_size // cache_config.tokens_per_block + cache_config.num_gpu_blocks = num_required_gpu_blocks + print(f"allocate {num_required_gpu_blocks} gpu blocks for benchmark") + for tp_rank in range(model_config.tp_size): + tp_client_process = Process( + target=run_tp_client, + args=(0, tp_rank, gpu_register_port, + model_config, cache_config), + daemon=True + ) + tp_client_process.start() + tp_client_processes.append(tp_client_process) + + while not kvmanager.is_ready(): + time.sleep(3) + flexkv_logger.info("waiting for flexkv to be ready") + flexkv_logger.info("flexkv is ready") + + batch_sequence_tensor = [] + batch_slot_mapping = [] + cache_length = int(sequence_length * benchmark_config.cache_ratio) + + # generate requests + for i in range(batch_size): + batch_sequence_tensor.append(torch.randint(0, 100000, (sequence_length, ), dtype=torch.int64)) + batch_slot_mapping.append(torch.arange(i * sequence_length, (i+1) * sequence_length, dtype=torch.int64)) + + # benchmark put + start_time = time.time() + batch_put_ids = [] + if benchmark_config.cache_ratio > 0: + for i in range(batch_size): + task_id = kvmanager.put_async(batch_sequence_tensor[i][:cache_length], + batch_slot_mapping[i][:cache_length], + token_mask=None) + batch_put_ids.append(task_id) + put_result = kvmanager.wait(batch_put_ids, completely=True) + end_time = time.time() + + if benchmark_config.clear_cpu_cache: + kvmanager._clear_cpu_cache() + + elapsed_time_put = end_time - start_time + put_tokens = 0 + for _, response in put_result.items(): + if response.status == KVResponseStatus.SUCCESS: + put_tokens += response.return_mask.sum().item() + transfer_data_size_GB = put_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 + transfer_bandwidth_put = transfer_data_size_GB / elapsed_time_put + print(f"put {put_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"time: {elapsed_time_put*1000:.2f}ms, bandwidth: {transfer_bandwidth_put:.2f} GB/s") + + all_tokens = 0 + start_time = time.time() + batch_get_ids = [] + for i in range(batch_size): + all_tokens += len(batch_sequence_tensor[i]) + task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], + token_mask=None) + batch_get_ids.append(task_id) + get_match_time = time.time() - start_time + kvmanager.launch(batch_get_ids, batch_slot_mapping) + get_result = kvmanager.wait(batch_get_ids) + elapsed_time_get = time.time() - start_time + cached_tokens = 0 + for _, response in get_result.items(): + if response.status == KVResponseStatus.SUCCESS: + cached_tokens += response.return_mask.sum().item() + transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 + transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get + print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " + f"cache_ratio: {cached_tokens * 100 / all_tokens:.2f}%, " + f"match time: {get_match_time*1000:.2f}ms, " + f"e2e time: {elapsed_time_get*1000:.2f}ms, " + f"bandwidth: {transfer_bandwidth_get:.2f} GB/s") + + shutdown_tp_client(tp_client_processes) + kvmanager.shutdown() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="benchmarks/example_config.json") + # benchmark config + parser.add_argument("--num-layers", type=int, default=-1) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--sequence-length", type=int, default=1024) + parser.add_argument("--cache-ratio", type=float, default=1) + parser.add_argument("--clear-cpu-cache", action="store_true") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + benchmark_config = BenchmarkConfig( + num_layers_to_transfer=args.num_layers, + batch_size=args.batch_size, + sequence_length=args.sequence_length, + cache_ratio=args.cache_ratio, + clear_cpu_cache=args.clear_cpu_cache + ) + model_config, cache_config = load_config(args.config) + #cache_config.num_cpu_blocks = 8192 - 2048 + # pad sequence length to divisible by tokens_per_block + benchmark_config.sequence_length = \ + ((benchmark_config.sequence_length - 1) // cache_config.tokens_per_block + 1) * cache_config.tokens_per_block + import uuid + gpu_register_port = f"ipc:///tmp/flexkv_gpu_{uuid.uuid4().hex[:8]}" + server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" + + benchmark_flexkv(model_config, cache_config, benchmark_config, gpu_register_port, server_recv_port) diff --git a/benchmarks/benchmark_workers.py b/benchmarks/benchmark_workers.py index f9c50f8bc7..8c1c2182e3 100644 --- a/benchmarks/benchmark_workers.py +++ b/benchmarks/benchmark_workers.py @@ -9,7 +9,7 @@ import torch -from flexkv.common.transfer import TransferOp, TransferType, TransferDescriptor +from flexkv.common.transfer import TransferOp, TransferType from flexkv.transfer.worker import GPUCPUTransferWorker, CPUSSDDiskTransferWorker, WorkerHandle, tpGPUCPUTransferWorker from flexkv.storage.allocator import CPUAllocator, GPUAllocator, SSDAllocator from flexkv.common.storage import KVCacheLayoutType, KVCacheLayout @@ -17,7 +17,7 @@ from flexkv.common.debug import flexkv_logger -flexkv_logger.set_level("OFF") +# flexkv_logger.set_level("OFF") @dataclass class BenchmarkConfig: @@ -50,11 +50,10 @@ def make_configs(args: dict) -> Tuple[ModelConfig, CacheConfig, BenchmarkConfig] def create_cpu_gpu_worker( model_config: ModelConfig, - cache_config: CacheConfig, - bench_config: BenchmarkConfig) -> Tuple[WorkerHandle, mp.Queue]: + cache_config: CacheConfig) -> Tuple[WorkerHandle, mp.Queue]: mp.set_start_method('spawn', force=True) cpu_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.cpu_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_cpu_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -85,7 +84,6 @@ def create_cpu_gpu_worker( finished_ops_queue = mp.Queue() if model_config.tp_size == 1: worker_handle = GPUCPUTransferWorker.create_worker( - worker_id=0, finished_ops_queue=finished_ops_queue, gpu_blocks=gpu_handles[0].get_tensor_handle_list(), cpu_blocks=cpu_handle.get_tensor(), @@ -100,7 +98,6 @@ def create_cpu_gpu_worker( ) else: worker_handle = tpGPUCPUTransferWorker.create_worker( - worker_id=0, finished_ops_queue=finished_ops_queue, gpu_blocks=[handle.get_tensor_handle_list() for handle in gpu_handles], cpu_blocks=cpu_handle.get_tensor(), @@ -121,11 +118,10 @@ def create_cpu_gpu_worker( def create_cpu_ssd_worker( model_config: ModelConfig, - cache_config: CacheConfig, - bench_config: BenchmarkConfig) -> Tuple[WorkerHandle, mp.Queue]: + cache_config: CacheConfig) -> Tuple[WorkerHandle, mp.Queue]: mp.set_start_method('spawn', force=True) cpu_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.cpu_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_cpu_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -133,7 +129,7 @@ def create_cpu_ssd_worker( head_size=model_config.head_size, ) ssd_layout = KVCacheLayout( - type=KVCacheLayoutType.LAYERWISE, + type=KVCacheLayoutType(cache_config.ssd_kv_layout_type), num_layer=model_config.num_layers, num_block=cache_config.num_ssd_blocks, tokens_per_block=cache_config.tokens_per_block, @@ -153,7 +149,6 @@ def create_cpu_ssd_worker( ) finished_ops_queue = mp.Queue() worker_handle = CPUSSDDiskTransferWorker.create_worker( - worker_id=10, finished_ops_queue=finished_ops_queue, cpu_blocks=cpu_handle.get_tensor(), ssd_files=ssd_handle.get_file_list(), @@ -195,29 +190,25 @@ def bench_worker(args): shuffle_ids = bench_config.shuffle_ids if transfer_type == TransferType.H2D or transfer_type == TransferType.D2H: - worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config, bench_config) + worker_handle, finished_ops_queue = create_cpu_gpu_worker(model_config, cache_config) elif transfer_type == TransferType.H2DISK or transfer_type == TransferType.DISK2H: - worker_handle, finished_ops_queue = create_cpu_ssd_worker(model_config, cache_config, bench_config) + worker_handle, finished_ops_queue = create_cpu_ssd_worker(model_config, cache_config) else: raise ValueError(f"Unsupported transfer type: {transfer_type} for benchmark, " f"currently only support {TransferType.H2D.name}, {TransferType.D2H.name}, " f"{TransferType.H2DISK.name}, {TransferType.DISK2H.name}") if shuffle_ids: - block_ids = torch.randperm(num_blocks_to_transfer) + block_ids = torch.randperm(num_blocks_to_transfer).numpy() else: - block_ids = torch.arange(num_blocks_to_transfer) + block_ids = torch.arange(num_blocks_to_transfer).numpy() transfer_op = TransferOp( transfer_type=transfer_type, layer_id=0, layer_granularity=num_layers_to_transfer, - src_descriptor=TransferDescriptor( - physical_block_ids=block_ids, - ), - dst_descriptor=TransferDescriptor( - physical_block_ids=block_ids, - ), + src_block_ids=block_ids, + dst_block_ids=block_ids, graph_id=0, dp_id=0, successors=[], @@ -226,8 +217,8 @@ def bench_worker(args): if transfer_type == TransferType.DISK2H: tmp_op = copy.deepcopy(transfer_op) tmp_op.transfer_type = TransferType.H2DISK - tmp_op.src_descriptor = transfer_op.dst_descriptor - tmp_op.dst_descriptor = transfer_op.src_descriptor + tmp_op.src_block_ids = transfer_op.dst_block_ids + tmp_op.dst_block_ids = transfer_op.src_block_ids launch_transfer(worker_handle, finished_ops_queue, tmp_op) for _ in range(warmup_round): launch_transfer(worker_handle, finished_ops_queue, transfer_op) diff --git a/benchmarks/example_config.json b/benchmarks/example_config.json index 630bdded0b..d4854557c3 100644 --- a/benchmarks/example_config.json +++ b/benchmarks/example_config.json @@ -14,11 +14,10 @@ "enable_remote": false, "tokens_per_block": 16, "use_gds": false, - "use_pinned_memory": true, "gpu_kv_layout_type": "LAYERWISE", - "cpu_kv_layout_type": "LAYERWISE", - "ssd_kv_layout_type": "LAYERWISE", - "remote_kv_layout_type": "LAYERWISE", + "cpu_kv_layout_type": "BLOCKWISE", + "ssd_kv_layout_type": "BLOCKWISE", + "remote_kv_layout_type": "BLOCKWISE", "num_cpu_blocks": 2048, "num_ssd_blocks": 4096, "num_remote_blocks": null, @@ -28,7 +27,7 @@ "transfer_sms_d2h": 8, "max_blocks_per_file": 32000, "ssd_cache_dir": "./ssd_cache1/", - "ssd_cache_iouring_entries": 512, + "ssd_cache_iouring_entries": 32, "ssd_cache_iouring_flags": 0, "remote_cache_size_mode": "file_size", "remote_file_size": null, @@ -40,6 +39,8 @@ "trace_file_path": "./flexkv_trace.log", "trace_max_file_size_mb": 100, "trace_max_files": 5, - "trace_flush_interval_ms": 1000 + "trace_flush_interval_ms": 1000, + "evict_ratio": 0.05, + "index_accel": true } } diff --git a/build.sh b/build.sh index 8b34f27df4..b976fec7d9 100755 --- a/build.sh +++ b/build.sh @@ -40,13 +40,6 @@ echo "=== Setting BUILD_LIB_PATH to $BUILD_LIB_PATH ===" cd .. -echo "=== Installing package with pip ===" -if [ "$BUILD_TYPE" = "debug" ]; then - FLEXKV_DEBUG=1 pip install --no-build-isolation -e . -else - FLEXKV_DEBUG=0 pip install --no-build-isolation -e . -fi - # Set LD_LIBRARY_PATH for immediate use export LD_LIBRARY_PATH=$BUILD_LIB_PATH:$LD_LIBRARY_PATH echo "Added $BUILD_LIB_PATH to LD_LIBRARY_PATH for current session" @@ -69,3 +62,11 @@ fi echo "=== Build and installation completed successfully in ${BUILD_TYPE} mode ===" echo "You can now run tests directly without setting LD_LIBRARY_PATH manually" + +if [ "$BUILD_TYPE" = "debug" ]; then + FLEXKV_DEBUG=1 pip install -v --no-build-isolation -e . +elif [ "$BUILD_TYPE" = "release" ]; then + FLEXKV_DEBUG=0 python setup.py bdist_wheel -v +else + FLEXKV_DEBUG=0 pip install -v --no-build-isolation -e . +fi diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index a991954be0..c03ad74c40 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -19,6 +19,7 @@ #include "tp_transfer_thread_group.h" #include "transfer.cuh" #include "transfer_ssd.h" +#include "radix_tree.h" namespace py = pybind11; @@ -143,12 +144,10 @@ PYBIND11_MODULE(c_ext, m) { py::class_(m, "TPTransferThreadGroup") .def(py::init> &, - torch::Tensor &, int>()) + torch::Tensor &, int, torch::Tensor &, torch::Tensor &, torch::Tensor &>()) .def("tp_group_transfer", &flexkv::TPTransferThreadGroup::tp_group_transfer, - py::arg("gpu_block_id_tensor"), py::arg("gpu_kv_stride_in_bytes"), - py::arg("gpu_block_stride_in_bytes"), - py::arg("gpu_chunk_size_in_bytes"), py::arg("cpu_block_id_tensor"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), py::arg("cpu_kv_stride_in_bytes"), py::arg("cpu_layer_stride_in_bytes"), py::arg("cpu_block_stride_in_bytes"), @@ -195,4 +194,36 @@ PYBIND11_MODULE(c_ext, m) { "Call Pcfs::write from C++", py::arg("file_nodeid"), py::arg("offset"), py::arg("buffer"), py::arg("size"), py::arg("thread_id")); #endif + + py::class_(m, "CRadixTreeIndex") + .def(py::init()) + .def("is_empty", &flexkv::CRadixTreeIndex::is_empty) + .def("reset", &flexkv::CRadixTreeIndex::reset) + .def("lock", &flexkv::CRadixTreeIndex::lock, py::arg("node")) + .def("unlock", &flexkv::CRadixTreeIndex::unlock, py::arg("node")) + .def("set_ready", &flexkv::CRadixTreeIndex::set_ready, + py::arg("node"), py::arg("ready"), py::arg("ready_length")) + .def("insert", &flexkv::CRadixTreeIndex::insert, py::return_value_policy::reference, + py::arg("physical_block_ids"), py::arg("block_hashes"), py::arg("num_blocks"), + py::arg("num_insert_blocks"), py::arg("ready") = true, py::arg("node") = nullptr, + py::arg("num_matched_blocks") = -1, py::arg("last_node_matched_length") = -1) + .def("evict", &flexkv::CRadixTreeIndex::evict, py::arg("evicted_blocks"), py::arg("num_evicted")) + .def("total_cached_blocks", &flexkv::CRadixTreeIndex::total_cached_blocks) + .def("total_unready_blocks", &flexkv::CRadixTreeIndex::total_unready_blocks) + .def("total_ready_blocks", &flexkv::CRadixTreeIndex::total_ready_blocks) + .def("match_prefix", &flexkv::CRadixTreeIndex::match_prefix, + py::arg("block_hashes"), py::arg("num_blocks"), py::arg("update_cache_info")); + + py::class_(m, "CRadixNode") + .def(py::init()) + .def("size", &flexkv::CRadixNode::size); + + py::class_>(m, "CMatchResult") + .def(py::init *>()) + .def_readonly("last_ready_node", &flexkv::CMatchResult::last_ready_node) + .def_readonly("last_node", &flexkv::CMatchResult::last_node) + .def_readonly("physical_blocks", &flexkv::CMatchResult::physical_blocks) + .def_readonly("num_ready_matched_blocks", &flexkv::CMatchResult::num_ready_matched_blocks) + .def_readonly("num_matched_blocks", &flexkv::CMatchResult::num_matched_blocks) + .def_readonly("last_node_matched_length", &flexkv::CMatchResult::last_node_matched_length); } diff --git a/csrc/radix_tree.cpp b/csrc/radix_tree.cpp new file mode 100644 index 0000000000..32dc8ab0e7 --- /dev/null +++ b/csrc/radix_tree.cpp @@ -0,0 +1,268 @@ +#include +#include +#include +#include +#include +#include + +#include "cache_utils.h" +#include "radix_tree.h" + +namespace flexkv { + +CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) { + assert(index != nullptr); + + this->on_leaf = false; + this->parent = nullptr; + this->index = index; + this->ready = ready; + this->lock_cnt = lock_cnt; + + struct timeval now; + gettimeofday(&now, nullptr); + last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + + index->inc_node_count(); +} + +CRadixNode::~CRadixNode() { + assert(parent == nullptr); + + block_hashes.clear(); + physical_blocks.clear(); + children.clear(); + + index->dec_node_count(); +} + +CRadixNode *CRadixNode::split(int prefix_length) { + assert(prefix_length < size()); + assert(prefix_length > 0); + assert(parent != nullptr); + + auto new_node = new CRadixNode(index, is_ready(), 0); + new_node->set_time(get_time()); + new_node->set_parent(parent); + get_index()->add_node(new_node); + + auto &new_block_hashes = new_node->get_block_hashes(); + auto &new_physical_blocks = new_node->get_physical_blocks(); + + new_block_hashes.insert(new_block_hashes.end(), block_hashes.cbegin(), block_hashes.cbegin() + prefix_length); + new_physical_blocks.insert(new_physical_blocks.end(), physical_blocks.cbegin(), physical_blocks.cbegin() + prefix_length); + + block_hashes.erase(block_hashes.begin(), block_hashes.begin() + prefix_length); + physical_blocks.erase(physical_blocks.begin(), physical_blocks.begin() + prefix_length); + + parent->set_child(new_node->get_head_hash(), new_node); + new_node->set_parent(parent); + new_node->set_child(get_head_hash(), this); + + set_parent(new_node); + return new_node; +} + +void CRadixNode::merge_child() { + auto child = children.begin()->second; + + assert(get_num_children() == 1); + assert(child->is_leaf()); + + block_hashes.insert(block_hashes.end(), child->get_block_hashes().cbegin(), + child->get_block_hashes().cend()); + physical_blocks.insert(physical_blocks.end(), child->get_physical_blocks().cbegin(), + child->get_physical_blocks().cend()); + + set_time(std::max(get_time(), child->get_time())); + children.clear(); + + child->clear_parent(); + index->remove_leaf(child); + index->remove_node(child); +} + +std::deque *CRadixNode::shrink(int length) { + assert(length < size()); + assert(length > 0); + assert(is_leaf()); + assert(in_use() == false); + + auto remaining_length = size() - length; + auto shrink_blocks = new std::deque(); + + shrink_blocks->insert(shrink_blocks->end(), physical_blocks.begin() + remaining_length, physical_blocks.end()); + + block_hashes.erase(block_hashes.begin() + remaining_length, block_hashes.end()); + physical_blocks.erase(physical_blocks.begin() + remaining_length, physical_blocks.end()); + + return shrink_blocks; +} + +CRadixNode *CRadixTreeIndex::insert(torch::Tensor &physical_block_ids, + torch::Tensor &block_hashes, int num_blocks, int num_insert_blocks, bool ready, + CRadixNode *last_node, int num_matched_blocks, int last_node_matched_length) { + if (num_insert_blocks == -1) { + num_insert_blocks = num_blocks; + } + assert(num_insert_blocks >= 0); + assert(num_insert_blocks <= num_blocks); + assert(physical_block_ids.ndim() == 1); + + if (last_node == nullptr) { + auto match_result = match_prefix(block_hashes, num_blocks, true); + num_matched_blocks = match_result->num_matched_blocks; + last_node_matched_length = match_result->last_node_matched_length; + last_node = match_result->last_node; + } + + assert(last_node != nullptr); + assert(last_node_matched_length != 0 || is_root(last_node)); + assert(physical_block_ids.size() == num_insert_blocks - num_matched_blocks); + + if (num_matched_blocks >= num_insert_blocks) { + return nullptr; + } + + auto new_node = new CRadixNode(this, ready, 0); + auto &new_block_hashes = new_node->get_block_hashes(); + auto &new_physical_blocks = new_node->get_physical_blocks(); + + auto block_hashes_ptr = block_hashes.data_ptr(); + auto physical_block_ids_ptr = physical_block_ids.data_ptr(); + for (auto i = 0; i + num_matched_blocks < num_insert_blocks; i++) { + new_block_hashes.insert(new_block_hashes.end(), block_hashes_ptr[i+num_matched_blocks]); + new_physical_blocks.insert(new_physical_blocks.end(), physical_block_ids_ptr[i]); + } + + if (last_node_matched_length < last_node->size()) { + last_node->split(last_node_matched_length); + last_node = last_node->get_parent(); + assert(last_node != nullptr); + } + + if (last_node->is_leaf()) { + remove_leaf(last_node); + } + + new_node->set_parent(last_node); + last_node->set_child(new_node->get_head_hash(), new_node); + + add_node(new_node); + add_leaf(new_node); + return new_node; +} + +int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, int num_evicted) { + int64_t *evicted_blocks_ptr = evicted_blocks.data_ptr(); + int has_evicted = 0; + std::priority_queue, CRadixNode::Compare> candidate; + + for (auto it = leaf_list.begin(); it != leaf_list.end(); it++) { + if ((*it)->evictable()) { + candidate.push(*it); + } + } + + while ((has_evicted < num_evicted) && candidate.size()) { + auto node = candidate.top(); + candidate.pop(); + + if (node->size() > num_evicted - has_evicted) { + auto blocks = node->shrink(num_evicted - has_evicted); + for (auto it = blocks->begin(); it != blocks->end(); it++) { + evicted_blocks_ptr[has_evicted] = *it; + has_evicted++; + } + delete blocks; + } else { + auto parent = node->get_parent(); + auto &blocks = node->get_physical_blocks(); + + assert(parent != nullptr); + parent->remove_child(node->get_head_hash()); + + for (auto it = blocks.begin(); it != blocks.end(); it++) { + evicted_blocks_ptr[has_evicted] = *it; + has_evicted++; + } + + if (parent->is_leaf() && !is_root(parent)) { + add_leaf(parent); + if (parent->evictable()) { + candidate.push(parent); + } + } + + node->clear_parent(); + remove_leaf(node); + remove_node(node); + } + } + return has_evicted; +} + +std::shared_ptr CRadixTreeIndex::match_prefix( + torch::Tensor &block_hashes, int num_blocks, bool update_cache_info) { + auto current_node = root; + auto last_ready_node = root; + auto prefix_blocks_num = 0; + auto ready_prefix_blocks_num = 0; + auto last_node_matched_length = 0; + auto physical_blocks = new std::vector(); + auto block_hashes_ptr = block_hashes.data_ptr(); + HashType child_hash; + + while (prefix_blocks_num < num_blocks) { + if (update_cache_info) { + current_node->update_time(); + } + + child_hash = HashType(block_hashes_ptr[prefix_blocks_num + current_node->size()]); + if (current_node->lookup_child(child_hash)) { + if (current_node->is_ready()) { + last_ready_node = current_node; + ready_prefix_blocks_num += current_node->size(); + } + prefix_blocks_num += current_node->size(); + physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), + current_node->get_physical_blocks().end()); + current_node = current_node->get_child(child_hash); + } else { + auto matched_length = 0; + if (is_root(current_node) == false) { + auto cmp_length = std::min(current_node->size(), num_blocks - prefix_blocks_num); + auto left = 0; + auto right = cmp_length; + + while (left < right) { + auto mid = (left + right) / 2; + if (current_node->get_hash(mid) == HashType(block_hashes_ptr[prefix_blocks_num+mid])) { + left = mid + 1; + } else { + right = mid; + } + } + matched_length = left; + physical_blocks->insert(physical_blocks->end(), current_node->get_physical_blocks().begin(), + current_node->get_physical_blocks().begin() + matched_length); + } else { + matched_length = 0; + } + + if (current_node->is_ready()) { + last_ready_node = current_node; + ready_prefix_blocks_num += matched_length; + } + + last_node_matched_length = matched_length; + prefix_blocks_num += matched_length; + break; + } + } + + return std::make_shared(prefix_blocks_num, ready_prefix_blocks_num, last_node_matched_length, + last_ready_node, current_node, physical_blocks); +} + +} // namespace flexkv diff --git a/csrc/radix_tree.h b/csrc/radix_tree.h new file mode 100644 index 0000000000..63560a3a8d --- /dev/null +++ b/csrc/radix_tree.h @@ -0,0 +1,346 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "cache_utils.h" + +namespace flexkv { + +class CRadixTreeIndex; + +class CRadixNode { +private: + bool on_leaf; + bool ready; + int lock_cnt; + time_t last_access_time; + + std::deque block_hashes; + std::deque physical_blocks; + std::map children; + + CRadixTreeIndex *index; + CRadixNode *parent; + +public: + CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt); + ~CRadixNode(); + + struct Compare { + bool operator() (CRadixNode *a, CRadixNode *b) { + return a->get_time() > b->get_time(); + } + }; + + bool get_leaf_state() { + return on_leaf; + } + + void set_leaf_state(bool on_leaf) { + this->on_leaf = on_leaf; + } + + CRadixTreeIndex *get_index() { + return index; + } + + void set_time(time_t time) { + last_access_time = time; + } + + time_t get_time() { + return last_access_time; + } + + void update_time() { + struct timeval now; + + gettimeofday(&now, nullptr); + last_access_time = now.tv_sec * 1000 + now.tv_usec / 10000; + } + + CRadixNode *get_parent() { + return parent; + } + + void set_parent(CRadixNode *parent) { + this->parent = parent; + } + + void clear_parent() { + this->parent = nullptr; + } + + HashType get_hash(int pos) { + return HashType(block_hashes[pos]); + } + + HashType get_head_hash() { + if (size() > 0) { + return HashType(block_hashes[0]); + } else { + return HashType(0); + } + } + + int size() { + return block_hashes.size(); + } + + int get_num_children() { + return children.size(); + } + + std::deque &get_block_hashes() { + return block_hashes; + } + + std::deque &get_physical_blocks() { + return physical_blocks; + } + + bool lookup_child(HashType hash) { + auto iter = children.find(hash); + if (iter != children.end()) + return true; + else + return false; + } + + CRadixNode *get_child(HashType hash) { + return children.at(hash); + } + + void set_child(HashType hash, CRadixNode *node) { + children[hash] = node; + } + + void remove_child(HashType hash) { + children.erase(hash); + } + + bool is_leaf() { + return get_num_children() == 0; + } + + bool in_use() { + return lock_cnt > 0 || !ready; + } + + bool evictable() { + return is_leaf() && !in_use(); + } + + void lock() { + assert(lock_cnt >= 0); + lock_cnt++; + } + + void unlock() { + assert(lock_cnt > 0); + lock_cnt--; + } + + void set_ready(bool ready) { + this->ready = ready; + } + + bool is_ready() { + return ready; + } + + CRadixNode *split(int prefix_length); + std::deque *shrink(int length); + void merge_child(); +}; + +class CMatchResult { +public: + int num_ready_matched_blocks; + int num_matched_blocks; + int last_node_matched_length; + + CRadixNode *last_ready_node; + CRadixNode *last_node; + std::vector *physical_blocks; + + CMatchResult(int _num_ready_matched_blocks, int _num_matched_blocks, int _last_node_matched_length, + CRadixNode *_last_ready_node, CRadixNode *_last_node, std::vector *blocks) + : num_ready_matched_blocks(_num_ready_matched_blocks), num_matched_blocks(_num_matched_blocks), + last_node_matched_length(_last_node_matched_length), last_ready_node(_last_ready_node), + last_node(_last_node), physical_blocks(blocks) { + } + + ~CMatchResult() { + delete physical_blocks; + }; +}; + +class CRadixTreeIndex { +private: + CRadixNode *root; + std::list node_list; + std::list leaf_list; + + int max_num_blocks; + int tokens_per_block; + int node_count; + +public: + CRadixTreeIndex(int tokens_per_block, int max_num_blocks = 1000000) { + this->tokens_per_block = tokens_per_block; + this->max_num_blocks = max_num_blocks; + this->node_count = 0; + + root = new CRadixNode(this, true, 0); + node_list.push_back(root); + } + + ~CRadixTreeIndex() { + leaf_list.clear(); + + while (node_list.size()) { + auto node = node_list.front(); + node->set_parent(nullptr); + node_list.pop_front(); + delete node; + } + + if (node_count) { + std::cerr << "CRadix Node count" << node_count << std::endl; + } + } + + void reset() { + leaf_list.clear(); + + while (node_list.size()) { + auto node = node_list.front(); + node->set_parent(nullptr); + node_list.pop_front(); + delete node; + } + + root = new CRadixNode(this, true, 0); + node_list.push_back(root); + } + + bool is_root(CRadixNode *node) { + return node == root; + } + + CRadixNode *get_root() { + return root; + } + + void remove_node(CRadixNode *node) { + assert(node != root); + assert(node->get_parent() == nullptr); + + node_list.remove(node); + delete node; + } + + void remove_leaf(CRadixNode *node) { + assert(node != root); + assert(node->get_leaf_state()); + + if (node->get_leaf_state() == false) { + return; + } + + leaf_list.remove(node); + node->set_leaf_state(false); + } + + void add_node(CRadixNode *node) { + assert(node != nullptr); + assert(node->get_parent() != nullptr); + node_list.push_back(node); + } + + void add_leaf(CRadixNode *node) { + assert(node != nullptr); + assert(node->get_leaf_state() == false); + + if (node->get_leaf_state() == true) { + return; + } + + leaf_list.push_back(node); + node->set_leaf_state(true); + } + + void lock(CRadixNode *node) { + node->lock(); + } + + void unlock(CRadixNode *node) { + node->unlock(); + } + + bool is_empty() { + return node_list.size() == 1; + } + + void inc_node_count() { + node_count++; + } + + void dec_node_count() { + node_count--; + } + + void set_ready(CRadixNode *node, bool ready = true, int ready_length = -1) { + node->set_ready(ready); + if (ready_length > 0) { + ready_length -= node->size(); + while (ready_length > 0) { + assert(node->get_parent() != nullptr); + node = node->get_parent(); + ready_length -= node->size(); + node->set_ready(true); + } + assert(ready_length == 0); + } + } + + int total_node_num() { + return node_list.size() - 1; + } + + int total_cached_blocks() { + auto total_blocks = 0; + + for (auto it = node_list.begin(); it != node_list.end(); it++) { + total_blocks += (*it)->size(); + } + return total_blocks; + } + + int total_ready_blocks() { + auto total_blocks = 0; + for (auto it = node_list.begin(); it != node_list.end(); it++) { + if ((*it)->is_ready()) { + total_blocks += (*it)->size(); + } + } + return total_blocks; + } + + int total_unready_blocks() { + return total_cached_blocks() - total_ready_blocks(); + } + + int evict(torch::Tensor &evicted_blocks, int num_evicted); + std::shared_ptr match_prefix(torch::Tensor &block_hashes, + int num_blocks, bool update_cache_info = true); + CRadixNode *insert(torch::Tensor &physical_block_ids, torch::Tensor &block_hashes, int num_blocks, + int num_insert_blocks, bool ready = true, CRadixNode *node = nullptr, int num_matched_blocks = -1, + int last_node_matched_length = -1); +}; + +} // namespace flexkv diff --git a/csrc/tp_transfer_thread_group.cpp b/csrc/tp_transfer_thread_group.cpp index 617ac8abd2..06cb45c4fe 100644 --- a/csrc/tp_transfer_thread_group.cpp +++ b/csrc/tp_transfer_thread_group.cpp @@ -22,8 +22,31 @@ namespace flexkv { TPTransferThreadGroup::TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, - torch::Tensor &cpu_blocks, int dp_group_id) { + torch::Tensor &cpu_blocks, int dp_group_id, + torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor) { + num_gpus_ = num_gpus; + + gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t* kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); + int64_t* block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t* chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; + gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; + gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + } + + queues_.resize(num_gpus_); + mtxs_ = std::vector(num_gpus_); + cvs_ = std::vector(num_gpus_); + int num_layers = gpu_blocks[0].size(); cudaMallocHost((void **)&gpu_blocks_, num_gpus_ * num_layers * sizeof(void *)); @@ -41,15 +64,55 @@ TPTransferThreadGroup::TPTransferThreadGroup( cudaSetDevice(dp_group_id * num_gpus_ + i); cudaStreamCreate(&streams_[i]); } + // create the thread pool + stop_pool_=false; + for (int i = 0; i < num_gpus_; ++i) { + threads_.emplace_back([this, i]() { + int device_id = dp_group_id_ * num_gpus_ + i; + cudaSetDevice(device_id); // only once + + while (true) { + Task task; + { + std::unique_lock lk(mtxs_[i]); + cvs_[i].wait(lk, [&]{ return stop_pool_ || !queues_[i].empty(); }); + if (stop_pool_ && queues_[i].empty()) return; + + task = std::move(queues_[i].front()); + queues_[i].pop(); + } + task(); // + } + }); + } + +} + +TPTransferThreadGroup::~TPTransferThreadGroup() { + stop_pool_ = true; + for (auto& cv : cvs_) cv.notify_all(); + for (auto& t : threads_) if (t.joinable()) t.join(); + + cudaFreeHost(gpu_blocks_); + + delete[] gpu_kv_strides_in_bytes_; + delete[] gpu_block_strides_in_bytes_; + delete[] gpu_chunk_sizes_in_bytes_; } -TPTransferThreadGroup::~TPTransferThreadGroup() {} +std::future TPTransferThreadGroup::enqueue_for_gpu(int gpu_idx, Task task) { + auto pkg = std::make_shared>(std::move(task)); + auto fut = pkg->get_future(); + { + std::lock_guard lk(mtxs_[gpu_idx]); + queues_[gpu_idx].emplace([pkg]{ (*pkg)(); }); + } + cvs_[gpu_idx].notify_one(); + return fut; +} void TPTransferThreadGroup::tp_group_transfer( const torch::Tensor &gpu_block_id_tensor, - const int64_t gpu_kv_stride_in_bytes, - const int64_t gpu_block_stride_in_bytes, - const int64_t gpu_chunk_size_in_bytes, const torch::Tensor &cpu_block_id_tensor, const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, @@ -60,11 +123,15 @@ void TPTransferThreadGroup::tp_group_transfer( std::atomic failed{false}; std::string error_msg; - threads_.clear(); - threads_.reserve(num_gpus_); + // threads_.clear(); + // threads_.reserve(num_gpus_); - for (int i = 0; i < num_gpus_; ++i) { - threads_.emplace_back([&, i]() { + // Barrier sync_point(num_gpus_); + std::vector> futures; + futures.reserve(num_gpus_); + + for (int i=0; i(gpu_blocks_ + i * num_layers + layer_id); void *cpu_ptr = cpu_blocks_; int64_t cpu_startoff_inside_chunks = - is_mla ? 0 : i * gpu_chunk_size_in_bytes; - cudaSetDevice(dp_group_id_ * num_gpus_ + i); + is_mla ? 0 : i * gpu_chunk_sizes_in_bytes_[i]; + flexkv::transfer_kv_blocks( - num_blocks, layer_id, layer_granularity, gpu_block_ids, - gpu_layer_ptrs, gpu_kv_stride_in_bytes, gpu_block_stride_in_bytes, - cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, - cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, - cpu_startoff_inside_chunks, gpu_chunk_size_in_bytes, streams_[i], - transfer_sms, is_host_to_device, use_ce_transfer, is_mla); + num_blocks, layer_id, layer_granularity, gpu_block_ids, + gpu_layer_ptrs, gpu_kv_strides_in_bytes_[i], gpu_block_strides_in_bytes_[i], + cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, + cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, + cpu_startoff_inside_chunks, gpu_chunk_sizes_in_bytes_[i], streams_[i], + transfer_sms, is_host_to_device, use_ce_transfer, is_mla + ); + cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { failed = true; @@ -95,12 +164,12 @@ void TPTransferThreadGroup::tp_group_transfer( failed = true; error_msg = e.what(); } - }); + + })); } - for (auto &t : threads_) { - if (t.joinable()) - t.join(); + for (auto &f : futures){ + f.get(); } if (failed) { diff --git a/csrc/tp_transfer_thread_group.h b/csrc/tp_transfer_thread_group.h index 315c36c541..3d57e569c5 100644 --- a/csrc/tp_transfer_thread_group.h +++ b/csrc/tp_transfer_thread_group.h @@ -24,20 +24,22 @@ #include #include #include - +#include +#include +#include +#include namespace flexkv { - class TPTransferThreadGroup { public: TPTransferThreadGroup( int num_gpus, const std::vector> &gpu_blocks, - torch::Tensor &cpu_blocks, int dp_group_id); + torch::Tensor &cpu_blocks, int dp_group_id, + torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor); ~TPTransferThreadGroup(); void tp_group_transfer(const torch::Tensor &gpu_block_id_tensor, - const int64_t gpu_kv_stride_in_bytes, - const int64_t gpu_block_stride_in_bytes, - const int64_t gpu_chunk_size_in_bytes, const torch::Tensor &cpu_block_id_tensor, const int64_t cpu_kv_stride_in_bytes, const int64_t cpu_layer_stride_in_bytes, @@ -48,12 +50,25 @@ class TPTransferThreadGroup { const int layer_granularity, const bool is_mla); private: + using Task = std::function; + std::future enqueue_for_gpu(int gpu_idx, Task task); + int num_gpus_; int dp_group_id_; void **gpu_blocks_; void *cpu_blocks_; + + int64_t *gpu_kv_strides_in_bytes_; + int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_chunk_sizes_in_bytes_; + std::vector threads_; std::vector streams_; + + std::vector> queues_; + std::vector mtxs_; + std::vector cvs_; + std::atomic stop_pool_; }; } // namespace flexkv diff --git a/csrc/transfer_ssd.h b/csrc/transfer_ssd.h index f795f4de54..9d6f9ba57e 100644 --- a/csrc/transfer_ssd.h +++ b/csrc/transfer_ssd.h @@ -173,11 +173,10 @@ class IOUring { class SSDIOCTX { public: - SSDIOCTX(std::map>& ssd_files, - int num_devices, int iouring_entries, int iouring_flags) - : iouring(iouring_entries, iouring_flags), - fds_buffer_io(num_devices), - fds_direct_io(num_devices) { + SSDIOCTX(std::map> &ssd_files, int num_devices, + int iouring_entries, int iouring_flags) + : iouring(iouring_entries, iouring_flags), fds_buffer_io(num_devices), + fds_direct_io(num_devices) { int i, j, fd_buffer_io, fd_direct_io; @@ -190,7 +189,8 @@ class SSDIOCTX { fd_direct_io = open(ssd_files[i][j].c_str(), O_RDWR | O_DIRECT); if (fd_buffer_io < 0 || fd_direct_io < 0) { - std::cerr << "open file failed, path = " << ssd_files[i][j] << std::endl; + std::cerr << "open file failed, path = " << ssd_files[i][j] + << std::endl; throw std::runtime_error("Failed to open file"); } else { posix_fadvise(fd_buffer_io, 0, 0, POSIX_FADV_SEQUENTIAL); @@ -221,20 +221,14 @@ class SSDIOCTX { } } - int get_num_devices() { - return num_devices; - } + int get_num_devices() { return num_devices; } - int get_num_files_per_device() { - return num_files_per_device; - } + int get_num_files_per_device() { return num_files_per_device; } - IOUring &get_iouring() { - return iouring; - } + IOUring &get_iouring() { return iouring; } std::vector> &get_fds(bool is_read, bool is_direct) { - if (is_read && is_direct) { + if (is_direct) { return fds_direct_io; } else { return fds_buffer_io; @@ -250,15 +244,13 @@ class SSDIOCTX { std::vector> fds_direct_io; }; - void transfer_kv_blocks_ssd( - SSDIOCTX &ioctx, - const torch::Tensor &cpu_layer_id_list, int64_t cpu_tensor_ptr, - const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids, - int64_t cpu_layer_stride_in_bytes, int64_t cpu_kv_stride_in_bytes, - int64_t ssd_layer_stride_in_bytes, int64_t ssd_kv_stride_in_bytes, - int64_t chunk_size_in_bytes, int64_t block_stride_in_bytes, bool is_read, - int num_blocks_per_file, int round_robin = 1, - int num_threads_per_device = 16, bool is_mla = false); + SSDIOCTX &ioctx, const torch::Tensor &cpu_layer_id_list, + int64_t cpu_tensor_ptr, const torch::Tensor &ssd_block_ids, + const torch::Tensor &cpu_block_ids, int64_t cpu_layer_stride_in_bytes, + int64_t cpu_kv_stride_in_bytes, int64_t ssd_layer_stride_in_bytes, + int64_t ssd_kv_stride_in_bytes, int64_t chunk_size_in_bytes, + int64_t block_stride_in_bytes, bool is_read, int num_blocks_per_file, + int round_robin = 1, int num_threads_per_device = 16, bool is_mla = false); } // namespace flexkv diff --git a/docs/vllm_adapter/README_en.md b/docs/vllm_adapter/README_en.md new file mode 100644 index 0000000000..781b3ad3ee --- /dev/null +++ b/docs/vllm_adapter/README_en.md @@ -0,0 +1,84 @@ +# Using FlexKV in vLLM + +## Current Version vs. Legacy Version +In commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934), we introduced a major update: +**FlexKV has transitioned from a client-server architecture to a library function that inference acceleration engines (such as vLLM) can directly invoke**, reducing inter-process communication overhead. + +This change involves significant API adjustments. Therefore, please note: + +- **Version >= `1.0.0`**: Use the **current version API**; the vLLM patch is located in `examples/vllm_adaption/`. +- **Version == `0.1.0`**: Supports the **legacy version API**; the vLLM patch is located in `examples/vllm_adaption_legacy/`. + +--- + +## Current Version (>= 1.0.0) + +### Supported Versions +- FlexKV >= `1.0.0` +- vLLM versions >= `0.8.5` can generally follow this version for adaptation + +### Example +We provide an adaptation example based on **vLLM 0.10.1.1**: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +## Legacy Version (<= 0.1.0) – Not Recommended for Current Use + +### Supported Versions +- FlexKV <= `0.1.0` + +### Example +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch` to vLLM 0.8.4, then start FlexKV, vLLM, and the benchmark script: + +```bash +# Start FlexKV as server +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# Start vLLM as client +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# Start benchmark +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +Apply the patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch` to vLLM 0.10.0, and use the same testing method as above. diff --git a/docs/vllm_adapter/README_zh.md b/docs/vllm_adapter/README_zh.md new file mode 100644 index 0000000000..0e7ce7687e --- /dev/null +++ b/docs/vllm_adapter/README_zh.md @@ -0,0 +1,83 @@ +# 在 vLLM 中使用 FlexKV + +## 当前版本与 Legacy 版本说明 +在 commit [`0290841dce65ae9b036a23d733cf94e47e814934`](https://github.com/taco-project/FlexKV/commit/0290841dce65ae9b036a23d733cf94e47e814934),我们更新了一个重要功能: + **FlexKV 从 client-server 模式,变为推理加速引擎(如 vLLM)可直接调用的库函数**,以减少进程间消息传递的开销。 +这一变更引发了较大的 API 调整。因此,请注意: + +- **版本 >= `1.0.0`**:应使用 **当前版本 API**,vLLM patch位于 `examples/vllm_adaption/`。 +- **版本 == `0.1.0`**:仅支持 **Legacy 版本 API**, vLLM patch位于`examples/vllm_adaption_legacy/`。 + +--- + +## 当前版本(>= 1.0.0) + +### 适用版本 +- FlexKV >= `1.0.0` +- vLLM 原则上>= `0.8.5`版本均可参考示例代码进行修改 + +### 示例 +我们提供了基于 **vLLM 0.10.1.1** 的适配示例: + +1. apply patch +```bash +# FLEXKV_DIR/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +git apply examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch +``` + +2. offline test +```bash +# VLLM_DIR/examples/offline_inference/prefix_caching_flexkv.py +python examples/offline_inference/prefix_caching_flexkv.py +``` + +3. online serving +```bash +# generate config +cat < ./flexkv_config.json +{ + "server_recv_port": "ipc:///tmp/flexkv_test", + "cache_config": { + "enable_cpu": true, + "num_cpu_blocks": 10240, + }, + "num_log_interval_requests": 200 +} +EOF +export FLEXKV_CONFIG_PATH="./flexkv_config.json" + +VLLM_USE_V1=1 python -m vllm.entrypoints.cli.main serve Qwen3/Qwen3-32B \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --port 30001 \ + --max-num-seqs 128 \ + --max-num-batched-tokens 8192 \ + --max_model_len 8192 \ + --max-seq-len-to-capture 8192 \ + --gpu-memory-utilization 0.8 \ + --enable-chunked-prefill \ + --enable-prefix-caching \ + --kv-transfer-config \ + '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}' + +``` + +## Legacy版本(<= 0.1.0),目前的版本尽量不要使用 + +### 适用版本 +- FlexKV <= `0.1.0` + +### 示例 +在 vLLM 0.8.4 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch`,分别启动 FlexKV、vLLM 和测试脚本: + +```bash +# 启动 FlexKV 作为服务端 +bash benchmarks/flexkv_benchmark/run_flexkv_server.sh + +# 启动 vLLM 作为客户端 +bash benchmarks/flexkv_benchmark/serving_vllm.sh + +# 启动性能测试 +bash benchmarks/flexkv_benchmark/multiturn_benchmark.sh +``` +在 vLLM 0.10.0 版本中应用patch `examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch`,测试方法同上。 diff --git a/examples/run_server.py b/examples/run_server.py index d5b6a182ec..48b24ecad1 100644 --- a/examples/run_server.py +++ b/examples/run_server.py @@ -12,16 +12,16 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - + # NAME - parser.add_argument("--enable-cpu", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-cpu", + action=argparse.BooleanOptionalAction, default=True) - parser.add_argument("--enable-ssd", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-ssd", + action=argparse.BooleanOptionalAction, default=False,) - parser.add_argument("--enable-remote", - action=argparse.BooleanOptionalAction, + parser.add_argument("--enable-remote", + action=argparse.BooleanOptionalAction, default=False,) parser.add_argument("--model-path", type=str, help="model path", default="") parser.add_argument("--tp-size", type=int, default=1) @@ -54,7 +54,7 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": args = parse_args() hf_config = AutoConfig.from_pretrained(args.model_path) - + num_layers=hf_config.num_hidden_layers if hasattr(hf_config, 'num_key_value_heads'): num_kv_heads=hf_config.num_key_value_heads @@ -65,7 +65,7 @@ def parse_args() -> argparse.Namespace: head_size=(hf_config.head_dim if hasattr(hf_config, 'head_dim') else hf_config.hidden_size//hf_config.num_attention_heads) use_mla=hf_config.architectures[0].startswith("Deepseek") - + # TODO: different model config may have different attribute name model_config = ModelConfig( num_layers=num_layers, @@ -76,14 +76,13 @@ def parse_args() -> argparse.Namespace: dp_size=args.dp_size, dtype=hf_config.torch_dtype ) - + cache_config = CacheConfig( enable_cpu=args.enable_cpu, enable_ssd=args.enable_ssd, enable_remote=args.enable_remote, use_gds=False, enable_trace=False, - use_pinned_memory=False, ssd_cache_iouring_entries=512, tokens_per_block=args.block_size, num_cpu_blocks=args.num_cpu_blocks, @@ -93,6 +92,6 @@ def parse_args() -> argparse.Namespace: remote_cache_size_mode=args.remote_cache_size_mode, remote_cache_path=args.remote_cache_path, ) - + kvserver = KVServer(model_config, cache_config, args.server_recv_port) - kvserver.run() \ No newline at end of file + kvserver.run() diff --git a/examples/scheduler_server_example.py b/examples/scheduler_server_example.py index 1aae7ec298..059cc467aa 100644 --- a/examples/scheduler_server_example.py +++ b/examples/scheduler_server_example.py @@ -16,9 +16,9 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, model_config, gpu_kv_layout): """Run TP client process""" from flexkv.server.client import KVTPClient - + print(f"Starting TP client: dp_client_id={dp_client_id}, tp_rank={tp_rank}, device_id={device_id}") - + try: # Set CUDA device for this process if torch.cuda.is_available(): @@ -27,8 +27,8 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo torch.cuda.init() # Clear cache torch.cuda.empty_cache() - - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Create GPU blocks for this TP client gpu_blocks = [] @@ -51,7 +51,7 @@ def run_tp_client_process(dp_client_id, tp_rank, device_id, server_recv_port, mo # Keep TP client running while True: time.sleep(1) - + except Exception as e: print(f"TP client {tp_rank} error: {e}") import traceback @@ -84,7 +84,6 @@ def main(): enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks, ) @@ -106,14 +105,14 @@ def main(): cache_config=cache_config, server_recv_port="ipc:///tmp/scheduler_server_example" # TPClient connects to this port ) - + # Start background server thread to handle TPClient registration scheduler_server.start_server_thread() - - print(f"SchedulerServer started!") + + print("SchedulerServer started!") print(f"TPClient can connect to: {scheduler_server.get_server_port()}") print("Starting TP client processes...") - + # Start TP client processes tp_client_processes = [] for tp_rank in range(tp_size): @@ -123,7 +122,7 @@ def main(): if device_id >= available_gpus: device_id = device_id % available_gpus print(f"Warning: Using GPU {device_id} for TP rank {tp_rank} (not enough GPUs)") - + tp_client_process = Process( target=run_tp_client_process, args=(0, tp_rank, device_id, scheduler_server.get_server_port(), model_config, gpu_kv_layout), @@ -134,32 +133,32 @@ def main(): print(f"Started TP client process for rank {tp_rank} on device {device_id}") print("Waiting for all TP clients to register...") - + time.sleep(5) - + # Now we can directly use scheduler_server without network communication # Example: Create some test data (following benchmark_kvmanager.py pattern) batch_size = 4 seq_len = 128 - + print("\n=== Generating test data ===") # Generate separate sequences for each request (correct approach) batch_token_ids = [] batch_slot_mappings = [] batch_token_masks = [] - + for i in range(batch_size): # Each sequence is independent (seq_len,) shape token_ids = torch.randint(0, 1000, (seq_len,)) slot_mapping = torch.arange(i * seq_len, (i + 1) * seq_len) token_mask = torch.ones(seq_len, dtype=torch.bool) - + batch_token_ids.append(token_ids) batch_slot_mappings.append(slot_mapping) batch_token_masks.append(token_mask) - + print(f"Generated {batch_size} sequences, each with {seq_len} tokens") - + print("\n=== Executing PUT Operations ===") # PUT operations - each sequence processed separately start_time = time.time() @@ -173,7 +172,7 @@ def main(): if task_id: put_task_ids.append(task_id) print(f"PUT task {task_id} created for sequence {i}") - + put_time = (time.time() - start_time) * 1000 print(f"Created {len(put_task_ids)} PUT tasks, time: {put_time:.2f}ms") time.sleep(2) @@ -190,10 +189,10 @@ def main(): if task_id: get_task_ids.append(task_id) print(f"GET task {task_id} created for sequence {i}") - + get_time = (time.time() - start_time) * 1000 print(f"Created {len(get_task_ids)} GET tasks, time: {get_time:.2f}ms") - + print("\n=== Waiting for All Tasks to Complete ===") # Wait for all tasks to complete - can wait for multiple tasks at once all_task_ids = put_task_ids + get_task_ids @@ -202,7 +201,7 @@ def main(): masks = scheduler_server.wait(all_task_ids) wait_time = (time.time() - start_time) * 1000 print(f"All {len(all_task_ids)} tasks completed, time: {wait_time:.2f}ms") - + # Analyze results if masks: total_tokens = 0 @@ -211,7 +210,7 @@ def main(): tokens = mask.sum().item() if hasattr(mask, 'sum') else len(mask) total_tokens += tokens print(f"Task {task_id}: {tokens} tokens processed") - + print("\n=== Trying Non-blocking Wait ===") # Create a few more tasks and try non-blocking wait extra_task_ids = [] @@ -223,7 +222,7 @@ def main(): ) if task_id: extra_task_ids.append(task_id) - + if extra_task_ids: # Immediately try to wait (might not be completed yet) masks = scheduler_server.try_wait(extra_task_ids) @@ -233,15 +232,15 @@ def main(): print(f"Tasks {extra_task_ids} not ready yet, will wait...") masks = scheduler_server.wait(extra_task_ids) print(f"Tasks {extra_task_ids} completed after wait") - + print("\n✅ All operations completed successfully!") - - + + # Clean up resources print("\n=== Shutting down SchedulerServer ===") scheduler_server.shutdown() print("SchedulerServer has been shut down") - + # Terminate TP client processes print("Terminating TP client processes...") for i, process in enumerate(tp_client_processes): @@ -253,4 +252,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch new file mode 100644 index 0000000000..812a1d6e2f --- /dev/null +++ b/examples/vllm_adaption/vllm_0_10_1_1-flexkv-connector.patch @@ -0,0 +1,453 @@ +diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py +new file mode 100644 +index 000000000..a57328ffd +--- /dev/null ++++ b/examples/offline_inference/prefix_caching_flexkv.py +@@ -0,0 +1,162 @@ ++# SPDX-License-Identifier: Apache-2.0 ++import os ++import time ++import json ++ ++from vllm import LLM, SamplingParams ++from vllm.distributed import cleanup_dist_env_and_memory ++ ++# NOTE: This is just a running example. For benchmarking purpose, ++# please see benchmarks/benchmark_prefix_caching.py ++ ++ ++flexkv_config = { ++ "server_recv_port": "ipc:///tmp/flexkv_test", ++ "cache_config": { ++ "enable_cpu": True, ++ "num_cpu_blocks": 10240, ++ }, ++ "num_log_interval_requests": 200 ++} ++flexkv_config_path = "./flexkv_config.json" ++with open(flexkv_config_path, 'w') as f: ++ json.dump(flexkv_config, f) ++os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path ++ ++ ++# Common prefix. ++prefix = ( ++ "You are an expert school principal, skilled in effectively managing " ++ "faculty and staff. Draft 10-15 questions for a potential first grade " ++ "Head Teacher for my K-12, all-girls', independent school that emphasizes " ++ "community, joyful discovery, and life-long learning. The candidate is " ++ "coming in for a first-round panel interview for a 8th grade Math " ++ "teaching role. They have 5 years of previous teaching experience " ++ "as an assistant teacher at a co-ed, public school with experience " ++ "in middle school math teaching. Based on these information, fulfill " ++ "the following paragraph: ") ++ ++# Sample prompts. ++prompts = [ ++ "Hello, my name is", ++ "The president of the United States is", ++ "The capital of France is", ++ "The future of AI is", ++] ++ ++generating_prompts = [prefix + prompt for prompt in prompts] ++ ++# Create a sampling params object. ++sampling_params = SamplingParams(temperature=0.0) ++ ++kv_transfer_config = { ++ "kv_connector": "FlexKVConnectorV1", ++ "kv_role": "kv_both", ++} ++# model_path = "/data0/models/facebook/opt-125m" ++model_path = "/data0/models/Qwen3/Qwen3-32B" ++tp_size = 8 ++gpu_memory_utilization = 0.4 ++ ++ ++ ++def main(): ++ # Create an LLM without prefix caching as a baseline. ++ regular_llm = LLM(model=model_path, ++ enable_prefix_caching=False, ++ gpu_memory_utilization=gpu_memory_utilization, ++ tensor_parallel_size=tp_size ++ ) ++ ++ print("Results without `enable_prefix_caching`") ++ ++ # ruff: noqa: E501 ++ # Generate texts from the prompts. The output is a list of RequestOutput objects ++ # that contain the prompt, generated text, and other information. ++ outputs = regular_llm.generate(generating_prompts, sampling_params) ++ ++ regular_generated_texts = [] ++ # Print the outputs. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ regular_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Destroy the LLM object and free up the GPU memory. ++ del regular_llm ++ cleanup_dist_env_and_memory() ++ ++ # return ++ ++ # Create an LLM with prefix caching enabled. ++ prefix_cached_llm = LLM(model=model_path, ++ enable_prefix_caching=True, ++ gpu_memory_utilization=gpu_memory_utilization, ++ tensor_parallel_size=tp_size, ++ kv_transfer_config=kv_transfer_config, ++ ) ++ ++ # Warmup so that the shared prompt's KV cache is computed. ++ prefix_cached_llm.generate(generating_prompts[0], sampling_params) ++ ++ # wait for offload kv task finished. ++ time.sleep(2) ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `enable_prefix_caching`") ++ ++ cached_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ cached_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == cached_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ # wait for offload kv task finished. ++ time.sleep(2) ++ ++ # reset prefix cache to use flexkv ++ prefix_cached_llm.reset_prefix_cache() ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `flexkv`") ++ ++ flexkv_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ flexkv_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == flexkv_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ ++ ++if __name__ == "__main__": ++ main() ++ # pass +\ No newline at end of file +diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py +index 584fc1d65..db1cfe36b 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/factory.py ++++ b/vllm/distributed/kv_transfer/kv_connector/factory.py +@@ -105,3 +105,8 @@ KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") ++ ++KVConnectorFactory.register_connector( ++ "FlexKVConnectorV1", ++ "vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector", ++ "FlexKVConnectorV1") +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py +new file mode 100644 +index 000000000..bdfa9f321 +--- /dev/null ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/flexkv_connector.py +@@ -0,0 +1,191 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++from typing import TYPE_CHECKING, Any, Optional ++ ++import torch ++from flexkv.integration.vllm.vllm_v1_adapter import FlexKVConnectorV1Impl ++ ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.logger import init_logger ++from vllm.v1.core.sched.output import SchedulerOutput ++from vllm.v1.outputs import KVConnectorOutput ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.core.kv_cache_manager import KVCacheBlocks ++ from vllm.v1.request import Request ++ ++logger = init_logger(__name__) ++ ++ ++class FlexKVConnectorV1(KVConnectorBase_V1): ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ super().__init__(vllm_config=vllm_config, role=role) ++ self._flexkv_connector = FlexKVConnectorV1Impl(vllm_config, role) ++ ++ def shutdown(self): ++ self._flexkv_connector.shutdown() ++ ++ # ============================== ++ # Worker-side methods ++ # ============================== ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """ ++ Start loading the KV cache from the connector to vLLM's paged ++ KV buffer. This is called from the forward context before the ++ forward pass to enable async loading during model execution. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ ++ """ ++ self._flexkv_connector.start_load_kv(forward_context, **kwargs) ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """ ++ Block until the KV for a specific layer is loaded into vLLM's ++ paged buffer. This is called from within attention layer to ensure ++ async copying from start_load_kv is complete. ++ ++ This interface will be useful for layer-by-layer pipelining. ++ ++ Args: ++ layer_name: the name of that layer ++ """ ++ self._flexkv_connector.wait_for_layer_load(layer_name) ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """ ++ Start saving the a layer of KV cache from vLLM's paged buffer ++ to the connector. This is called from within attention layer to ++ enable async copying during execution. ++ ++ Args: ++ layer_name (str): the name of the layer. ++ kv_layer (torch.Tensor): the paged KV buffer of the current ++ layer in vLLM. ++ attn_metadata (AttentionMetadata): the attention metadata. ++ **kwargs: additional arguments for the save operation. ++ """ ++ self._flexkv_connector.save_kv_layer(layer_name, kv_layer, attn_metadata, ++ **kwargs) ++ ++ def wait_for_save(self): ++ """ ++ Block until all the save operations is done. This is called ++ as the forward context exits to ensure that the async saving ++ from save_kv_layer is complete before finishing the forward. ++ ++ This prevents overwrites of paged KV buffer before saving done. ++ """ ++ self._flexkv_connector.wait_for_save() ++ ++ def get_finished( ++ self, finished_req_ids: set[str] ++ ) -> tuple[Optional[set[str]], Optional[set[str]]]: ++ """ ++ Notifies worker-side connector ids of requests that have ++ finished generating tokens. ++ ++ Returns: ++ ids of requests that have finished asynchronous transfer ++ (requests that previously returned True from request_finished()), ++ tuple of (sending/saving ids, recving/loading ids). ++ The finished saves/sends req ids must belong to a set provided in a ++ call to this method (this call or a prior one). ++ """ ++ return self._flexkv_connector.get_finished(finished_req_ids) ++ ++ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ++ """ ++ Initialize with the KV caches. Useful for pre-registering the ++ KV Caches in the KVConnector (e.g. for NIXL). ++ ++ Args: kv_caches: ++ dictionary of layer names, kv cache ++ """ ++ self._flexkv_connector.register_kv_caches(kv_caches) ++ ++ # ============================== ++ # Scheduler-side methods ++ # ============================== ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> tuple[int, bool]: ++ """ ++ Get number of new tokens that can be loaded from the ++ external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ return self._flexkv_connector.get_num_new_matched_tokens( ++ request, num_computed_tokens) ++ ++ def update_state_after_alloc(self, request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int): ++ """ ++ Update KVConnector state after block allocation. ++ """ ++ self._flexkv_connector.update_state_after_alloc(request, blocks, ++ num_external_tokens) ++ ++ def build_connector_meta( ++ self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: ++ """ ++ Build the connector metadata for this step. ++ ++ This function should NOT modify fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ return self._flexkv_connector.build_connector_meta(scheduler_output) ++ ++ def update_connector_output(self, connector_output: KVConnectorOutput): ++ """ ++ Update KVConnector state from worker-side connectors output. ++ ++ Args: ++ connector_output (KVConnectorOutput): the worker-side ++ connectors output. ++ """ ++ self._flexkv_connector.update_connector_output(connector_output) ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ """ ++ Called when a request has finished, before its blocks are freed. ++ ++ Returns: ++ True if the request is being saved/sent asynchronously and blocks ++ should not be freed until the request_id is returned from ++ get_finished(). ++ Optional KVTransferParams to be included in the request outputs ++ returned by the engine. ++ """ ++ return self._flexkv_connector.request_finished(request, block_ids) +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 981023409..a6c8fac38 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -118,6 +118,7 @@ class Scheduler(SchedulerInterface): + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() ++ self.sending_kv_reqs: dict[str, Request] = {} + + # Encoder-related. + # Calculate encoder cache size if applicable +@@ -1029,7 +1030,8 @@ class Scheduler(SchedulerInterface): + + if not delay_free_blocks: + self._free_blocks(request) +- ++ else: ++ self.sending_kv_reqs[request.request_id] = request + return kv_xfer_params + + def _free_blocks(self, request: Request): +@@ -1041,7 +1043,7 @@ class Scheduler(SchedulerInterface): + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: +- return len(self.finished_req_ids) > 0 ++ return len(self.finished_req_ids) > 0 or len(self.sending_kv_reqs) > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() +@@ -1082,6 +1084,8 @@ class Scheduler(SchedulerInterface): + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() ++ if self.connector and hasattr(self.connector, "shutdown"): ++ self.connector.shutdown() + + ######################################################################## + # KV Connector Related Methods +@@ -1149,6 +1153,10 @@ class Scheduler(SchedulerInterface): + scheduler the request during the next step. + """ + ++ # avoid busy checking ++ if len(self.running) == 0: ++ time.sleep(0.01) ++ + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + +@@ -1158,4 +1166,5 @@ class Scheduler(SchedulerInterface): + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (kv_connector_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) ++ del self.sending_kv_reqs[req_id] + self._free_blocks(self.requests[req_id]) +diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py +index a03ebe35d..8e4460957 100644 +--- a/vllm/v1/worker/kv_connector_model_runner_mixin.py ++++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py +@@ -66,9 +66,9 @@ class KVConnectorModelRunnerMixin: + scheduler_output, wait_for_save=False) as kv_connector_output: + pass + +- if (not kv_connector_output.finished_sending +- and not kv_connector_output.finished_recving): +- return EMPTY_MODEL_RUNNER_OUTPUT ++ # if (not kv_connector_output.finished_sending ++ # and not kv_connector_output.finished_recving): ++ # return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output diff --git a/examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch new file mode 100644 index 0000000000..f6349a0ac7 --- /dev/null +++ b/examples/vllm_adaption_legacy/flexkv_vllm_0_10_0.patch @@ -0,0 +1,1224 @@ +diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py +index c7229dbb8..d2325fd3a 100644 +--- a/benchmarks/backend_request_func.py ++++ b/benchmarks/backend_request_func.py +@@ -9,6 +9,7 @@ import time + import traceback + from dataclasses import dataclass, field + from typing import Optional, Union ++import asyncio + + import aiohttp + import huggingface_hub.constants +@@ -23,10 +24,10 @@ AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + @dataclass + class RequestFuncInput: +- prompt: str ++ prompt: Union[str, list[str]] + api_url: str +- prompt_len: int +- output_len: int ++ prompt_len: Union[int, list[int]] ++ output_len: Union[int, list[int]] + model: str + model_name: Optional[str] = None + logprobs: Optional[int] = None +@@ -555,6 +556,107 @@ async def async_request_openai_audio( + pbar.update(1) + return output + ++async def async_request_openai_chat_completions_multiturns( ++ request_func_input: RequestFuncInput, ++ pbar: Optional[tqdm] = None, ++ turn_interval_time: float = 3.0, ++) -> RequestFuncOutput: ++ api_url = request_func_input.api_url ++ assert api_url.endswith( ++ ("chat/completions", "profile") ++ ), "OpenAI Chat Completions API URL must end with 'chat/completions'." ++ assert isinstance(request_func_input.prompt, list) ++ assert isinstance(request_func_input.prompt_len, list) ++ assert isinstance(request_func_input.output_len, list) ++ ++ async with aiohttp.ClientSession(trust_env=True, ++ timeout=AIOHTTP_TIMEOUT) as session: ++ payload = { ++ "model": request_func_input.model_name \ ++ if request_func_input.model_name else request_func_input.model, ++ "messages": [ ++ ], ++ "temperature": 0.0, ++ "stream": True, ++ "stream_options": { ++ "include_usage": True, ++ }, ++ } ++ payload["ignore_eos"] = request_func_input.ignore_eos ++ if request_func_input.extra_body: ++ payload.update(request_func_input.extra_body) ++ headers = { ++ "Content-Type": "application/json", ++ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", ++ } ++ ++ output_list = [] ++ for turn_id, prompt in enumerate(request_func_input.prompt): ++ output = RequestFuncOutput() ++ output.prompt_len = request_func_input.prompt_len[turn_id] ++ ++ payload["messages"].append({"role": "user", "content": prompt}) ++ payload["max_tokens"] = request_func_input.output_len[turn_id] ++ ++ generated_text = "" ++ ttft = 0.0 ++ st = time.perf_counter() ++ most_recent_timestamp = st ++ try: ++ async with session.post(url=api_url, json=payload, ++ headers=headers) as response: ++ if response.status == 200: ++ async for chunk_bytes in response.content: ++ chunk_bytes = chunk_bytes.strip() ++ if not chunk_bytes: ++ continue ++ ++ chunk = chunk_bytes.decode("utf-8").removeprefix( ++ "data: ") ++ if chunk != "[DONE]": ++ timestamp = time.perf_counter() ++ data = json.loads(chunk) ++ ++ if choices := data.get("choices"): ++ content = choices[0]["delta"].get("content") ++ # First token ++ if ttft == 0.0: ++ ttft = timestamp - st ++ output.ttft = ttft ++ ++ # Decoding phase ++ else: ++ output.itl.append(timestamp - ++ most_recent_timestamp) ++ ++ generated_text += content or "" ++ elif usage := data.get("usage"): ++ output.output_tokens = usage.get( ++ "completion_tokens") ++ ++ most_recent_timestamp = timestamp ++ ++ output.generated_text = generated_text ++ output.success = True ++ output.latency = most_recent_timestamp - st ++ else: ++ output.error = response.reason or "" ++ output.success = False ++ break ++ except Exception: ++ output.success = False ++ exc_info = sys.exc_info() ++ output.error = "".join(traceback.format_exception(*exc_info)) ++ break ++ payload["messages"].append({"role": "assistant", "content": generated_text}) ++ ++ output_list.append(output) ++ if turn_id != len(request_func_input.prompt) - 1: ++ await asyncio.sleep(turn_interval_time) ++ ++ if pbar: ++ pbar.update(1) ++ return output_list + + def get_model(pretrained_model_name_or_path: str) -> str: + if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true": +@@ -619,6 +721,7 @@ ASYNC_REQUEST_FUNCS = { + "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, + "llama.cpp": async_request_openai_completions, ++ "openai-chat-multiturns": async_request_openai_chat_completions_multiturns, + } + + OPENAI_COMPATIBLE_BACKENDS = [ +diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py +index 1ad6cef7a..9178528d0 100644 +--- a/benchmarks/benchmark_dataset.py ++++ b/benchmarks/benchmark_dataset.py +@@ -49,9 +49,9 @@ class SampleRequest: + Represents a single inference request for benchmarking. + """ + +- prompt: Union[str, Any] +- prompt_len: int +- expected_output_len: int ++ prompt: Union[str, list[str], Any] ++ prompt_len: Union[int, list[int]] ++ expected_output_len: Union[int, list[int]] + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + +@@ -617,6 +617,108 @@ class SonnetDataset(BenchmarkDataset): + ) + return samples + ++ ++# ----------------------------------------------------------------------------- ++# ShareGPT Multiturn Dataset Implementation ++# ----------------------------------------------------------------------------- ++ ++ ++class ShareGPTMultiTurnsDataset(BenchmarkDataset): ++ def __init__(self, min_num_turns: int = 2, **kwargs) -> None: ++ super().__init__(**kwargs) ++ self.load_data(min_num_turns) ++ ++ def load_data(self, min_num_turns: int) -> None: ++ if self.dataset_path is None: ++ raise ValueError("dataset_path must be provided for loading data.") ++ ++ with open(self.dataset_path, encoding="utf-8") as f: ++ self.data = json.load(f) ++ # Filter entries with at least two conversation turns. ++ new_data = [] ++ for entry in self.data: ++ if "conversations" in entry: ++ while len(entry["conversations"]) > 0 and entry["conversations"][0]['from'] != 'human': ++ entry["conversations"].pop(0) ++ if len(entry["conversations"]) % 2 != 0: ++ entry["conversations"].pop(-1) ++ if len(entry["conversations"]) >= 2 * min_num_turns: ++ new_data.append(entry) ++ self.data = new_data ++ random.seed(self.random_seed) ++ random.shuffle(self.data) ++ ++ def sample( ++ self, ++ tokenizer: PreTrainedTokenizerBase, ++ num_requests: int, ++ lora_path: Optional[str] = None, ++ max_loras: Optional[int] = None, ++ output_len: Optional[int] = None, ++ **kwargs, ++ ) -> list: ++ samples: list = [] ++ for entry in self.data: ++ if len(samples) >= num_requests: ++ break ++ ++ prompt_list = [d["value"] for d in entry["conversations"][::2]] ++ completion_list = [d["value"] for d in entry["conversations"][1::2]] ++ # prompt, completion = ( ++ # entry["conversations"][0]["value"], ++ # entry["conversations"][1]["value"], ++ # ) ++ ++ lora_request, tokenizer = self.get_random_lora_request( ++ tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) ++ ++ ++ prompt_ids_list = [] ++ completion_ids_list = [] ++ prompt_len_list = [] ++ new_output_len_list = [] ++ history_len = 0 ++ for turn_id in range(len(prompt_list)): ++ try: ++ prompt_ids = tokenizer(prompt_list[turn_id]).input_ids ++ completion_ids = tokenizer(completion_list[turn_id]).input_ids ++ except: ++ print(entry) ++ raise ++ prompt_len = len(prompt_ids) + history_len ++ new_output_len = len(completion_ids) if output_len is None else output_len ++ if not is_valid_sequence( ++ prompt_len, ++ new_output_len, ++ min_len=4, ++ max_prompt_len=4096, ++ max_total_len=8192, ++ skip_min_output_len_check=output_len ++ is not None): ++ turn_id -= 1 ++ break ++ prompt_ids_list.append(prompt_ids) ++ completion_ids_list.append(completion_ids) ++ prompt_len_list.append(prompt_len) ++ new_output_len_list.append(new_output_len) ++ history_len += prompt_len ++ history_len += new_output_len ++ ++ if turn_id <= 0: ++ continue ++ ++ prompt_list = prompt_list[:turn_id+1] ++ ++ samples.append( ++ SampleRequest( ++ prompt=prompt_list, ++ prompt_len=prompt_len_list, ++ expected_output_len=new_output_len_list, ++ lora_request=lora_request, ++ )) ++ self.maybe_oversample_requests(samples, num_requests) ++ return samples ++ + + # ----------------------------------------------------------------------------- + # BurstGPT Dataset Implementation +diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py +index c597fb106..74e157927 100644 +--- a/benchmarks/benchmark_serving.py ++++ b/benchmarks/benchmark_serving.py +@@ -71,6 +71,7 @@ from benchmark_dataset import ( + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, ++ ShareGPTMultiTurnsDataset, + ) + from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json + from vllm.benchmarks.serve import get_request +@@ -142,7 +143,7 @@ def calculate_metrics( + ).input_ids + ) + actual_output_lens.append(output_len) +- total_input += input_requests[i].prompt_len ++ total_input += outputs[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft +@@ -278,6 +279,9 @@ async def benchmark( + ) + + test_output = await request_func(request_func_input=test_input) ++ if backend == "openai-chat-multiturns": ++ print("test_output ", test_output) ++ test_output = test_output[-1] + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " +@@ -394,6 +398,8 @@ async def benchmark( + task = limited_request_func(request_func_input=request_func_input, pbar=pbar) + tasks.append(asyncio.create_task(task)) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) ++ if backend == "openai-chat-multiturns": ++ outputs = [o for sub_o in outputs for o in sub_o] + + if profile: + print("Stopping profiler...") +@@ -748,6 +754,15 @@ def main(args: argparse.Namespace): + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), ++ "sharegpt_multiturns": ++ lambda: ShareGPTMultiTurnsDataset( ++ min_num_turns=4, ++ random_seed=args.seed, ++ dataset_path=args.dataset_path).sample( ++ tokenizer=tokenizer, ++ num_requests=args.num_prompts, ++ output_len=args.sharegpt_output_len, ++ ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), +@@ -930,7 +945,7 @@ def create_argument_parser(): + "--dataset-name", + type=str, + default="sharegpt", +- choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], ++ choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "sharegpt_multiturns"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( +diff --git a/benchmarks/flexkv_benchmark/container b/benchmarks/flexkv_benchmark/container +new file mode 100644 +index 000000000..cfc3b5bac +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/container +@@ -0,0 +1,12 @@ ++docker run --shm-size=50g --ipc=host --network host -it --gpus all -v /home/zichengm/FlexKV:/workspace -e GLOO_SOCKET_IFNAME=eno1 --entrypoint /bin/bash vllm/vllm-openai:v0.10.0 ++VLLM_USE_PRECOMPILED=1 pip install -e . ++apt update && apt install liburing-dev ++> vllm.log 2>&1 & ++export GLOO_SOCKET_IFNAME=eno1 ++nohup bash run_flexkv_server.sh > kvserver.log 2>&1 & ++nohup bash serving_vllm.sh 2 > vllm.log 2>&1 & ++bash multiturn_benchmark.sh ++ ++gdb -q -ex "set pagination off" -ex "set confirm off" -ex "set env PYTHONFAULTHANDLER=1" -ex "handle SIGPIPE noprint nostop pass" -ex "handle SIGBUS stop print" -ex "run" --args python3 examples/run_server.py --model-path Qwen/Qwen3-8B --tp-size 1 --dp-size 1 --block-size 16 --num-cpu-blocks 8192 --server-recv-port ipc:///tmp/tmpe0x8_0gq ++ ++cpu block num = cpu memory size / layer_num / 2 / token_per_block / num_heads / head_size / sizeof(data_type) +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/lmcache_config.yaml b/benchmarks/flexkv_benchmark/lmcache_config.yaml +new file mode 100644 +index 000000000..8016df5b6 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/lmcache_config.yaml +@@ -0,0 +1,6 @@ ++# Basic configurations ++chunk_size: 16 ++ ++# CPU offloading configurations ++local_cpu: true ++max_local_cpu_size: 32 +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/multiturn_benchmark.sh b/benchmarks/flexkv_benchmark/multiturn_benchmark.sh +new file mode 100644 +index 000000000..4a15ca771 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/multiturn_benchmark.sh +@@ -0,0 +1,17 @@ ++current_time=$(date +"%Y-%m-%d-%H:%M:%S") ++for workers in 128; do ++ concurrency_multiplier=4 ++ if [ $workers -gt 128 ]; then ++ concurrency_multiplier=2 ++ fi ++ python3 ../benchmark_serving.py \ ++ --backend openai-chat-multiturns \ ++ --model Qwen/Qwen3-8B \ ++ --dataset-name sharegpt_multiturns \ ++ --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \ ++ --num-prompts $((workers*concurrency_multiplier)) \ ++ --max-concurrency $workers \ ++ --host 0.0.0.0 \ ++ --port 12599 \ ++ --endpoint /v1/chat/completions 2>&1 ++done +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/run_flexkv_server.sh b/benchmarks/flexkv_benchmark/run_flexkv_server.sh +new file mode 100644 +index 000000000..56b38fa01 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/run_flexkv_server.sh +@@ -0,0 +1,15 @@ ++MODEL_PATH=Qwen/Qwen3-8B ++ ++CMD="python3 examples/run_server.py \ ++ --model-path $MODEL_PATH \ ++ --tp-size 1 \ ++ --dp-size 1 \ ++ --block-size 16 \ ++ --num-cpu-blocks 7282 \ ++ --server-recv-port ipc:///tmp/tmpe0x8_0gq \ ++ " ++echo ++echo ++ ++echo $CMD ++eval $CMD +\ No newline at end of file +diff --git a/benchmarks/flexkv_benchmark/serving_vllm.sh b/benchmarks/flexkv_benchmark/serving_vllm.sh +new file mode 100644 +index 000000000..18dee4732 +--- /dev/null ++++ b/benchmarks/flexkv_benchmark/serving_vllm.sh +@@ -0,0 +1,79 @@ ++#!/bin/bash ++# vLLM服务启动脚本 ++# 使用方法: ./serving_vllm.sh ++# type选项: ++# 0: 无前缀缓存 ++# 1: GPU前缀缓存 ++# 2: FlexKV ++# 3: LMCache (需要先配置LMCACHE_CONFIG_FILE环境变量) ++ ++MODEL_PATH=Qwen/Qwen3-8B ++ ++type=${1} ++ ++if [[ $type = 0 ]]; then ++ # no prefix cache ++ prefix_args="--no-enable-prefix-caching" ++ use_lmcache=false ++elif [[ $type = 1 ]]; then ++ # gpu prefix cache ++ prefix_args="" ++ use_lmcache=false ++elif [[ $type = 2 ]]; then ++ # flexkv ++ prefix_args="" ++ export ENABLE_FLEXKV="true" ++ export FLEXKV_SERVER_RECV_PORT="ipc:///tmp/tmpe0x8_0gq" ++ use_lmcache=false ++elif [[ $type = 3 ]]; then ++ # lmcache ++ prefix_args="" ++ use_lmcache=true ++ export LMCACHE_CONFIG_FILE="./lmcache_config.yaml" ++else ++ echo "ERROR: Unknown running type [$type]" ++ exit -1 ++fi ++ ++# nccl envs ++export GLOO_SOCKET_IFNAME=eno1 ++export NCCL_SOCKET_IFNAME=eno1 ++export NCCL_IB_GID_INDEX=3 ++export NCCL_IB_DISABLE=0 ++export NCCL_NET_GDR_LEVEL=2 ++export NCCL_IB_QPS_PER_CONNECTION=4 ++export NCCL_IB_TC=160 ++export NCCL_IB_TIMEOUT=22 ++export NCCL_PXN_DISABLE=0 ++ ++if [[ $use_lmcache = true ]]; then ++ # 使用vllm serve命令和LMCache ++ CMD="python3 -m vllm.entrypoints.openai.api_server --model $MODEL_PATH \ ++ --port=12599 \ ++ --tensor-parallel-size=1 \ ++ --data-parallel-size=1 \ ++ --pipeline-parallel-size=1 \ ++ --max-model-len=8192 \ ++ --max-num-seqs=256 \ ++ --gpu-memory-utilization 0.4 \ ++ --max-num-batched-tokens 8192 \ ++ --kv-transfer-config '{\"kv_connector\":\"LMCacheConnectorV1\",\"kv_role\":\"kv_both\"}' \ ++ $prefix_args" ++else ++ # 使用原有的api_server启动方式 ++ CMD="python3 -m vllm.entrypoints.openai.api_server --model $MODEL_PATH \ ++ --port=12599 \ ++ --tensor-parallel-size=1 \ ++ --data-parallel-size=1 \ ++ --pipeline-parallel-size=1 \ ++ --max-model-len=8192 \ ++ --max-num-seqs=256 \ ++ --gpu-memory-utilization 0.4 \ ++ --max-num-batched-tokens 8192 \ ++ $prefix_args" ++fi ++echo ++echo ++ ++echo $CMD ++eval $CMD +\ No newline at end of file +diff --git a/examples/offline_inference/prefix_caching_flexkv.py b/examples/offline_inference/prefix_caching_flexkv.py +new file mode 100644 +index 000000000..6ff17dfca +--- /dev/null ++++ b/examples/offline_inference/prefix_caching_flexkv.py +@@ -0,0 +1,123 @@ ++# SPDX-License-Identifier: Apache-2.0 ++import os ++ ++from vllm import LLM, SamplingParams ++from vllm.distributed import cleanup_dist_env_and_memory ++ ++# NOTE: This is just a running example. For benchmarking purpose, ++# please see benchmarks/benchmark_prefix_caching.py ++ ++os.environ["ENABLE_FLEXKV"] = "true" ++os.environ["FLEXKV_SERVER_RECV_PORT"] = "ipc:///tmp/tmpe0x8_0gq" ++ ++# Common prefix. ++prefix = ( ++ "You are an expert school principal, skilled in effectively managing " ++ "faculty and staff. Draft 10-15 questions for a potential first grade " ++ "Head Teacher for my K-12, all-girls', independent school that emphasizes " ++ "community, joyful discovery, and life-long learning. The candidate is " ++ "coming in for a first-round panel interview for a 8th grade Math " ++ "teaching role. They have 5 years of previous teaching experience " ++ "as an assistant teacher at a co-ed, public school with experience " ++ "in middle school math teaching. Based on these information, fulfill " ++ "the following paragraph: ") ++ ++# Sample prompts. ++prompts = [ ++ "Hello, my name is", ++ "The president of the United States is", ++ "The capital of France is", ++ "The future of AI is", ++] ++ ++generating_prompts = [prefix + prompt for prompt in prompts] ++ ++# Create a sampling params object. ++sampling_params = SamplingParams(temperature=0.0) ++ ++def main(): ++ # Create an LLM without prefix caching as a baseline. ++ regular_llm = LLM(model="facebook/opt-125m", ++ enable_prefix_caching=False, ++ gpu_memory_utilization=0.4) ++ ++ print("Results without `enable_prefix_caching`") ++ ++ # ruff: noqa: E501 ++ # Generate texts from the prompts. The output is a list of RequestOutput objects ++ # that contain the prompt, generated text, and other information. ++ outputs = regular_llm.generate(generating_prompts, sampling_params) ++ ++ regular_generated_texts = [] ++ # Print the outputs. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ regular_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Destroy the LLM object and free up the GPU memory. ++ del regular_llm ++ cleanup_dist_env_and_memory() ++ ++ # Create an LLM with prefix caching enabled. ++ prefix_cached_llm = LLM(model="facebook/opt-125m", ++ enable_prefix_caching=True, ++ gpu_memory_utilization=0.4) ++ ++ # Warmup so that the shared prompt's KV cache is computed. ++ prefix_cached_llm.generate(generating_prompts[0], sampling_params) ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `enable_prefix_caching`") ++ ++ cached_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ cached_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == cached_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ # reset prefix cache to use flexkv ++ prefix_cached_llm.reset_prefix_cache() ++ ++ # Generate with prefix caching. ++ outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) ++ ++ print("Results with `flexkv`") ++ ++ flexkv_generated_texts = [] ++ # Print the outputs. You should see the same outputs as before. ++ print("-" * 50) ++ for output in outputs: ++ prompt = output.prompt ++ generated_text = output.outputs[0].text ++ flexkv_generated_texts.append(generated_text) ++ print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") ++ print("-" * 50) ++ ++ # Compare the results and display the speedup ++ generated_same = all([ ++ regular_generated_texts[i] == flexkv_generated_texts[i] ++ for i in range(len(prompts)) ++ ]) ++ print(f"Generated answers are the same: {generated_same}") ++ ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/vllm/distributed/flexkv_extension/__init__.py b/vllm/distributed/flexkv_extension/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/vllm/distributed/flexkv_extension/client.py b/vllm/distributed/flexkv_extension/client.py +new file mode 100644 +index 000000000..478683fa9 +--- /dev/null ++++ b/vllm/distributed/flexkv_extension/client.py +@@ -0,0 +1,101 @@ ++import torch ++from typing import Optional ++ ++from flexkv.server.client import KVDPClient, KVTPClient ++from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType ++from flexkv.common.config import ModelConfig ++from vllm.distributed.flexkv_extension.config import FlexKVConfig ++from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec ++from vllm.logger import init_logger ++ ++logger = init_logger(__name__) ++ ++ ++class FlexKVDPClient: ++ def __init__( ++ self, ++ flexkv_config: FlexKVConfig ++ ): ++ self.flexkv_config = flexkv_config ++ self.server_recv_port = flexkv_config.server_recv_port ++ self.tp_size = flexkv_config.tp_size ++ self.model_config = ModelConfig( ++ num_layers=flexkv_config.num_layers, ++ num_kv_heads=flexkv_config.num_kv_heads, ++ head_size=flexkv_config.head_size, ++ use_mla=flexkv_config.use_mla, ++ dtype=flexkv_config.dtype, ++ tp_size=flexkv_config.tp_size, ++ ) ++ ++ logger.info(f"start init FlexKVDPClient to {self.server_recv_port}") ++ self.dp_client = KVDPClient(self.server_recv_port, self.model_config) ++ logger.info(f"finish init FlexKVDPClient") ++ ++ def put_async( ++ self, ++ token_ids: torch.Tensor, ++ slot_mapping: torch.Tensor, ++ token_mask: Optional[torch.Tensor] = None, ++ ) -> int: ++ " return task_id " ++ return self.dp_client.put_async(token_ids, slot_mapping, token_mask) ++ ++ def get_async( ++ self, ++ token_ids: torch.Tensor, ++ slot_mapping: torch.Tensor, ++ token_mask: Optional[torch.Tensor] = None, ++ ) -> int: ++ " return task_id " ++ return self.dp_client.get_async(token_ids, slot_mapping, token_mask) ++ ++ def wait( ++ self, ++ wait_task_ids: list[int], ++ ) -> dict[int, torch.Tensor]: ++ return self.dp_client.wait(wait_task_ids) ++ ++ def try_wait( ++ self, ++ wait_task_ids: list[int], ++ ) -> dict[int, Optional[torch.Tensor]]: ++ # print("--------------------------------") ++ # print(f"[FlexKVDPClient] About to call dp_client.try_wait with {wait_task_ids}") ++ try: ++ result = self.dp_client.try_wait(wait_task_ids) ++ # print(f"[FlexKVDPClient] dp_client.try_wait returned: {result}") ++ return result ++ except Exception as e: ++ # print(f"[FlexKVDPClient ERROR] Exception calling dp_client.try_wait: {e}") ++ import traceback ++ traceback.print_exc() ++ return {} ++ ++ ++class FlexKVTPClient: ++ def __init__( ++ self, ++ flexkv_config: FlexKVConfig, ++ dp_client_id: int, ++ tp_rank: int, ++ device_id: int, ++ gpu_blocks: list[torch.Tensor], ++ kv_shape: tuple[int], ++ ): ++ logger.info(f"start init FlexKVTPClient to {flexkv_config.server_recv_port}") ++ self.tp_client = KVTPClient(flexkv_config.server_recv_port, dp_client_id, device_id, tp_rank) ++ logger.info(f"finish init FlexKVTPClient") ++ gpu_layout = KVCacheLayout( ++ type=KVCacheLayoutType.LAYERWISE, ++ num_layer=flexkv_config.num_layers, ++ num_block=flexkv_config.num_blocks, ++ tokens_per_block=flexkv_config.block_size, ++ num_head=flexkv_config.num_kv_heads, ++ head_size=flexkv_config.head_size, ++ is_mla=flexkv_config.use_mla, ++ ) ++ logger.info(f"start register FlexKVTPClient") ++ self.tp_client.register_to_server(gpu_blocks, gpu_layout) ++ ++ logger.info(f"finish register FlexKVTPClient") +\ No newline at end of file +diff --git a/vllm/distributed/flexkv_extension/config.py b/vllm/distributed/flexkv_extension/config.py +new file mode 100644 +index 000000000..f2724e712 +--- /dev/null ++++ b/vllm/distributed/flexkv_extension/config.py +@@ -0,0 +1,45 @@ ++from dataclasses import dataclass ++import json ++import os ++import torch ++from vllm.v1.kv_cache_interface import KVCacheConfig, FullAttentionSpec ++from vllm.logger import init_logger ++ ++logger = init_logger(__name__) ++ ++ ++@dataclass ++class FlexKVConfig: ++ enable_flexkv: bool ++ server_recv_port: str ++ num_blocks: int = None ++ block_size: int = None ++ num_layers: int = None ++ num_kv_heads: int = None ++ head_size: int = None ++ dtype: torch.dtype = None ++ use_mla: bool = False ++ tp_size: int = 1 ++ ++ @classmethod ++ def from_env(cls) -> 'FlexKVConfig': ++ enable_flexkv = (os.getenv('ENABLE_FLEXKV', "false").lower() == "true") ++ server_recv_port = os.getenv('FLEXKV_SERVER_RECV_PORT', "") ++ ++ return cls(enable_flexkv=enable_flexkv, ++ server_recv_port=server_recv_port) ++ ++ def post_init( ++ self, ++ kv_cache_config: KVCacheConfig, ++ tp_size: int ++ ): ++ self.num_blocks = kv_cache_config.num_blocks ++ self.num_layers = len(kv_cache_config.kv_cache_groups) ++ kv_cache_spec: FullAttentionSpec = kv_cache_config.kv_cache_groups[0].kv_cache_spec ++ self.block_size = kv_cache_spec.block_size ++ self.num_kv_heads = kv_cache_spec.num_kv_heads ++ self.head_size = kv_cache_spec.head_size ++ self.dtype = kv_cache_spec.dtype ++ self.use_mla = kv_cache_spec.use_mla ++ self.tp_size = tp_size +\ No newline at end of file +diff --git a/vllm/logger.py b/vllm/logger.py +index 69aaf4390..fe426f420 100644 +--- a/vllm/logger.py ++++ b/vllm/logger.py +@@ -21,7 +21,7 @@ VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH + VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL + VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX + +-_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " ++_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s.%(msecs)03d " + "[%(filename)s:%(lineno)d] %(message)s") + _DATE_FORMAT = "%m-%d %H:%M:%S" + +diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py +index 5b0218640..aa590eb6f 100644 +--- a/vllm/v1/core/kv_cache_utils.py ++++ b/vllm/v1/core/kv_cache_utils.py +@@ -87,8 +87,9 @@ class PrefixCachingMetrics: + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 ++ self.aggregated_query_flexkv_hit = 0 + # A deque of (requests, queries, hits) for the most recent requests. +- self.query_queue: deque[tuple[int, int, int]] = deque() ++ self.query_queue: deque[tuple[int, int, int, int]] = deque() + + def observe(self, stats: PrefixCacheStats): + """Observe the prefix caching for a set of requests. +@@ -108,14 +109,15 @@ class PrefixCachingMetrics: + self.reset() + + # Update the metrics. +- self.query_queue.append((stats.requests, stats.queries, stats.hits)) ++ self.query_queue.append((stats.requests, stats.queries, stats.hits, stats.flexkv_hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits ++ self.aggregated_query_flexkv_hit += stats.flexkv_hits + + # Remove the oldest stats if the number of requests exceeds. + if self.aggregated_requests > self.max_recent_requests: +- old_requests, old_queries, old_hits = self.query_queue.popleft() ++ old_requests, old_queries, old_hits, _ = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits +@@ -125,6 +127,7 @@ class PrefixCachingMetrics: + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 ++ self.aggregated_query_flexkv_hit = 0 + self.query_queue.clear() + + @property +@@ -133,6 +136,13 @@ class PrefixCachingMetrics: + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total ++ ++ @property ++ def flexkv_hit_rate(self) -> float: ++ """Calculate the hit rate for the past N requests.""" ++ if self.aggregated_query_total == 0: ++ return 0.0 ++ return self.aggregated_query_flexkv_hit / self.aggregated_query_total + + + @dataclass +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 446f98034..b465c4cf1 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -5,6 +5,7 @@ from __future__ import annotations + + import itertools + import time ++import torch + from collections import defaultdict + from collections.abc import Iterable + from typing import Any, Optional, Union +@@ -34,6 +35,9 @@ from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + from vllm.v1.structured_output import StructuredOutputManager ++# flexkv ++from vllm.utils import cdiv ++from vllm.distributed.flexkv_extension.config import FlexKVConfig + + logger = init_logger(__name__) + +@@ -162,6 +166,23 @@ class Scheduler(SchedulerInterface): + ) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + ++ # flexkv ++ self.enable_flexkv = False ++ self.flexkv_client = None ++ # task_id -> Request ++ self.load_kv_tasks: dict[int, Request] = {} ++ # task_id -> Request ++ self.offload_kv_tasks: dict[int, Request] = {} ++ # request_id -> time info ++ self.flexkv_timer: dict[str, dict[str, float]] = {} ++ ++ ++ def init_flexkv(self, flexkv_config: FlexKVConfig) -> int: ++ self.enable_flexkv = True ++ from vllm.distributed.flexkv_extension.client import FlexKVDPClient ++ self.flexkv_client = FlexKVDPClient(flexkv_config) ++ return self.flexkv_client.dp_client.dp_client_id ++ + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. +@@ -174,6 +195,13 @@ class Scheduler(SchedulerInterface): + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. + ++ # flexkv ++ if self.enable_flexkv: ++ # aviod busy loop ++ if self.get_num_unfinished_requests() == 0: ++ time.sleep(0.01) ++ self.check_offload_kv_tasks() ++ + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] +@@ -448,6 +476,27 @@ class Scheduler(SchedulerInterface): + if new_blocks is None: + # The request cannot be scheduled. + break ++ ++ if self.enable_flexkv and num_new_tokens > self.block_size and request.status == RequestStatus.WAITING: ++ # don't match the last block ++ num_new_blocks_to_get = cdiv(num_new_tokens, self.block_size)-1 ++ num_new_tokens_to_match = num_new_blocks_to_get*self.block_size ++ num_tokens_to_get = num_computed_tokens + num_new_tokens_to_match ++ blocks_ids_to_get = [block.block_id for block in new_blocks.blocks[0][:num_new_blocks_to_get]] ++ slot_mapping = torch.tensor(blocks_ids_to_get).repeat_interleave(self.block_size)*self.block_size ++ token_mask_to_get = torch.ones(num_tokens_to_get, dtype=torch.bool) ++ token_mask_to_get[:num_computed_tokens] = False ++ t_async_get_start = time.monotonic() ++ task_id = self.flexkv_client.get_async( ++ token_ids=torch.tensor(request.all_token_ids[:num_tokens_to_get]), ++ slot_mapping=slot_mapping, ++ token_mask=token_mask_to_get) ++ t_async_get_return = time.monotonic() ++ ++ self.load_kv_tasks[task_id] = request ++ self.flexkv_timer[request.request_id] = {} ++ self.flexkv_timer[request.request_id]['get_async_start'] = t_async_get_start ++ self.flexkv_timer[request.request_id]['get_async_return'] = t_async_get_return + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that +@@ -505,6 +554,31 @@ class Scheduler(SchedulerInterface): + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget ++ # batch wait ++ ++ # batch wait ++ if self.enable_flexkv: ++ if len(self.load_kv_tasks) != 0: ++ task_ids = list(self.load_kv_tasks.keys()) ++ print(f"[DEBUG] scheduler wait for {task_ids}") ++ results = self.flexkv_client.wait(task_ids) ++ print(f"[DEBUG] scheduler wait result: {results}") ++ t_async_get_end = time.monotonic() ++ for task_id, task_result in results.items(): ++ request = self.load_kv_tasks.pop(task_id) ++ t_get_async_start = self.flexkv_timer[request.request_id]["get_async_start"] ++ t_get_async_return = self.flexkv_timer[request.request_id]["get_async_return"] ++ match_length = task_result.sum().item() ++ self.flexkv_timer.pop(request.request_id) ++ logger.info( ++ f"[FlexKV] req: {request.request_id}, task: {task_id}, " ++ f"get {match_length} tokens cost {(t_async_get_end-t_get_async_start)*1000:.2f} ms, " ++ f"get_async() api cost {(t_get_async_return-t_get_async_start)*1000:.2f} ms") ++ ++ token_budget += match_length ++ num_scheduled_tokens[request.request_id] -= match_length ++ request.num_computed_tokens += match_length ++ self.kv_cache_manager.prefix_cache_stats.flexkv_hits += (match_length//self.block_size) + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: +@@ -1016,11 +1090,49 @@ class Scheduler(SchedulerInterface): + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + +- if not delay_free_blocks: +- self._free_blocks(request) ++ # flexkv: offload BEFORE freeing blocks to preserve req_to_blocks info ++ if self.enable_flexkv: ++ self._offload_kv(request) ++ else: ++ if not delay_free_blocks: ++ self._free_blocks(request) ++ # else: ++ # self._free_block(request) ++ + + return kv_xfer_params + ++ def _free_block(self, request: Request) -> None: ++ self.kv_cache_manager.free(request) ++ self.kv_cache_manager.free_block_hashes(request) ++ del self.requests[request.request_id] ++ ++ def _offload_kv(self, request: Request): ++ # print(f"single_type_managers: {self.kv_cache_manager.coordinator.single_type_managers}") ++ # print(f"req_to_blocks: {self.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks}") ++ req_blocks = self.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks.get(request.request_id, []) ++ req_token_ids = torch.tensor(request.all_token_ids[:-1]) ++ req_block_ids = torch.tensor([block.block_id for block in req_blocks]) ++ ++ # Debug information for empty req_blocks ++ # if len(req_blocks) == 0: ++ # print(f"WARNING: Empty req_blocks for request {request.request_id}") ++ # print(f" request.all_token_ids length: {len(request.all_token_ids)}") ++ # print(f" req_token_ids length: {len(req_token_ids)}") ++ # print(f" req_to_blocks keys: {list(self.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks.keys())}") ++ ++ slot_mapping = req_block_ids.repeat_interleave(self.block_size)[:len(req_token_ids)] * self.block_size ++ ++ # Additional debug info ++ # print(f"FlexKV _offload_kv: req_id={request.request_id}, " ++ # f"blocks={len(req_blocks)}, tokens={len(req_token_ids)}, slots={len(slot_mapping)}") ++ ++ self.flexkv_timer[request.request_id] = {} ++ self.flexkv_timer[request.request_id]["put_async_start"] = time.monotonic() ++ task_id = self.flexkv_client.put_async(token_ids=req_token_ids, slot_mapping=slot_mapping) ++ self.offload_kv_tasks[task_id] = request ++ self.flexkv_timer[request.request_id]["put_async_return"] = time.monotonic() ++ + def _free_blocks(self, request: Request): + assert request.is_finished() + self.kv_cache_manager.free(request) +@@ -1068,7 +1180,27 @@ class Scheduler(SchedulerInterface): + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) + return spec_decoding_stats +- ++ ++ def check_offload_kv_tasks(self): ++ if len(self.offload_kv_tasks) == 0: ++ return ++ logger.info(f"check_offload_kv_tasks") ++ task_ids = list(self.offload_kv_tasks.keys()) ++ results = self.flexkv_client.try_wait(task_ids) ++ # logger.info(f"results {results}") ++ t_async_put_end = time.monotonic() ++ for task_id, task_result in results.items(): ++ if task_result is not None: ++ request = self.offload_kv_tasks.pop(task_id) ++ t_put_async_start = self.flexkv_timer[request.request_id]["put_async_start"] ++ t_put_async_return = self.flexkv_timer[request.request_id]["put_async_return"] ++ self.flexkv_timer.pop(request.request_id) ++ logger.info( ++ f"[FlexKV] req: {request.request_id}, task: {task_id}, " ++ f"put {sum(task_result).item()} tokens cost {(t_async_put_end-t_put_async_start)*1000:.2f} ms, " ++ f"put_async() api cost {(t_put_async_return-t_put_async_start)*1000:.2f} ms") ++ self._free_block(request) ++ + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() +diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py +index 7779b559c..2d17908ea 100644 +--- a/vllm/v1/engine/core.py ++++ b/vllm/v1/engine/core.py +@@ -46,6 +46,8 @@ from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + from vllm.v1.structured_output import StructuredOutputManager + from vllm.version import __version__ as VLLM_VERSION + ++from vllm.distributed.flexkv_extension.config import FlexKVConfig ++ + logger = init_logger(__name__) + + POLLING_TIMEOUT_S = 2.5 +@@ -118,6 +120,8 @@ class EngineCore: + log_stats=self.log_stats, + ) + ++ self.init_flexkv(vllm_config, kv_cache_config) ++ + # Setup MM Input Mapper. + self.mm_input_cache_server = MirroredProcessingCache( + vllm_config.model_config) +@@ -194,6 +198,23 @@ class EngineCore: + "warmup model) took %.2f seconds"), elapsed) + return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + ++ ++ def init_flexkv( ++ self, ++ taco_llm_config: VllmConfig, ++ kv_cache_config: KVCacheConfig ++ ): ++ self.scheduler: V1Scheduler ++ if taco_llm_config.cache_config.enable_prefix_caching: ++ flexkv_config = FlexKVConfig.from_env() ++ if flexkv_config.enable_flexkv: ++ flexkv_config.post_init( ++ kv_cache_config=kv_cache_config, ++ tp_size=taco_llm_config.parallel_config.tensor_parallel_size, ++ ) ++ dp_client_id = self.scheduler.init_flexkv(flexkv_config) ++ self.model_executor.init_flexkv(flexkv_config, dp_client_id) ++ + def add_request(self, request: EngineCoreRequest): + """Add request to the scheduler.""" + if pooling_params := request.pooling_params: +diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py +index 50b9634a4..3d7bdd4c8 100644 +--- a/vllm/v1/executor/abstract.py ++++ b/vllm/v1/executor/abstract.py +@@ -15,7 +15,7 @@ from vllm.executor.uniproc_executor import ( # noqa + UniProcExecutor as UniProcExecutorV0) + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + from vllm.v1.outputs import ModelRunnerOutput +- ++from vllm.distributed.flexkv_extension.config import FlexKVConfig + FailureCallback = Callable[[], None] + + +@@ -88,6 +88,10 @@ class Executor(ExecutorBase): + args=(scheduler_output, )) + return output[0] + ++ def init_flexkv(self, flexkv_config: FlexKVConfig, dp_client_id: int): ++ self.collective_rpc("init_flexkv", ++ args=(flexkv_config, dp_client_id, )) ++ + @property + def max_concurrent_batches(self) -> int: + return 1 +diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py +index 7f2556bab..e7fb79486 100644 +--- a/vllm/v1/metrics/loggers.py ++++ b/vllm/v1/metrics/loggers.py +@@ -125,7 +125,8 @@ class LoggingStatLogger(StatLoggerBase): + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Waiting: %d reqs, " + "GPU KV cache usage: %.1f%%, " +- "Prefix cache hit rate: %.1f%%", ++ "Prefix cache hit rate: %.1f%%, " ++ "FlexKV hit rate: %.1f%%", + self.engine_index, + prompt_throughput, + generation_throughput, +@@ -133,6 +134,7 @@ class LoggingStatLogger(StatLoggerBase): + scheduler_stats.num_waiting_reqs, + scheduler_stats.kv_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, ++ self.prefix_caching_metrics.flexkv_hit_rate * 100, + ) + self.spec_decoding_logging.log(log_fn=log_fn) + +diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py +index 1eb10ccb6..1073aa571 100644 +--- a/vllm/v1/metrics/stats.py ++++ b/vllm/v1/metrics/stats.py +@@ -24,7 +24,8 @@ class PrefixCacheStats: + queries: int = 0 + # The number of hits in these requests. + hits: int = 0 +- ++ # flexkv ++ flexkv_hits: int = 0 + + @dataclass + class SchedulerStats: +diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py +index a5bf197ba..d10265d0c 100644 +--- a/vllm/v1/worker/gpu_model_runner.py ++++ b/vllm/v1/worker/gpu_model_runner.py +@@ -2494,6 +2494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + ) == 0, "Attention backends are already initialized" + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): ++ print("init attn backend ", i) + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(kv_cache_spec, AttentionSpec): + attn_backend_i = get_attn_backend( +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 522946351..31a3bed13 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -18,7 +18,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, + set_custom_all_reduce) + from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) +-from vllm.distributed.parallel_state import get_pp_group, get_tp_group ++from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_tensor_model_parallel_rank + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest + from vllm.model_executor import set_random_seed +@@ -33,6 +33,10 @@ from vllm.v1.utils import report_usage_stats + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.worker_base import WorkerBase + ++# flexkv ++from vllm.distributed.flexkv_extension.config import FlexKVConfig ++from vllm.distributed.flexkv_extension.client import FlexKVTPClient ++ + logger = init_logger(__name__) + + if TYPE_CHECKING: +@@ -556,6 +560,23 @@ class Worker(WorkerBase): + max_size=max_size, + ) + ++ def init_flexkv( ++ self, ++ flexkv_config: FlexKVConfig, ++ dp_client_id: int, ++ ) -> None: ++ from vllm.distributed.flexkv_extension.client import FlexKVTPClient ++ layer_kv_shape = self.model_runner.attn_backends[0].get_kv_cache_shape( ++ flexkv_config.num_blocks, flexkv_config.block_size, ++ flexkv_config.num_kv_heads, flexkv_config.head_size) ++ kv_shape = (flexkv_config.num_layers, *layer_kv_shape) ++ self.flexkv_client = FlexKVTPClient(flexkv_config=flexkv_config, ++ dp_client_id=dp_client_id, ++ tp_rank=get_tensor_model_parallel_rank(), ++ device_id=self.device.index, ++ gpu_blocks=self.model_runner.kv_caches, ++ kv_shape=kv_shape) ++ + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", diff --git a/examples/vllm_adaption/flexkv_vllm_0_8_4.patch b/examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch similarity index 100% rename from examples/vllm_adaption/flexkv_vllm_0_8_4.patch rename to examples/vllm_adaption_legacy/flexkv_vllm_0_8_4.patch diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 669163cc83..e113aeb69b 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -18,27 +18,134 @@ from functools import partial from queue import Queue from typing import List, Tuple, Optional, Dict, Callable +from dataclasses import dataclass, field +import numpy as np +import nvtx import torch +from flexkv.c_ext import CRadixNode, CRadixTreeIndex, CMatchResult from flexkv.cache.mempool import Mempool from flexkv.cache.radixtree import RadixTreeIndex, RadixNode, MatchResult -from flexkv.cache.transfer_pattern import ( - convert_read_graph_to_layer_wise_graph, add_virtal_op_for_mutiple_finished_ops -) +from flexkv.cache.transfer_pattern import add_virtal_op_for_mutiple_finished_ops from flexkv.common.block import SequenceMeta from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.exceptions import InvalidConfigError, NotEnoughSpaceError from flexkv.common.transfer import ( - DeviceType, TransferOpGraph, TransferOp, TransferType, TransferDescriptor + DeviceType, TransferOpGraph, TransferOp, TransferType ) +from flexkv.common.debug import flexkv_logger +@dataclass +class MatchResultAccel: + num_ready_matched_blocks: int = 0 + num_matched_blocks: int = 0 + last_ready_node: Optional['CRadixNode'] = None + last_node: Optional['CRadixNode'] = None + last_node_matched_length: int = 0 + physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + + def __post_init__(self) -> None: + assert self.physical_blocks.ndim == 1 + +class CacheEngineAccel: + def __init__(self, + device_type: DeviceType, + num_total_blocks: int, + tokens_per_block: int, + evict_ratio: float): + if not isinstance(device_type, DeviceType): + raise InvalidConfigError(f"Unknown device type: {device_type}") + if num_total_blocks <= 0: + raise InvalidConfigError(f"Invalid num_total_blocks: {num_total_blocks}") + if tokens_per_block <= 0 or (tokens_per_block & (tokens_per_block - 1)) != 0: + raise InvalidConfigError(f"Invalid tokens_per_block: {tokens_per_block}, " + f"tokens_per_block must be a power of 2") + + self.device_type = device_type + + self.index = CRadixTreeIndex(tokens_per_block, num_total_blocks) + + self.mempool = Mempool(num_total_blocks=num_total_blocks) + + self.tokens_per_block = tokens_per_block + self.num_total_blocks = num_total_blocks + self.evict_ratio = evict_ratio + + def reset(self) -> None: + self.index.reset() + self.mempool.reset() + + def match(self, sequence_meta: SequenceMeta) -> MatchResultAccel: + sequence_meta.gen_hashes() + match_result = self.index.match_prefix(torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, True) + return MatchResultAccel(match_result.num_ready_matched_blocks, match_result.num_matched_blocks, + match_result.last_ready_node, match_result.last_node, + match_result.last_node_matched_length, + torch.tensor(match_result.physical_blocks, dtype=torch.int64).numpy()) + def insert(self, + sequence_meta: SequenceMeta, + physical_block_ids: torch.Tensor, + num_insert_blocks: int = -1, + is_ready: bool = True, + match_result: Optional[MatchResultAccel] = None) -> Optional[CRadixNode]: + sequence_meta.gen_hashes() + if match_result is None: + return self.index.insert(torch.from_numpy(physical_block_ids).to(torch.int64), + torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, + num_insert_blocks, + is_ready) + else: + return self.index.insert(torch.from_numpy(physical_block_ids).to(torch.int64), + torch.from_numpy(sequence_meta.block_hashes).to(torch.int64), + sequence_meta.num_blocks, + num_insert_blocks, + is_ready, + match_result.last_node, + match_result.num_matched_blocks, + match_result.last_node_matched_length) + + def lock_node(self, node: CRadixNode) -> None: + self.index.lock(node) + + def cleanup(self, node: CRadixNode, cleanup_length: int) -> None: + self.index.unlock(node) + self.index.set_ready(node, True, cleanup_length) + + def take(self, + num_required_blocks: int, + protected_node: Optional[CRadixNode] = None, + strict: bool = True) -> torch.Tensor: + if num_required_blocks > self.mempool.num_free_blocks: + if protected_node is not None: + self.index.lock(protected_node) + evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) + target_blocks = torch.zeros(evict_block_num, dtype=torch.int64) + num_evicted = self.index.evict(target_blocks, evict_block_num) + if num_evicted != evict_block_num: + target_blocks.resize_(num_evicted) + self.mempool.recycle_blocks(target_blocks.numpy()) + + if protected_node is not None: + self.index.unlock(protected_node) + if strict and num_required_blocks > self.mempool.num_free_blocks: + raise NotEnoughSpaceError("Not enough free blocks to take, ", + required=num_required_blocks, + available=self.mempool.num_free_blocks) + num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) + return self.mempool.allocate_blocks(num_allocated_blocks) + + def recycle(self, physical_blocks: np.ndarray) -> None: + self.mempool.recycle_blocks(physical_blocks) class CacheEngine: def __init__(self, device_type: DeviceType, num_total_blocks: int, - tokens_per_block: int): + tokens_per_block: int, + evict_ratio: float): if not isinstance(device_type, DeviceType): raise InvalidConfigError(f"Unknown device type: {device_type}") if num_total_blocks <= 0: @@ -55,6 +162,7 @@ def __init__(self, self.tokens_per_block = tokens_per_block self.num_total_blocks = num_total_blocks + self.evict_ratio = evict_ratio def reset(self) -> None: self.index.reset() @@ -67,7 +175,7 @@ def match(self, sequence_meta: SequenceMeta) -> MatchResult: def insert(self, sequence_meta: SequenceMeta, - physical_block_ids: torch.Tensor, + physical_block_ids: np.ndarray, num_insert_blocks: int = -1, is_ready: bool = True, match_result: Optional[MatchResult] = None) -> Optional[RadixNode]: @@ -87,12 +195,13 @@ def cleanup(self, node: RadixNode, cleanup_length: int) -> None: def take(self, num_required_blocks: int, protected_node: Optional[RadixNode] = None, - strict: bool = True) -> torch.Tensor: + strict: bool = True) -> np.ndarray: if num_required_blocks > self.mempool.num_free_blocks: if protected_node is not None: self.index.lock(protected_node) + evict_block_num = max(num_required_blocks - self.mempool.num_free_blocks, int(self.mempool.num_total_blocks * self.evict_ratio)) self.mempool.recycle_blocks( - self.index.evict(num_required_blocks - self.mempool.num_free_blocks) + self.index.evict(evict_block_num) ) if protected_node is not None: self.index.unlock(protected_node) @@ -103,7 +212,7 @@ def take(self, num_allocated_blocks = min(num_required_blocks, self.mempool.num_free_blocks) return self.mempool.allocate_blocks(num_allocated_blocks) - def recycle(self, physical_blocks: torch.Tensor) -> None: + def recycle(self, physical_blocks: np.ndarray) -> None: self.mempool.recycle_blocks(physical_blocks) class GlobalCacheEngine: @@ -119,19 +228,40 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig): self.cache_engines = {} if cache_config.enable_cpu: - self.cpu_cache_engine = CacheEngine(DeviceType.CPU, + if cache_config.index_accel: + self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU, cache_config.num_cpu_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) + else: + self.cpu_cache_engine = CacheEngine(DeviceType.CPU, + cache_config.num_cpu_blocks, + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine if cache_config.enable_ssd: - self.ssd_cache_engine = CacheEngine(DeviceType.SSD, + if cache_config.index_accel: + self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD, + cache_config.num_ssd_blocks, + cache_config.tokens_per_block, + cache_config.evict_ratio) + else: + self.ssd_cache_engine = CacheEngine(DeviceType.SSD, cache_config.num_ssd_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.SSD] = self.ssd_cache_engine if cache_config.enable_remote: - self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, + if cache_config.index_accel: + self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE, + cache_config.num_remote_blocks, + cache_config.tokens_per_block, + cache_config.evict_ratio) + else: + self.remote_cache_engine = CacheEngine(DeviceType.REMOTE, cache_config.num_remote_blocks, - cache_config.tokens_per_block) + cache_config.tokens_per_block, + cache_config.evict_ratio) self.cache_engines[DeviceType.REMOTE] = self.remote_cache_engine self._empty_get_return: Callable[[int], Tuple[TransferOpGraph, List[int], Dict, Dict, int]] = \ @@ -149,25 +279,34 @@ def reset(self) -> None: def get(self, request_id: int, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor, + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray, layer_num: int = -1, layer_granularity: int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, torch.Tensor, Callable, List[int]]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: self._check_input(token_ids, token_mask, slot_mapping) + if layer_num == -1: layer_num = self.model_config.num_layers if layer_granularity == -1: layer_granularity = layer_num + if layer_num != layer_granularity: + flexkv_logger.error(f"Layerwise transfer is not supported yet, " + f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") + raise NotImplementedError(f"Layerwise transfer is not supported yet, " + f"layer_num: {layer_num}, layer_granularity: {layer_granularity}") + # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] + block_start_idx, block_end_idx = self._get_block_range(token_mask) assert block_end_idx == aligned_length // self.tokens_per_block - gpu_block_mapping = self._slot_to_block_mapping(slot_mapping)[:block_end_idx-block_start_idx] + gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, + self.tokens_per_block)[:block_end_idx-block_start_idx] sequence_meta = SequenceMeta(token_ids=aligned_token_ids, tokens_per_block=self.cache_config.tokens_per_block) @@ -179,7 +318,7 @@ def get(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) else: @@ -189,24 +328,24 @@ def get(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) - transfer_graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops( + transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, finished_ops_ids ) - return_mask = torch.zeros_like(token_mask) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[block_start_idx* self.tokens_per_block: (block_start_idx + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True - if layer_num // layer_granularity != 1: - transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, - finished_ops_ids=finished_ops_ids, - layer_num=layer_num, - layer_granularity=layer_granularity) + # if layer_num // layer_granularity != 1: + # transfer_graph, finished_ops_ids = convert_read_graph_to_layer_wise_graph(transfer_graph=transfer_graph, + # finished_ops_ids=finished_ops_ids, + # layer_num=layer_num, + # layer_granularity=layer_granularity) transfer_graph.bind_to_dp_group(dp_id) for device_type in node_to_unlock: @@ -216,14 +355,14 @@ def get(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, finished_ops_ids + return transfer_graph, return_mask, callback, task_end_op_id def _get_impl_global(self, request_id: int, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: """ transfer pattern: @@ -252,7 +391,7 @@ def _get_impl_global(self, #early return if no blocks to transfer if fragment123_num_blocks == 0: return self._empty_get_return(request_id) - assert fragment123_num_blocks <= len(gpu_block_mapping) + assert fragment123_num_blocks <= len(gpu_block_ids) transfer_graph = TransferOpGraph() finished_ops_ids = [] @@ -263,7 +402,7 @@ def _get_impl_global(self, fragment3_num_blocks = max(len(remote_matched_blocks) - fragment12_num_blocks, 0) fragment23_num_blocks = fragment2_num_blocks + fragment3_num_blocks - fragment123_gpu_blocks = gpu_block_mapping[:fragment123_num_blocks] + fragment123_gpu_blocks = gpu_block_ids[:fragment123_num_blocks] fragment123_cpu_blocks = cpu_matched_blocks fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] fragment3_remote_blocks = remote_matched_blocks[-fragment3_num_blocks:] @@ -271,7 +410,7 @@ def _get_impl_global(self, cpu_node_to_unlock = cpu_matched_result.last_ready_node ssd_node_to_unlock = ssd_matched_result.last_ready_node remote_node_to_unlock = remote_matched_result.last_ready_node - cpu_blocks_to_free = torch.tensor([], dtype=torch.int64) + cpu_blocks_to_free = np.array([], dtype=np.int64) if fragment23_num_blocks > 0: num_extra_required_blocks = fragment23_num_blocks @@ -283,7 +422,7 @@ def _get_impl_global(self, if len(fragment23_cpu_blocks) < num_extra_required_blocks: self.cpu_cache_engine.recycle(fragment23_cpu_blocks) return self._empty_get_return(request_id) - fragment123_cpu_blocks = torch.cat([fragment123_cpu_blocks, fragment23_cpu_blocks]) + fragment123_cpu_blocks = np.concatenate([fragment123_cpu_blocks, fragment23_cpu_blocks]) # we only insert the buffer blocks to cpu cache engine only: # 1. the cpu cache engine satisfies prefix cache after insertion # 2. the sequence is all ready blocks @@ -302,14 +441,8 @@ def _get_impl_global(self, op_disk2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[fragment1_num_blocks:fragment12_num_blocks] - ), + src_block_ids = fragment2_ssd_blocks, + dst_block_ids = fragment123_cpu_blocks[fragment1_num_blocks:fragment12_num_blocks], layer_id = 0, layer_granularity = layer_num ) @@ -320,14 +453,8 @@ def _get_impl_global(self, op_remote2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.REMOTE2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=fragment3_remote_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[-fragment3_num_blocks:], - ), + src_block_ids = fragment3_remote_blocks, + dst_block_ids = fragment123_cpu_blocks[-fragment3_num_blocks:], layer_id = 0, layer_granularity = layer_num ) @@ -353,14 +480,8 @@ def _get_impl_global(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks[-fragment3_num_blocks:], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment3_ssd_blocks, - ), + src_block_ids = fragment123_cpu_blocks[-fragment3_num_blocks:], + dst_block_ids = fragment3_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -377,15 +498,8 @@ def _get_impl_global(self, op_h2d = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment123_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment123_gpu_blocks, - device_id = 0 - ), + src_block_ids = fragment123_cpu_blocks, + dst_block_ids = fragment123_gpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -414,7 +528,7 @@ def _get_impl_local(self, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num: int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int]: """ transfer pattern: @@ -429,7 +543,10 @@ def _get_impl_local(self, assert self.cache_config.enable_cpu assert self.cpu_cache_engine is not None - cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result = self.match_local_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) # tailor the blocks to assure: # the blocks are needed by the mask & the blocks are ready @@ -445,12 +562,12 @@ def _get_impl_local(self, #early return if no blocks to transfer if fragment12_num_blocks == 0: return self._empty_get_return(request_id) - assert fragment12_num_blocks <= len(gpu_block_mapping) + assert fragment12_num_blocks <= len(gpu_block_ids) transfer_graph = TransferOpGraph() finished_ops_ids = [] - fragment12_gpu_blocks = gpu_block_mapping[:fragment12_num_blocks] + fragment12_gpu_blocks = gpu_block_ids[:fragment12_num_blocks] fragment2_ssd_blocks = ssd_matched_blocks[-fragment2_num_blocks:] fragment1_cpu_blocks = cpu_matched_blocks[:fragment1_num_blocks] @@ -458,7 +575,9 @@ def _get_impl_local(self, ssd_node_to_unlock = ssd_matched_result.last_ready_node # prepare cpu blocks to transfer - cpu_blocks_to_free = torch.tensor([], dtype=torch.int64) + cpu_blocks_to_free = np.array([], dtype=np.int64) + op_disk2h = None + fragment2_cpu_blocks = None if fragment2_num_blocks > 0: fragment2_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment2_num_blocks, @@ -474,39 +593,12 @@ def _get_impl_local(self, op_disk2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), + src_block_ids = fragment2_ssd_blocks, + dst_block_ids = fragment2_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) transfer_graph.add_transfer_op(op_disk2h) - - op_h2d_frag2 = TransferOp( - graph_id = transfer_graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks[-fragment2_num_blocks:], - device_id = 0 - ), - layer_id = 0, - layer_granularity = layer_num - ) - transfer_graph.add_transfer_op(op_h2d_frag2) - - transfer_graph.add_dependency(op_h2d_frag2.op_id, op_disk2h.op_id) - finished_ops_ids.append(op_h2d_frag2.op_id) - # we only insert the buffer blocks to cpu cache engine only: # 1. the cpu cache engine satisfies prefix cache after insertion # 2. the sequence is all ready blocks @@ -520,23 +612,22 @@ def _get_impl_local(self, match_result=cpu_matched_result) else: cpu_blocks_to_free = fragment2_cpu_blocks - op_h2d_frag1 = TransferOp( + if fragment2_cpu_blocks is not None: + fragment12_cpu_blocks = np.concatenate([fragment1_cpu_blocks, fragment2_cpu_blocks]) + else: + fragment12_cpu_blocks = fragment1_cpu_blocks + op_h2d = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment1_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks[:fragment1_num_blocks], - device_id = 0 - ), + src_block_ids = fragment12_cpu_blocks, + dst_block_ids = fragment12_gpu_blocks, layer_id = 0, layer_granularity = layer_num ) - transfer_graph.add_transfer_op(op_h2d_frag1) - finished_ops_ids.append(op_h2d_frag1.op_id) + transfer_graph.add_transfer_op(op_h2d) + if op_disk2h is not None: + transfer_graph.add_dependency(op_h2d.op_id, op_disk2h.op_id) + finished_ops_ids.append(op_h2d.op_id) node_to_unlock = {} if cpu_node_to_unlock is not None: @@ -549,11 +640,11 @@ def _get_impl_local(self, def put(self, request_id: int, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor, + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray, layer_num : int = -1, - dp_id: int = 0) -> Tuple[TransferOpGraph, torch.Tensor, Callable, List[int]]: + dp_id: int = 0) -> Tuple[TransferOpGraph, np.ndarray, Callable, int]: self._check_input(token_ids, token_mask, slot_mapping) if layer_num == -1: @@ -561,13 +652,14 @@ def put(self, # ignore the last incomplete block aligned_length = (token_ids.shape[0] // self.tokens_per_block) * self.tokens_per_block aligned_token_ids = token_ids[:aligned_length] - token_mask[aligned_length:] = False + token_mask = token_mask[:aligned_length] block_start_idx, block_end_idx = self._get_block_range(token_mask) # the mask should has a prefix of True assert block_start_idx == 0 - gpu_block_mapping = self._slot_to_block_mapping(slot_mapping)[:block_end_idx-block_start_idx] + gpu_block_ids = self.slot_mapping_to_block_ids(slot_mapping, + self.tokens_per_block)[:block_end_idx-block_start_idx] sequence_meta = SequenceMeta(token_ids=aligned_token_ids, tokens_per_block=self.cache_config.tokens_per_block) @@ -580,7 +672,7 @@ def put(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) else: @@ -591,16 +683,16 @@ def put(self, sequence_meta, block_start_idx, block_end_idx, - gpu_block_mapping, + gpu_block_ids, layer_num ) - transfer_graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops( + transfer_graph, task_end_op_id = add_virtal_op_for_mutiple_finished_ops( transfer_graph, finished_ops_ids ) - return_mask = torch.zeros_like(token_mask) + return_mask = np.zeros_like(token_mask, dtype=np.bool_) return_mask[(block_start_idx + skipped_gpu_blocks)* self.tokens_per_block: (block_start_idx + skipped_gpu_blocks + num_gpu_blocks_to_transfer) * self.tokens_per_block] = True transfer_graph.bind_to_dp_group(dp_id) @@ -612,14 +704,14 @@ def put(self, node_to_unlock=node_to_unlock, buffer_to_free=buffer_to_free) - return transfer_graph, return_mask, callback, finished_ops_ids + return transfer_graph, return_mask, callback, task_end_op_id def _put_impl_global(self, request_id: int, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: """ transfer pattern: @@ -639,7 +731,10 @@ def _put_impl_global(self, assert self.cpu_cache_engine is not None assert self.remote_cache_engine is not None - cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result, remote_matched_result = self.match_all(sequence_meta) cpu_matched_blocks = cpu_matched_result.physical_blocks[ :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ @@ -648,15 +743,15 @@ def _put_impl_global(self, :remote_matched_result.num_matched_blocks][block_mask_start:block_mask_end] num_skipped_blocks = len(cpu_matched_blocks) - fragment12_num_blocks = len(gpu_block_mapping) - num_skipped_blocks + fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks if fragment12_num_blocks == 0: return self._empty_put_return(request_id) - fragment2_num_blocks = len(gpu_block_mapping) - len(ssd_matched_blocks) + fragment2_num_blocks = len(gpu_block_ids) - len(ssd_matched_blocks) if not self.cache_config.enable_ssd: fragment2_num_blocks = 0 - fragment3_num_blocks = len(gpu_block_mapping) - len(remote_matched_blocks) + fragment3_num_blocks = len(gpu_block_ids) - len(remote_matched_blocks) - fragment12_gpu_blocks = gpu_block_mapping[num_skipped_blocks:] + fragment12_gpu_blocks = gpu_block_ids[num_skipped_blocks:] fragment12_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment12_num_blocks, @@ -678,7 +773,7 @@ def _put_impl_global(self, else: self.ssd_cache_engine.recycle(fragment2_ssd_blocks) else: - fragment2_ssd_blocks = torch.tensor([], dtype=torch.int64) + fragment2_ssd_blocks = np.array([], dtype=np.int64) put_to_remote = False if fragment3_num_blocks > 0: fragment3_remote_blocks = self.remote_cache_engine.take( @@ -691,7 +786,7 @@ def _put_impl_global(self, else: self.remote_cache_engine.recycle(fragment3_remote_blocks) else: - fragment3_remote_blocks = torch.tensor([], dtype=torch.int64) + fragment3_remote_blocks = np.array([], dtype=np.int64) transfer_graph = TransferOpGraph() finished_ops_ids = [] @@ -699,14 +794,8 @@ def _put_impl_global(self, op_d2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment12_cpu_blocks, - ), + src_block_ids = fragment12_gpu_blocks, + dst_block_ids = fragment12_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -718,14 +807,8 @@ def _put_impl_global(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), + src_block_ids = fragment2_cpu_blocks, + dst_block_ids = fragment2_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -736,21 +819,15 @@ def _put_impl_global(self, if put_to_remote: if fragment3_num_blocks > fragment12_num_blocks: extra_num_cpu_blocks = fragment3_num_blocks - fragment12_num_blocks - fragment3_cpu_blocks = torch.cat([fragment12_cpu_blocks, + fragment3_cpu_blocks = np.concatenate([fragment12_cpu_blocks, cpu_matched_blocks[-extra_num_cpu_blocks:]]) else: fragment3_cpu_blocks = fragment12_cpu_blocks[-fragment3_num_blocks:] op_h2remote = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2REMOTE, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment3_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=fragment3_remote_blocks, - ), + src_block_ids = fragment3_cpu_blocks, + dst_block_ids = fragment3_remote_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -789,7 +866,7 @@ def _put_impl_local(self, sequence_meta: SequenceMeta, block_mask_start: int, block_mask_end: int, - gpu_block_mapping: torch.Tensor, + gpu_block_ids: np.ndarray, layer_num : int) -> Tuple[TransferOpGraph, List[int], Dict, Dict, int, int]: """ transfer pattern: @@ -805,21 +882,24 @@ def _put_impl_local(self, assert self.cpu_cache_engine is not None # assert self.ssd_cache_engine is not None - cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) + if self.cache_config.index_accel: + cpu_matched_result, ssd_matched_result = self.match_local_accel(sequence_meta) + else: + cpu_matched_result, ssd_matched_result = self.match_local(sequence_meta) cpu_matched_blocks = cpu_matched_result.physical_blocks[ :cpu_matched_result.num_matched_blocks][block_mask_start:block_mask_end] ssd_matched_blocks = ssd_matched_result.physical_blocks[ :ssd_matched_result.num_matched_blocks][block_mask_start:block_mask_end] num_skipped_blocks = len(cpu_matched_blocks) - fragment12_num_blocks = len(gpu_block_mapping) - num_skipped_blocks + fragment12_num_blocks = len(gpu_block_ids) - num_skipped_blocks if fragment12_num_blocks == 0: return self._empty_put_return(request_id) - fragment2_num_blocks = len(gpu_block_mapping) - len(ssd_matched_blocks) + fragment2_num_blocks = len(gpu_block_ids) - len(ssd_matched_blocks) if not self.cache_config.enable_ssd: fragment2_num_blocks = 0 - fragment12_gpu_blocks = gpu_block_mapping[num_skipped_blocks:] + fragment12_gpu_blocks = gpu_block_ids[num_skipped_blocks:] fragment12_cpu_blocks = self.cpu_cache_engine.take( num_required_blocks=fragment12_num_blocks, @@ -833,7 +913,7 @@ def _put_impl_local(self, strict=False ) else: - fragment2_ssd_blocks = torch.tensor([], dtype=torch.int64) + fragment2_ssd_blocks = np.array([], dtype=np.int64) if len(fragment12_cpu_blocks) < fragment12_num_blocks or \ len(fragment2_ssd_blocks) < fragment2_num_blocks: self.cpu_cache_engine.recycle(fragment12_cpu_blocks) @@ -847,14 +927,8 @@ def _put_impl_local(self, op_d2h = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=fragment12_gpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment12_cpu_blocks, - ), + src_block_ids = fragment12_gpu_blocks, + dst_block_ids = fragment12_cpu_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -866,14 +940,8 @@ def _put_impl_local(self, op_h2disk = TransferOp( graph_id = transfer_graph.graph_id, transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=fragment2_cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=fragment2_ssd_blocks, - ), + src_block_ids = fragment2_cpu_blocks, + dst_block_ids = fragment2_ssd_blocks, layer_id = 0, layer_granularity = layer_num ) @@ -903,7 +971,7 @@ def _put_impl_local(self, def _transfer_callback(self, node_to_unlock: Dict[DeviceType, Tuple[RadixNode, int]], - buffer_to_free: Optional[Dict[DeviceType, torch.Tensor]] = None) -> None: + buffer_to_free: Optional[Dict[DeviceType, np.ndarray]] = None) -> None: if DeviceType.CPU in node_to_unlock: assert self.cpu_cache_engine is not None self.cpu_cache_engine.cleanup(node_to_unlock[DeviceType.CPU][0], node_to_unlock[DeviceType.CPU][1]) @@ -924,6 +992,18 @@ def _transfer_callback(self, assert self.remote_cache_engine is not None self.remote_cache_engine.recycle(buffer_to_free[DeviceType.REMOTE]) + @nvtx.annotate("Match Prefix Accel", color="yellow") + def match_local_accel(self, sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel]: + cpu_matched_result = MatchResultAccel() + ssd_matched_result = MatchResultAccel() + if self.cpu_cache_engine: + cpu_matched_result = self.cpu_cache_engine.match(sequence_meta) + if self.ssd_cache_engine: + ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) + + return cpu_matched_result, ssd_matched_result + + @nvtx.annotate("Match Prefix", color="yellow") def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() @@ -933,7 +1013,24 @@ def match_local(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchRe ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) return cpu_matched_result, ssd_matched_result + + @nvtx.annotate("Match All Prefix accel", color="yellow") + def match_all_accel(self, + sequence_meta: SequenceMeta) -> Tuple[MatchResultAccel, MatchResultAccel, MatchResultAccel]: + cpu_matched_result = MatchResultAccel() + ssd_matched_result = MatchResultAccel() + remote_matched_result = MatchResultAccel() + # TODO: avoid redundant match? + if self.cpu_cache_engine: + cpu_matched_result = self.cpu_cache_engine.match(sequence_meta) + if self.ssd_cache_engine: + ssd_matched_result = self.ssd_cache_engine.match(sequence_meta) + if self.remote_cache_engine: + remote_matched_result = self.remote_cache_engine.match(sequence_meta) + return cpu_matched_result, ssd_matched_result, remote_matched_result + + @nvtx.annotate("Match All Prefix", color="yellow") def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResult, MatchResult]: cpu_matched_result = MatchResult() ssd_matched_result = MatchResult() @@ -949,25 +1046,29 @@ def match_all(self, sequence_meta: SequenceMeta) -> Tuple[MatchResult, MatchResu return cpu_matched_result, ssd_matched_result, remote_matched_result def _check_input(self, - token_ids: torch.Tensor, - token_mask: torch.Tensor, - slot_mapping: torch.Tensor) -> None: + token_ids: np.ndarray, + token_mask: np.ndarray, + slot_mapping: np.ndarray) -> None: + assert token_ids.dtype == np.int64 + # assert token_mask.dtype == np.bool_, f"token_mask.dtype={token_mask.dtype}" + assert slot_mapping.dtype == np.int64 assert token_ids.ndim == 1 assert token_mask.ndim == 1 assert slot_mapping.ndim == 1 - assert len(token_ids) == len(token_mask), f"len(token_ids)={len(token_ids)}, len(token_mask)={len(token_mask)}" - assert len(slot_mapping) == token_mask.sum().item(), f"len(slot_mapping)={len(slot_mapping)}, token_mask.sum().item()={token_mask.sum().item()}" + assert token_ids.size == token_mask.size, f"token_ids.size={token_ids.size}, token_mask.size={token_mask.size}" + assert slot_mapping.size == token_mask.sum(), \ + f"slot_mapping.size={slot_mapping.size}, token_mask.sum()={token_mask.sum()}" - def _slot_to_block_mapping(self, - slot_mapping: torch.Tensor) -> torch.Tensor: - block_mapping: torch.Tensor = slot_mapping[::self.tokens_per_block] // self.tokens_per_block - return block_mapping + @staticmethod + def slot_mapping_to_block_ids(slot_mapping: np.ndarray, tokens_per_block: int) -> np.ndarray: + block_ids: np.ndarray = slot_mapping[::tokens_per_block] // tokens_per_block + return block_ids def _get_block_range(self, - token_mask: torch.Tensor) -> Tuple[int, int]: - mask_idx = torch.where(token_mask)[0] + token_mask: np.ndarray) -> Tuple[int, int]: + mask_idx = np.where(token_mask)[0] if len(mask_idx) == 0: - return 0, 0 - start_idx = int(mask_idx[0].item() // self.tokens_per_block) - end_idx = int(mask_idx[-1].item() // self.tokens_per_block) + return len(token_mask)//self.tokens_per_block, len(token_mask)//self.tokens_per_block + start_idx = mask_idx[0].item() // self.tokens_per_block + end_idx = mask_idx[-1].item() // self.tokens_per_block return start_idx, end_idx + 1 diff --git a/flexkv/cache/mempool.py b/flexkv/cache/mempool.py index d0e6848a30..d00077181f 100644 --- a/flexkv/cache/mempool.py +++ b/flexkv/cache/mempool.py @@ -1,7 +1,7 @@ from collections import deque from typing import List -import torch +import numpy as np from flexkv.common.exceptions import NotEnoughSpaceError @@ -14,18 +14,18 @@ def __init__( assert num_total_blocks > 0 self.num_total_blocks = num_total_blocks - self._free_mask = torch.ones(self.num_total_blocks, dtype=torch.bool) + self._free_mask = np.ones(self.num_total_blocks, dtype=np.bool_) self._num_free = num_total_blocks - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 def reset(self) -> None: - self._free_mask.fill_(True) + self._free_mask.fill(True) self._num_free = self.num_total_blocks - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 - def allocate_blocks(self, num: int) -> torch.Tensor: + def allocate_blocks(self, num: int) -> np.ndarray: if num < 0: raise ValueError(f"num must be greater than 0, but got {num}") if num > self._num_free: @@ -41,8 +41,8 @@ def allocate_blocks(self, num: int) -> torch.Tensor: self._num_free -= num return free_ids - def recycle_blocks(self, block_ids: torch.Tensor) -> None: - if block_ids.ndim != 1 or block_ids.dtype != torch.int64: + def recycle_blocks(self, block_ids: np.ndarray) -> None: + if block_ids.ndim != 1 or block_ids.dtype != np.int64: raise ValueError("block_ids must be a 1D tensor of int64") if self._free_mask[block_ids].any(): free_ids = block_ids[self._free_mask[block_ids]] @@ -51,7 +51,7 @@ def recycle_blocks(self, block_ids: torch.Tensor) -> None: self._num_free += len(block_ids) def _update_free_ids(self) -> None: - self._free_ids = self._free_mask.nonzero().squeeze(1) + self._free_ids = self._free_mask.nonzero()[0] self._free_ids_offset = 0 @property diff --git a/flexkv/cache/radixtree.py b/flexkv/cache/radixtree.py index 818e778a63..a9f69c6a35 100644 --- a/flexkv/cache/radixtree.py +++ b/flexkv/cache/radixtree.py @@ -32,11 +32,14 @@ class MatchResult: last_ready_node: Optional['RadixNode'] = None last_node: Optional['RadixNode'] = None last_node_matched_length: int = 0 - physical_blocks: torch.Tensor = torch.empty(0, dtype=torch.int64) + physical_blocks: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) def __post_init__(self) -> None: assert self.physical_blocks.ndim == 1 - assert self.physical_blocks.dtype == torch.int64 + assert self.physical_blocks.dtype == np.int64 + + def is_empty(self) -> bool: + return self.num_matched_blocks == 0 @dataclass class RadixNode: @@ -192,7 +195,7 @@ def match_prefix(self, last_ready_node=last_ready_node, last_node=current_node, last_node_matched_length=last_node_matched_length, - physical_blocks=torch.from_numpy(physical_blocks).to(torch.int64)) + physical_blocks=physical_blocks) def num_matched_blocks(self, sequence: SequenceMeta) -> int: @@ -201,7 +204,7 @@ def num_matched_blocks(self, def insert(self, sequence_meta: SequenceMeta, - physical_block_ids: torch.Tensor, + physical_block_ids: np.ndarray, num_insert_blocks: int = -1, is_ready: bool = True, match_result: Optional[MatchResult] = None) -> Optional[RadixNode]: @@ -210,7 +213,7 @@ def insert(self, assert 0 <= num_insert_blocks <= sequence_meta.num_blocks assert physical_block_ids.ndim == 1 - assert physical_block_ids.dtype == torch.int64 + assert physical_block_ids.dtype == np.int64 sequence_meta.gen_hashes() if match_result is None: @@ -232,7 +235,7 @@ def insert(self, new_node = RadixNode( block_hashes=sequence_meta.block_hashes[num_matched_blocks:num_insert_blocks], - physical_blocks=physical_block_ids.numpy(), + physical_blocks=physical_block_ids, is_ready=is_ready, lock_cnt=0, last_access_time=time.time() @@ -255,7 +258,7 @@ def insert(self, return new_node - def evict(self, num_evicted: int) -> torch.Tensor: + def evict(self, num_evicted: int) -> np.ndarray: candidates = [] for node in self.leaf_nodes.values(): if node.evictable(): @@ -277,7 +280,7 @@ def evict(self, num_evicted: int) -> torch.Tensor: physical_blocks = node.physical_blocks node.parent = None evicted_blocks = np.concatenate([evicted_blocks, physical_blocks]) - return torch.from_numpy(evicted_blocks).to(torch.int64) + return evicted_blocks def lock(self, node: RadixNode) -> None: if node.lock_cnt < 0: @@ -340,22 +343,22 @@ def total_unready_blocks(self) -> int: index = RadixTreeIndex(tokens_per_block=tokens_per_block) print(f"init index, tokens_per_block = {tokens_per_block}") - token_ids1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) - token_ids2 = torch.tensor([1, 2, 3, 4, 15, 16, 17, 18]) + token_ids1 = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64) + token_ids2 = np.array([1, 2, 3, 4, 15, 16, 17, 18], dtype=np.int64) seq1 = SequenceMeta(token_ids=token_ids1, tokens_per_block=tokens_per_block) seq2 = SequenceMeta(token_ids=token_ids2, tokens_per_block=tokens_per_block) - index.insert(seq1, torch.tensor([0, 1, 2, 3], dtype=torch.int64), is_ready=True) + index.insert(seq1, np.array([0, 1, 2, 3], dtype=np.int64), is_ready=True) print(f"insert seq1 = {seq1.token_ids}, " f"total cached blocks = {index.total_cached_blocks()}") seq2_matched_blocks = index.num_matched_blocks(seq2) assert seq2_matched_blocks == 2 - index.insert(seq2, torch.tensor([8, 9], dtype=torch.int64), is_ready=True) + index.insert(seq2, np.array([8, 9], dtype=np.int64), is_ready=True) print(f"insert seq2 = {seq2.token_ids}, " f"total cached blocks = {index.total_cached_blocks()}") - seq3 = SequenceMeta(token_ids=torch.tensor([1,2,3,4,0,0]), + seq3 = SequenceMeta(token_ids=np.array([1,2,3,4,0,0], dtype=np.int64), tokens_per_block=tokens_per_block) match_result = index.num_matched_blocks(seq3) print(f"match {seq3.token_ids}, num cached blocks: {match_result}") diff --git a/flexkv/cache/transfer_pattern.py b/flexkv/cache/transfer_pattern.py index 1246d248e8..fcf207408b 100644 --- a/flexkv/cache/transfer_pattern.py +++ b/flexkv/cache/transfer_pattern.py @@ -1,371 +1,33 @@ from typing import List, Optional, Tuple +import numpy as np import torch -from flexkv.common.transfer import DeviceType, TransferType -from flexkv.common.transfer import TransferOp, TransferOpGraph, TransferDescriptor +from flexkv.common.transfer import TransferType +from flexkv.common.transfer import TransferOp, TransferOpGraph def add_virtal_op_for_mutiple_finished_ops( graph: TransferOpGraph, finished_ops_ids: List[int] -)->Tuple[TransferOpGraph, List[int]]: - if len(finished_ops_ids) <= 1: - return graph, finished_ops_ids +)->Tuple[TransferOpGraph, int]: + if len(finished_ops_ids) == 0: + return graph, -1 + elif len(finished_ops_ids) == 1: + return graph, finished_ops_ids[0] else: op = TransferOp( graph_id = graph.graph_id, transfer_type = TransferType.VIRTUAL, + src_block_ids = np.array([], dtype=np.int64), + dst_block_ids = np.array([], dtype=np.int64), layer_id = -1, layer_granularity = -1, ) graph.add_transfer_op(op) for op_id in finished_ops_ids: graph.add_dependency(op.op_id, op_id) - return graph, [op.op_id] - -def create_read_graph_cpu_storage( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - graph: Optional[TransferOpGraph] = None, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a read transfer graph with (REMOTE_STORAGE / SSD)->CPU->GPU operations - ssd_blocks: the blocks of ssd that are used as a lower-level storage backend, - including ssd or remote storage. This can be empty, which means cpu-only kvcache. - Returns: - graph: TransferOpGraph - ops_to_be_tracked: List[int]: a list of transfer ops that can indicate - the completion of some key operations - """ - assert len(gpu_blocks) == len(cpu_blocks) - if graph is None: - graph = TransferOpGraph() - assert len(gpu_blocks) > 0 - if len(ssd_blocks) == 0: - op = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num, - ) - graph.add_transfer_op(op) - return graph, [op.op_id] - elif len(ssd_blocks) < len(cpu_blocks): - task_end_ops_ids = [] - if len(ssd_blocks) > 0: - op1 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.DISK2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):] - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op1) - op2 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):] - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[-len(ssd_blocks):], - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op2) - graph.add_dependency(op2.op_id, op1.op_id) - task_end_ops_ids.append(op2.op_id) - op3 = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[:len(cpu_blocks) - len(ssd_blocks)] - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[:len(cpu_blocks) - len(ssd_blocks)], - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op3) - task_end_ops_ids.append(op3.op_id) - return graph, task_end_ops_ids - else: - op1 = TransferOp( - graph_id = graph.graph_id, - transfer_type=TransferType.DISK2H, - src_descriptor=TransferDescriptor( - device_type=DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - dst_descriptor=TransferDescriptor( - device_type=DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op1) - op2 = TransferOp( - graph_id = graph.graph_id, - transfer_type=TransferType.H2D, - src_descriptor=TransferDescriptor( - device_type=DeviceType.CPU, - physical_block_ids=cpu_blocks, - ), - dst_descriptor=TransferDescriptor( - device_type=DeviceType.GPU, - physical_block_ids=gpu_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op2) - graph.add_dependency(op2.op_id, op1.op_id) - return graph, [op2.op_id] - -def create_read_graph_cpu_ssd_remote( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - remote_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - write_back_to_ssd: bool = True, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a read transfer graph with (REMOTE_STORAGE + SSD)->CPU->GPU operations - Returns: - graph: TransferOpGraph - finished_ops_ids: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - graph = TransferOpGraph() - finished_ops_ids: List[int] = [] - if len(remote_blocks) == 0: - graph, finished_ops_ids = create_read_graph_cpu_storage(gpu_blocks=gpu_blocks, - cpu_blocks=cpu_blocks, - ssd_blocks=ssd_blocks, - gpu_device_id=gpu_device_id, - layer_num=layer_num, - graph=graph) - if len(finished_ops_ids) > 0: - graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops(graph, finished_ops_ids) - assert len(finished_ops_ids) > 0 - return graph, finished_ops_ids - else: - if len(remote_blocks) < len(gpu_blocks): - graph, finished_ops_ids = create_read_graph_cpu_storage(gpu_blocks=gpu_blocks[:-len(remote_blocks)], - cpu_blocks=cpu_blocks[:-len(remote_blocks)], - ssd_blocks=ssd_blocks[:-len(remote_blocks)], - gpu_device_id=gpu_device_id, - layer_num=layer_num, - graph=graph) - op_r2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.REMOTE2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=remote_blocks, - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_r2h) - op_h2d = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2D, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks[-len(remote_blocks):], - device_id = gpu_device_id - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2d) - graph.add_dependency(op_h2d.op_id, op_r2h.op_id) - if write_back_to_ssd: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks[-len(remote_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_r2h.op_id) - finished_ops_ids.append(op_h2d.op_id) - if len(finished_ops_ids) > 0: - graph, finished_ops_ids = add_virtal_op_for_mutiple_finished_ops(graph, finished_ops_ids) - return graph, finished_ops_ids - -def create_write_graph_cpu_storage( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, - graph: Optional[TransferOpGraph] = None, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a write transfer graph with CPU->REMOTE_STORAGE / SSD operations - ssd_blocks: the blocks of ssd that are used as a lower-level storage backend, - including ssd or remote storage. This can be empty, which means cpu-only kvcache. - Write op granularity is larger: gpu->cpu is put into the same op. - Returns: - graph: TransferOpGraph - layer_wise_ops: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - if graph is None: - graph = TransferOpGraph() - op_d2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_d2h) - if len(ssd_blocks) == 0: - return graph, [op_d2h.op_id] - else: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(ssd_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_d2h.op_id) - return graph, [op_d2h.op_id] - -def create_write_graph_cpu_ssd_remote( - gpu_blocks: torch.Tensor, - cpu_blocks: torch.Tensor, - ssd_blocks: torch.Tensor, - remote_blocks: torch.Tensor, - gpu_device_id: int = 0, - layer_num: int = 1, -)->Tuple[TransferOpGraph, List[int]]: - """ - Create a write transfer graph with CPU->REMOTE_STORAGE + SSD operations - Returns: - graph: TransferOpGraph - layer_wise_ops: List[int]: a list of transfer ops that can indicate - the completion of each layer or each layer for each tp rank - """ - graph = TransferOpGraph() - op_d2h = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.D2H, - src_descriptor = TransferDescriptor( - device_type = DeviceType.GPU, - physical_block_ids=gpu_blocks, - device_id = gpu_device_id - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_d2h) - if len(ssd_blocks) != 0: - op_h2disk = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2DISK, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(gpu_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.SSD, - physical_block_ids=ssd_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2disk) - graph.add_dependency(op_h2disk.op_id, op_d2h.op_id) - if len(remote_blocks) != 0: - op_h2remote = TransferOp( - graph_id = graph.graph_id, - transfer_type = TransferType.H2REMOTE, - src_descriptor = TransferDescriptor( - device_type = DeviceType.CPU, - physical_block_ids=cpu_blocks[-len(remote_blocks):], - ), - dst_descriptor = TransferDescriptor( - device_type = DeviceType.REMOTE, - physical_block_ids=remote_blocks, - ), - layer_id = 0, - layer_granularity = layer_num - ) - graph.add_transfer_op(op_h2remote) - graph.add_dependency(op_h2remote.op_id, op_d2h.op_id) - return graph, [op_d2h.op_id] + return graph, op.op_id def convert_read_graph_to_layer_wise_graph( transfer_graph: TransferOpGraph, @@ -394,8 +56,8 @@ def convert_read_graph_to_layer_wise_graph( new_op = TransferOp( graph_id=new_graph.graph_id, transfer_type=op.transfer_type, - src_descriptor=op.src_descriptor, - dst_descriptor=op.dst_descriptor, + src_block_ids=op.src_block_ids, + dst_block_ids=op.dst_block_ids, layer_id=i * layer_granularity, layer_granularity=layer_granularity, # Inherit these fields directly diff --git a/flexkv/common/block.py b/flexkv/common/block.py index d1ab24687a..7e2cc5d8b2 100644 --- a/flexkv/common/block.py +++ b/flexkv/common/block.py @@ -5,13 +5,13 @@ import numpy as np import torch -from flexkv.common.hash_utils import HashType, gen_hashes, get_hash_size, hash_tensor +from flexkv.common.hash_utils import HashType, gen_hashes, get_hash_size, hash_array @dataclass class SequenceMeta: - token_ids: torch.Tensor + token_ids: np.ndarray tokens_per_block: int @@ -19,7 +19,7 @@ class SequenceMeta: _has_hashes: bool = False - def __init__(self, token_ids: torch.Tensor, tokens_per_block: int): + def __init__(self, token_ids: np.ndarray, tokens_per_block: int): self.token_ids = token_ids self.tokens_per_block = tokens_per_block @@ -44,13 +44,13 @@ def get_hash(self, block_id: int) -> Optional[HashType]: if self._has_hashes: return HashType(int(self.block_hashes[block_id].item())) else: - return hash_tensor(self.token_ids[:(block_id+1)*self.tokens_per_block]) + return hash_array(self.token_ids[:(block_id+1)*self.tokens_per_block]) def gen_hashes(self) -> None: if self._has_hashes: return assert self.token_ids.ndim == 1 - self.block_hashes = gen_hashes(self.token_ids, self.tokens_per_block).numpy() + self.block_hashes = gen_hashes(self.token_ids, self.tokens_per_block) assert self.block_hashes.ndim == 1 assert self.block_hashes.size == self.num_blocks assert self.block_hashes.itemsize == get_hash_size() diff --git a/flexkv/common/config.py b/flexkv/common/config.py index a022292eb8..2805ae606c 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -14,6 +14,7 @@ class ModelConfig: head_size: int use_mla: bool = False dtype: torch.dtype = torch.bfloat16 + max_req_tokens = 163840 # parallel configs tp_size: int = 1 @@ -31,13 +32,13 @@ class CacheConfig: enable_ssd: bool = False enable_remote: bool = False use_gds: bool = False - use_pinned_memory: bool = False + index_accel: bool = False # kv cache layout configs gpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE - remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.LAYERWISE + cpu_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + ssd_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE + remote_kv_layout_type: KVCacheLayoutType = KVCacheLayoutType.BLOCKWISE # mempool capacity configs num_cpu_blocks: int = 1000000 @@ -70,3 +71,17 @@ class CacheConfig: trace_max_file_size_mb: int = 100 trace_max_files: int = 5 trace_flush_interval_ms: int = 1000 + + #evict ratio + evict_ratio: float = 0.0 + + def __post_init__(self): + layout_fields = ['gpu_kv_layout_type', + 'cpu_kv_layout_type', + 'ssd_kv_layout_type', + 'remote_kv_layout_type'] + for field in layout_fields: + value = getattr(self, field) + if isinstance(value, str): + setattr(self, field, KVCacheLayoutType[value.upper()]) + diff --git a/flexkv/common/debug.py b/flexkv/common/debug.py index 6c52931db6..a522c5549a 100644 --- a/flexkv/common/debug.py +++ b/flexkv/common/debug.py @@ -6,18 +6,28 @@ from typing import Optional, Callable, Any +FLEXKV_LOGGING_PREFIX = os.getenv("FLEXKV_LOGGING_PREFIX", "FLEXKV") +_FORMAT = (f"[{FLEXKV_LOGGING_PREFIX}] %(levelname)s %(asctime)s.%(msecs)03d " + " %(message)s") +_DATE_FORMAT = "%m-%d %H:%M:%S" + class FlexkvLogger: def __init__(self, debug_level: str = "INFO"): self.enabled = False self.logger = logging.getLogger("FLEXKV") - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(message)s" + has_console_handler = any( + isinstance(handler, logging.StreamHandler) + for handler in self.logger.handlers ) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(formatter) - self.logger.addHandler(console_handler) + if not has_console_handler: + formatter = logging.Formatter( + fmt=_FORMAT, + datefmt=_DATE_FORMAT, + ) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) self.set_level(debug_level) diff --git a/flexkv/common/hash_utils.py b/flexkv/common/hash_utils.py index 6f8ec9fc96..32f0055f48 100644 --- a/flexkv/common/hash_utils.py +++ b/flexkv/common/hash_utils.py @@ -1,6 +1,7 @@ import time from typing import NewType, Optional +import numpy as np import torch from flexkv import c_ext @@ -18,33 +19,37 @@ def __init__(self) -> None: def reset(self) -> None: self.hasher.reset() - def update(self, tensor: torch.Tensor) -> None: - self.hasher.update(tensor) + def update(self, array: np.ndarray) -> None: + self.hasher.update(torch.from_numpy(array)) def digest(self) -> HashType: return HashType(self.hasher.digest()) -def hash_tensor(tensor: torch.Tensor) -> HashType: - hasher = Hasher() - hasher.update(tensor) - return HashType(hasher.digest()) +_HASHER = Hasher() -def gen_hashes(token_ids: torch.Tensor, tokens_per_block: int, hasher: Optional[Hasher] = None) -> torch.Tensor: - block_hashes = torch.zeros(token_ids.numel() // tokens_per_block, dtype=torch.uint64) +def hash_array(array: np.ndarray) -> HashType: + _HASHER.reset() + _HASHER.update(array) + return HashType(_HASHER.digest()) + +def gen_hashes(token_ids: np.ndarray, tokens_per_block: int, hasher: Optional[Hasher] = None) -> np.ndarray: + block_hashes = np.zeros(token_ids.size // tokens_per_block, dtype=np.uint64) if hasher is None: hasher = Hasher() - c_ext.gen_hashes(hasher.hasher, token_ids, tokens_per_block, block_hashes) + c_ext.gen_hashes(hasher.hasher, torch.from_numpy(token_ids), tokens_per_block, torch.from_numpy(block_hashes)) return block_hashes if __name__ == "__main__": - torch.manual_seed(0) - token_ids = torch.randint(0, 10000, (32000, ), dtype=torch.int64) + np.random.seed(0) + token_ids = np.random.randint(0, 10000, (1000, ), dtype=np.int64) print(f"token ids length: {token_ids.shape[0]}") + result = hash_array(token_ids) start = time.time() - result = hash_tensor(token_ids) - end = time.time() - print(f"tensor hash: {result}, time: {end - start}s") - start = time.time() - result2 = gen_hashes(token_ids, 16) + for i in range(1): + result = hash_array(token_ids) end = time.time() - print(f"block hashes: {result2}, time: {end - start}s") + print(f"array hash: {result}, average time: {(end - start)*1000/5}ms") + # start = time.time() + # result2 = gen_hashes(token_ids, 16) + # end = time.time() + # print(f"block hashes: {result2}, time: {(end - start)*1000}ms") diff --git a/flexkv/common/memory_handle.py b/flexkv/common/memory_handle.py index d1e2838a87..5013308b40 100644 --- a/flexkv/common/memory_handle.py +++ b/flexkv/common/memory_handle.py @@ -3,7 +3,7 @@ import os import pickle import time -from typing import List, Callable, Any, Optional, Tuple, Union +from typing import Callable, Any, Optional, Tuple, Union from dataclasses import dataclass import torch @@ -16,7 +16,7 @@ @dataclass class TensorSharedHandle: rebuild_func: Callable - rebuild_args: List[Any] + rebuild_args: Tuple[Any] device: torch.device def __init__(self, tensor: torch.Tensor): @@ -29,7 +29,7 @@ def get_tensor(self) -> torch.Tensor: return tensor @staticmethod - def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, List[Any], torch.device]: + def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, Tuple[Any], torch.device]: device = tensor.device rebuild_func, rebuild_args = reductions.reduce_tensor(tensor) @@ -37,7 +37,7 @@ def _export_tensor_handle(tensor: torch.Tensor) -> Tuple[Callable, List[Any], to return rebuild_func, rebuild_args, device @staticmethod - def _import_tensor_handle(rebuild_func: Callable, rebuild_args: List[Any], device: torch.device) -> torch.Tensor: + def _import_tensor_handle(rebuild_func: Callable, rebuild_args: Tuple[Any], device: torch.device) -> torch.Tensor: try: tensor = rebuild_func(*rebuild_args) diff --git a/flexkv/common/request.py b/flexkv/common/request.py index 1c871ea821..ef1c6ca009 100644 --- a/flexkv/common/request.py +++ b/flexkv/common/request.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from enum import Enum +from typing import Callable, List, Optional import torch +import numpy as np class KVRequestType(Enum): @@ -18,3 +20,21 @@ class KVRequest: slot_mapping: torch.Tensor layer_granularity: int = -1 dp_id: int = 0 + + +class KVResponseStatus(Enum): + SUCCESS = "success" + NOTFOUND = "not_found" + UNREADY = "unready" + TIMEOUT = "timeout" + CANCELLED = "cancelled" + FAILED = "failed" + +@dataclass +class KVResponse: + status: KVResponseStatus + task_id: int + return_mask: Optional[np.ndarray] + + def get_mask(self) -> torch.Tensor: + return torch.from_numpy(self.return_mask) if self.return_mask is not None else None diff --git a/flexkv/common/ring_buffer.py b/flexkv/common/ring_buffer.py new file mode 100644 index 0000000000..3f67d612ba --- /dev/null +++ b/flexkv/common/ring_buffer.py @@ -0,0 +1,109 @@ +import torch +import threading +import time +import random + +from collections import OrderedDict,deque +import numpy as np +from flexkv.common.transfer import TransferOp +from flexkv.common.debug import flexkv_logger +from flexkv.common.hash_utils import hash_array + + +class SharedOpPool: + def __init__(self, max_op_num: int, max_block_num: int, dtype = np.int64): + self.max_op_num = max_op_num + self.max_block_num = max_block_num + self.dtype = dtype + # create the buffer tensor + self.buffer_o = torch.empty((self.max_op_num, self.max_block_num), dtype = torch.int64) + # move tensor to share memory + self.buffer = self.buffer_o.share_memory_() + + flexkv_logger.info(f"[SharedOpPool] block ids buffer data_ptr: {self.buffer.storage().data_ptr()}") + + self.free_slots = deque(range(max_op_num)) + self.slot_map = dict() # {slot_hash: slot_id} + + self.slot_ref_count = np.zeros(max_op_num, dtype=np.int32) + self.slot_hashes = [0]*max_op_num + + self.lock = threading.Lock() + + def allocate_slot(self, block_ids: np.ndarray): + """ + Allocating a slot for the given block ids + Params: + block_ids: the block ids of src address or dst address + Returns: + slot_id: the slot which is assigned to the given block ids, -1 if failed + """ + # firstly, determine whether the length of block ids exceeds the limit + num_blocks = block_ids.size + if num_blocks > self.max_block_num or num_blocks == 0: + return -1 + + slot_hash = hash_array(block_ids) + reuse = False + + # get the slot of empty buffer + with self.lock: + if slot_hash in self.slot_map: + slot_id = self.slot_map[slot_hash] + reuse = True + else: + if not self.free_slots: + flexkv_logger.info("No empty slot in SharedOpPool") + return -1 + + slot_id = self.free_slots.popleft() + self.slot_map[slot_hash] = slot_id + + # update status managers + self.slot_ref_count[slot_id] += 1 + self.slot_hashes[slot_id] = slot_hash + + # do copy + if not reuse: + self.buffer[slot_id, :num_blocks] = torch.from_numpy(block_ids).to(torch.int64) + + return slot_id + + def free_slot(self, slot_id: int): + """ + Free the relevant resources of corresponding op, called when op transfer completed. + Input: + op_id: the index of current op + Output: + None + """ + with self.lock: + slot_hash = self.slot_hashes[slot_id] + if slot_hash not in self.slot_map: + raise RuntimeError(f"Slot {slot_id} is not in use, double free detected!") + self.slot_ref_count[slot_id] -= 1 + assert self.slot_ref_count[slot_id] >= 0, f"Slot {slot_id} ref count is negative" + if self.slot_ref_count[slot_id] == 0: + self.free_slots.append(slot_id) + del self.slot_map[slot_hash] + + def get_buffer(self): + return self.buffer + + def get_buffer_size(self): + return self.max_op_num, self.max_block_num + + def status(self): + """ + Current status logger + """ + with self.lock: + used = len(self.slot_map) + free = self.max_op_num - used + return {"used_slots": used, + "free_slots": free, + "capacity": self.max_op_num} + + +if __name__ == "__main__": + manager = SharedOpPool(4, 10) diff --git a/flexkv/common/tracer.py b/flexkv/common/tracer.py index 45178e85b4..dff6b1ff3a 100644 --- a/flexkv/common/tracer.py +++ b/flexkv/common/tracer.py @@ -121,7 +121,6 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_kv_layout_type": str(cache_config.ssd_kv_layout_type), "remote_kv_layout_type": str(cache_config.remote_kv_layout_type), "use_gds": cache_config.use_gds, - "use_pinned_memory": cache_config.use_pinned_memory, "remote_cache_size_mode": cache_config.remote_cache_size_mode, "num_cpu_blocks": cache_config.num_cpu_blocks, "num_ssd_blocks": cache_config.num_ssd_blocks, @@ -134,6 +133,7 @@ def trace_config(self, model_config, cache_config, gpu_layout=None): "ssd_cache_iouring_flags": cache_config.ssd_cache_iouring_flags, "remote_cache_path": cache_config.remote_cache_path, "remote_config_custom": cache_config.remote_config_custom, + "evict_ratio": cache_config.evict_ratio, } # Convert gpu_layout to dict if provided diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 342f0678c0..91229b7834 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -3,7 +3,7 @@ from enum import Enum from typing import ClassVar, List, Set, Dict -import torch +import numpy as np class DeviceType(Enum): @@ -31,16 +31,6 @@ class PartitionBlockType(Enum): ROUND_ROBIN = 0 SEQUENTIAL = 1 -@dataclass -class TransferDescriptor: - device_type: DeviceType = DeviceType.CPU - device_id: int = 0 - physical_block_ids: torch.Tensor = torch.tensor([], dtype=torch.int64) - - def __post_init__(self) -> None: - assert self.physical_block_ids.ndim == 1 - assert self.physical_block_ids.dtype == torch.int64 - class TransferOpStatus(Enum): PENDING = 0 RUNNING = 1 @@ -54,24 +44,31 @@ class TransferOp: op_id: int = field(init=False) graph_id: int transfer_type: TransferType - layer_id: int - layer_granularity: int - src_descriptor: TransferDescriptor = field(default_factory=TransferDescriptor) - dst_descriptor: TransferDescriptor = field(default_factory=TransferDescriptor) + src_block_ids: np.ndarray + dst_block_ids: np.ndarray + layer_id: int = 0 + layer_granularity: int = -1 # this will change dynamically as transfer ops executed predecessors: Set[int] = field(default_factory=set) # this will keep the full info successors: Set[int] = field(default_factory=set) status: TransferOpStatus = TransferOpStatus.PENDING dp_id: int = 0 + # used for get block ids inner worker process + src_slot_id: int = -1 + dst_slot_id: int = -1 + valid_block_num: int = 0 def __post_init__(self) -> None: if self.transfer_type != TransferType.VIRTUAL and \ - len(self.src_descriptor.physical_block_ids) != len(self.dst_descriptor.physical_block_ids): - raise ValueError("src_descriptor and dst_descriptor must have the same number of physical blocks") + self.src_block_ids.size != self.dst_block_ids.size: + raise ValueError("src_block_ids and dst_block_ids must have the same number of physical blocks") with TransferOp._lock: self.op_id = TransferOp._next_op_id TransferOp._next_op_id += 1 + assert self.src_block_ids.dtype == np.int64 + assert self.dst_block_ids.dtype == np.int64 + self.valid_block_num = self.src_block_ids.size class TransferOpGraph: @@ -79,13 +76,14 @@ class TransferOpGraph: _lock = threading.Lock() def __init__(self) -> None: - self.graph_id = self._get_next_graph_id() + self.graph_id = self._get_graph_id() self._op_map: Dict[int, TransferOp] = {} self._ready_ops: Set[int] = set() self._trigger_ops: Set[int] = set() + self._gpu_transfer_op_id: int = -1 @classmethod - def _get_next_graph_id(cls) -> int: + def _get_graph_id(cls) -> int: with cls._lock: graph_id = cls._next_graph_id cls._next_graph_id += 1 @@ -115,6 +113,14 @@ def trigger_op(self, op_id: int) -> None: def add_transfer_op(self, op: TransferOp) -> None: op.graph_id = self.graph_id self._op_map[op.op_id] = op + if op.transfer_type == TransferType.H2D or \ + op.transfer_type == TransferType.D2H or \ + op.transfer_type == TransferType.D2DISK or \ + op.transfer_type == TransferType.DISK2D: + if self._gpu_transfer_op_id == -1: + self._gpu_transfer_op_id = op.op_id + else: + raise ValueError("Only one GPU transfer op is allowed") self._ready_ops.add(op.op_id) def add_dependency(self, successor_op_id: int, predecessor_op_id: int) -> None: @@ -130,8 +136,6 @@ def mark_completed(self, op_id: int) -> None: assert self._op_map[op_id].status == TransferOpStatus.RUNNING self._op_map[op_id].status = TransferOpStatus.COMPLETED my_successors = self._op_map[op_id].successors - if len(my_successors) == 0: - return for successor_id in my_successors: self._op_map[successor_id].predecessors.remove(op_id) @@ -164,6 +168,16 @@ def all_transfer_ops_completed(self) -> bool: return all(op.status == TransferOpStatus.COMPLETED for op in self._op_map.values()) + def set_gpu_blocks(self, gpu_blocks: np.ndarray) -> None: + transfer_type = self._op_map[self._gpu_transfer_op_id].transfer_type + op = self._op_map[self._gpu_transfer_op_id] + if transfer_type.name.endswith("2D"): + op.dst_block_ids = gpu_blocks + else: + op.src_block_ids = gpu_blocks + assert op.src_block_ids.size == op.dst_block_ids.size, \ + f"src_block_ids.size={op.src_block_ids.size}, dst_block_ids.size={op.dst_block_ids.size}" + @property def num_ops(self) -> int: return len(self._op_map) @@ -172,69 +186,6 @@ def bind_to_dp_group(self, dp_id: int) -> None: for op in self._op_map.values(): op.dp_id = dp_id - def print_op_map(self) -> None: - """Print transfer op graph in a visual format showing dependencies. - - Example output: - Transfer Graph 5: - ├── Op 1 (H2D) [Completed] - │ └── No successors - ├── Op 2 (D2H) [Pending] - │ └── Followed by: 1 - └── Op 3 (DISK2H) [Pending] - └── Followed by: 1, 2 - """ - print(f"Transfer Graph {self.graph_id}:") - - # get all op ids and sort them - op_ids = sorted(self._op_map.keys()) - - for i, op_id in enumerate(op_ids): - op = self._op_map[op_id] - is_last = (i == len(op_ids) - 1) - - # draw the tree structure branch - prefix = "└── " if is_last else "├── " - - # get the op status - status = "[Completed]" if op.status == TransferOpStatus.COMPLETED else "[Pending]" - - # print the op info - print(f"{prefix}Op {op_id} ({op.transfer_type.name}) {status}") - - if op.transfer_type == TransferType.VIRTUAL: - continue - # print the dependency info - dep_prefix = " " if is_last else "│ " - if not op.successors: - print(f"{dep_prefix}└── No successors") - else: - deps_str = ", ".join(str(dep) for dep in sorted(op.successors)) - print(f"{dep_prefix}└── Followed by: {deps_str}") - - # print the transfer details - src_info = f"From: {op.src_descriptor.device_type.name}:{op.src_descriptor.device_id}" - dst_info = f"To: {op.dst_descriptor.device_type.name}:{op.dst_descriptor.device_id}" - print(f"{dep_prefix} └── {src_info} -> {dst_info}") - - print(f"{dep_prefix} └── layers: {op.layer_id} - {op.layer_id + op.layer_granularity}") - - # if there are physical block ids, also print them - if len(op.src_descriptor.physical_block_ids) > 0: - blocks = op.src_descriptor.physical_block_ids.tolist() - if len(blocks) > 3: - blocks_str = f"{blocks[:3]}... ({len(blocks)} blocks)" - else: - blocks_str = str(blocks) - print(f"{dep_prefix} └── Src Blocks: {blocks_str}") - if len(op.dst_descriptor.physical_block_ids) > 0: - blocks = op.dst_descriptor.physical_block_ids.tolist() - if len(blocks) > 3: - blocks_str = f"{blocks[:3]}... ({len(blocks)} blocks)" - else: - blocks_str = str(blocks) - print(f"{dep_prefix} └── Dst Blocks: {blocks_str}") - def get_nvtx_default_color() -> int: return 0xD3D3D3 diff --git a/flexkv/integration/__init__.py b/flexkv/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flexkv/integration/config.py b/flexkv/integration/config.py new file mode 100644 index 0000000000..76f27f5b34 --- /dev/null +++ b/flexkv/integration/config.py @@ -0,0 +1,68 @@ + +import json +import os +import torch +import tempfile +from typing import TYPE_CHECKING +from dataclasses import dataclass, field + +from flexkv.common.debug import flexkv_logger + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig, FullAttentionSpec + from vllm.config import VllmConfig + + +logger = flexkv_logger + +@dataclass +class FlexKVConfig: + #base config + server_recv_port: str + + # cache config + cache_config: dict = field(default_factory=dict) + + # model config + block_size: int = None + num_layers: int = None + num_kv_heads: int = None + head_size: int = None + dtype: torch.dtype = None + use_mla: bool = False + tp_size: int = 1 + + # log config + num_log_interval_requests: int = 200 + + @classmethod + def from_env(cls) -> 'FlexKVConfig': + config_file_path = os.getenv('FLEXKV_CONFIG_PATH', None) + logger.info(f"{config_file_path=}") + if config_file_path is None: + return cls(enable_flexkv=False, + server_recv_port="") + + assert config_file_path.endswith(".json"), "flexkv config must be a json file." + + with open(config_file_path, 'r') as f: + config_dict: dict = json.load(f) + logger.info(f"FlexKV Config Dict: {config_dict}") + + return cls( + server_recv_port=config_dict.get("server_recv_port", f"ipc:///tmp/flexkv_test"), + cache_config=config_dict.get("cache_config", {}), + num_log_interval_requests=config_dict.get("num_log_interval_requests", 200), + ) + + def post_init_from_vllm_config( + self, + vllm_config: "VllmConfig", + ): + self.num_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) + self.block_size = vllm_config.cache_config.block_size + self.num_kv_heads = vllm_config.model_config.get_total_num_kv_heads() + self.head_size = vllm_config.model_config.get_head_size() + self.dtype = vllm_config.model_config.dtype + self.use_mla = vllm_config.model_config.is_deepseek_mla + self.tp_size = vllm_config.parallel_config.tensor_parallel_size \ No newline at end of file diff --git a/flexkv/integration/stats.py b/flexkv/integration/stats.py new file mode 100644 index 0000000000..3f4d1a70f9 --- /dev/null +++ b/flexkv/integration/stats.py @@ -0,0 +1,100 @@ +import time +from dataclasses import dataclass +from collections import deque + +from flexkv.common.debug import flexkv_logger + +logger = flexkv_logger + + +@dataclass +class FlexKVStats: + num_log_interval_requests: int + + # get info + num_get_requests: int = 0 + num_get_query_tokens: int = 0 + num_gpu_matched_tokens: int = 0 + num_flexkv_matched_tokens: int = 0 + + # put info + num_put_requests: int = 0 + num_put_query_tokens: int = 0 + num_put_unmatched_tokens: int = 0 + + num_failed_requests: int = 0 + + @property + def tatal_num_requests(self) -> int: + return self.num_get_requests + self.num_put_requests + + @property + def get_gpu_match_ratio(self) -> float: + if self.num_get_query_tokens == 0: + return 0.0 + return self.num_gpu_matched_tokens / self.num_get_query_tokens + + @property + def get_flexkv_match_ratio(self) -> float: + if self.num_get_query_tokens == 0: + return 0.0 + return self.num_flexkv_matched_tokens / self.num_get_query_tokens + + @property + def get_put_token_ratio(self) -> float: + if self.num_put_unmatched_tokens == 0: + return 0.0 + return self.num_flexkv_matched_tokens / self.num_put_unmatched_tokens + + def record_get( + self, + num_prompt_tokens: int, + num_gpu_matched_tokens: int, + num_flexkv_matched_tokens: int, + ): + self.num_get_requests += 1 + self.num_get_query_tokens += num_prompt_tokens + self.num_gpu_matched_tokens += num_gpu_matched_tokens + self.num_flexkv_matched_tokens += num_flexkv_matched_tokens + if self.num_get_requests == self.num_log_interval_requests: + self.log() + self.clear() + + def record_put( + self, + num_all_tokens: int, + num_unmatched_tokens: int, + ): + self.num_put_requests += 1 + self.num_put_query_tokens += num_all_tokens + self.num_put_unmatched_tokens += num_unmatched_tokens + + def record_faild( + self, + num_failed_requests: int + ): + self.num_failed_requests += num_failed_requests + + def clear(self): + self.num_get_requests = 0 + self.num_get_query_tokens = 0 + self.num_gpu_matched_tokens = 0 + self.num_flexkv_matched_tokens = 0 + self.num_put_requests = 0 + self.num_put_query_tokens = 0 + self.num_put_unmatched_tokens = 0 + self.num_failed_requests = 0 + + def log(self): + if self.num_put_unmatched_tokens == 0: + get_put_token_ratio_str = "Nan" + else: + get_put_token_ratio_str = f"{self.get_put_token_ratio*100:.2f}%" + logger.info( + f"[FlexKV] Metric of Recent {self.num_log_interval_requests} Requests: " + f"Num Failed Request: {self.num_failed_requests}, " + f"Num Get Query Tokens: {self.num_get_query_tokens}, " + f"GPU Hit Ratio: {self.get_gpu_match_ratio*100:.2f}%, " + f"FlexKV Hit Ratio: {self.get_flexkv_match_ratio*100:.2f}%, " + f"Get/Put Token Ratio: {get_put_token_ratio_str}.") + \ No newline at end of file diff --git a/flexkv/integration/utils.py b/flexkv/integration/utils.py new file mode 100644 index 0000000000..9b107b7c23 --- /dev/null +++ b/flexkv/integration/utils.py @@ -0,0 +1,5 @@ + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) \ No newline at end of file diff --git a/flexkv/integration/vllm/__init__.py b/flexkv/integration/vllm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flexkv/integration/vllm/vllm_v1_adapter.py b/flexkv/integration/vllm/vllm_v1_adapter.py new file mode 100644 index 0000000000..c129e44563 --- /dev/null +++ b/flexkv/integration/vllm/vllm_v1_adapter.py @@ -0,0 +1,746 @@ +import os +import time +from typing import TYPE_CHECKING, Optional, Literal, Any +from dataclasses import dataclass, field +from abc import ABC, abstractmethod + +import numpy as np +import torch + +from flexkv.kvmanager import KVManager +from flexkv.server.client import KVTPClient +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.config import ModelConfig, CacheConfig +from flexkv.common.request import KVResponseStatus +from flexkv.common.debug import flexkv_logger +from flexkv.integration.stats import FlexKVStats +from flexkv.integration.utils import cdiv +from flexkv.integration.config import FlexKVConfig + +# vllm +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, KVConnectorRole) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + from vllm.v1.outputs import KVConnectorOutput + + +logger = flexkv_logger + + +@dataclass +class FlexKVResponse: + task_id: int + task_type: Literal["get", "put"] + request: "Request" + success: bool + + +@dataclass +class FlexKVTask(ABC): + task_id: int = 0 + request: "Request" = 0 + + # slot mapping + slot_mapping: Optional[np.ndarray] = None + + # timer + match_start_time: float = 0 + match_end_time: float = 0 + task_launch_time: float = 0 + task_finished_time: float = 0 + + @property + def match_cost(self) -> float: + return (self.match_end_time - self.match_start_time) + + @property + def task_execute_cost(self) -> float: + return (self.task_finished_time - self.task_launch_time) + + @property + @abstractmethod + def task_type(self) -> str: + ... + + def __str__(self): + return (f"FlexKVTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVGetTask(FlexKVTask): + num_computed_tokens: int + num_new_matched_tokens: int + + @property + def task_type(self) -> str: + return "get" + + def __str__(self): + return (f"FlexKVGetTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"num_computed_tokens={self.num_computed_tokens}, " + f"num_new_matched_tokens={self.num_new_matched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +@dataclass(kw_only=True) +class FlexKVPutTask(FlexKVTask): + num_matched_tokens: int + num_unmatched_tokens: int + + @property + def task_type(self) -> str: + return "put" + + def __str__(self): + return (f"FlexKVPutTask(task_id={self.task_id}, " + f"request={self.request.request_id}, " + f"num_matched_tokens={self.num_matched_tokens}, " + f"num_unmatched_tokens={self.num_unmatched_tokens}, " + f"match_cost {self.match_cost*1000:.2f} ms, " + f"task execute cost {self.task_execute_cost*1000:.2f} ms)") + + +class FlexKVSchedulerConnector: + def __init__( + self, + flexkv_config: FlexKVConfig + ): + logger.info(f"Start init FlexKVSchedulerConnector with {flexkv_config}") + self.flexkv_config = flexkv_config + self.server_recv_port = flexkv_config.server_recv_port + self.tp_size = flexkv_config.tp_size + self.block_size = flexkv_config.block_size + self.model_config = ModelConfig( + num_layers=flexkv_config.num_layers, + num_kv_heads=flexkv_config.num_kv_heads, + head_size=flexkv_config.head_size, + use_mla=flexkv_config.use_mla, + dtype=flexkv_config.dtype, + tp_size=flexkv_config.tp_size, + ) + if "tokens_per_block" in flexkv_config.cache_config: + assert flexkv_config.cache_config.pop("tokens_per_block") == flexkv_config.block_size + self.cache_config = CacheConfig( + tokens_per_block=flexkv_config.block_size, + **flexkv_config.cache_config, + ) + self.flexkv_manager = KVManager(model_config=self.model_config, + cache_config=self.cache_config, + gpu_register_port=flexkv_config.server_recv_port) + self.flexkv_manager.start() + # self.dp_client = KVDPClient(self.server_recv_port, self.model_config) + + # request_id -> task_id + self.req_id_to_task_dict: dict[str, int] = {} + # launched but unfinished tasks + self.get_tasks: dict[int, FlexKVGetTask] = {} + self.put_tasks: dict[int, FlexKVPutTask] = {} + # unlaunched tasks + self.tasks_to_launch: dict[int, FlexKVTask] = {} + self.tasks_to_cancel: dict[int, FlexKVTask] = {} + + self.flexkv_stats = FlexKVStats(flexkv_config.num_log_interval_requests) + + while not self.is_ready(): + logger.info("Waiting for flexkv init...") + time.sleep(5) + + logger.info("Finish init FlexKVSchedulerConnector") + + def is_ready( + self, + ) -> bool: + " Ask flexkv is ready " + return self.flexkv_manager.is_ready() + + def shutdown(self) -> None: + self.flexkv_manager.shutdown() + + @property + def dp_client_id(self) -> int: + return self.flexkv_manager.dp_client_id + + #################### + #### Get Method #### + #################### + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, bool]: A tuple containing two integer values representing the + number of new matched tokens and whether it is necessary + to get the new matched blocks from flexkv. + """ + task_id, num_new_matched_tokens = self._get_match(request=request, + num_computed_tokens=num_computed_tokens) + self.flexkv_stats.record_get(num_prompt_tokens=request.num_tokens, + num_gpu_matched_tokens=num_computed_tokens, + num_flexkv_matched_tokens=num_new_matched_tokens) + + if not self._need_to_get(num_prompt_tokens=request.num_tokens, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens): + return 0, False + + return num_new_matched_tokens, True + + + def _get_match( + self, + request: "Request", + num_computed_tokens: int = 0, + ) -> tuple[int, int]: + """ + Args: + request: Request to get. + num_computed_tokens: Number of prefix tokens have already been computed, + which means not need to transfer from flexkv. + + Returns: + tuple[int, int]: A tuple containing two integer values representing + the task_id and number of new matched tokens. + """ + match_start_time = time.perf_counter() + num_tokens_to_get = (request.num_tokens//self.block_size)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_get] + + assert num_computed_tokens <= num_tokens_to_get, ( + f"{num_computed_tokens=} must less equal to {num_tokens_to_get=}") + assert num_computed_tokens % self.block_size == 0 + + if num_tokens_to_get == num_computed_tokens: + return -1, 0 + + np_token_ids = np.array(token_ids) + np_token_mask = np.ones_like(np_token_ids, dtype=bool) + np_token_mask[:num_computed_tokens] = False + task_id, matched_mask = self.flexkv_manager.get_match(token_ids=np_token_ids, + token_mask=np_token_mask) + num_new_matched_tokens = matched_mask.sum().item() + + # Auto cancel if not call update_state_after_alloc() + match_end_time = time.perf_counter() + logger.debug(f"Get match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + if num_new_matched_tokens > 0: + self.req_id_to_task_dict[request.request_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVGetTask(task_id=task_id, + request=request, + num_computed_tokens=num_computed_tokens, + num_new_matched_tokens=num_new_matched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + + logger.debug(f"FlexKV create get task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_new_matched_tokens + + def _need_to_get( + self, + num_prompt_tokens: int, + num_computed_tokens: int, + num_new_matched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to get the new matched blocks from flexkv. + """ + return num_new_matched_tokens > 0 + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_new_matched_tokens: int, + ) -> None: + """ + Compute slot mapping and prepare to launch task. + Only call after get_num_new_matched_tokens(). + + Args: + request: Request to get. + blocks: All blocks of the request. + num_new_matched_tokens: Number of new matched tokens returned by + get_num_new_matched_tokens(). + + Returns: + None. + """ + if num_new_matched_tokens == 0: + return + # prepare to launch task + task_id = self.req_id_to_task_dict[request.request_id] + task: FlexKVGetTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot_mapping + num_computed_blocks = task.num_computed_tokens // self.block_size + num_blocks_to_get = num_new_matched_tokens // self.block_size + all_block_ids = blocks.get_block_ids()[0] + block_ids_to_get = all_block_ids[num_computed_blocks:num_computed_blocks+num_blocks_to_get] + task.slot_mapping = np.array(block_ids_to_get).repeat(self.block_size)*self.block_size + + def wait_for_all_get_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all get tasks. + + Returns: + list[FlexKVResponse]: Responses of all get tasks. + """ + return self._blocking_waiting_for_tasks(self.get_tasks) + + #################### + #### Put Method #### + #################### + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> bool: + """ + Args: + request: Request to put. + blocks: All block_ids of the request. + + Returns: + bool: whether thire is unfinished task for this request. + """ + # Task not finished, can't free blocks + if request.request_id in self.req_id_to_task_dict: + return True + + # Abnormal finished, don't put + if not (request.is_finished() and request.get_finished_reason() < 2): + return False + + task_id, num_matched_tokens, num_unmatched_tokens = self._put_match(request=request) + + self.flexkv_stats.record_put(num_all_tokens=request.num_tokens, + num_unmatched_tokens=num_unmatched_tokens) + + if not self._need_to_put(num_all_tokens=request.num_tokens, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens): + return False + + # prepare to launch task + task: FlexKVPutTask = self.tasks_to_cancel.pop(task_id) + self.tasks_to_launch[task_id] = task + + # compute slot mapping + # num_blocks_to_put = (num_matched_tokens+num_unmatched_tokens) // self.block_size + num_matched_blocks = num_matched_tokens // self.block_size + num_unmatched_tokens = num_unmatched_tokens // self.block_size + block_ids_to_put = block_ids[num_matched_blocks:num_matched_blocks+num_unmatched_tokens] + task.slot_mapping = np.array(block_ids_to_put).repeat(self.block_size)*self.block_size + + return True + + def _put_match( + self, + request: "Request" + ) -> tuple[int, int, int]: + """ + Args: + request: Request to put. + + Returns: + tuple[int, int, int]: A tuple containing three integer values representing + the task_id, number of matched tokens and number of unmatched tokens. + """ + match_start_time = time.perf_counter() + num_tokens_to_put = (cdiv(request.num_tokens+1, self.block_size)-1)*self.block_size + token_ids = request.all_token_ids[:num_tokens_to_put] + + if num_tokens_to_put == 0: + return -1, 0, 0 + + np_token_ids = np.array(token_ids) + task_id, unmatched_mask = self.flexkv_manager.put_match(token_ids=np_token_ids) + + num_unmatched_tokens = unmatched_mask.sum().item() + num_matched_tokens = num_tokens_to_put - num_unmatched_tokens + + # Auto cancel if not need to put. + match_end_time = time.perf_counter() + logger.debug(f"Put match cost {(match_end_time-match_start_time)*1000:.2f} ms.") + + if num_unmatched_tokens > 0: + self.req_id_to_task_dict[request.request_id] = task_id + self.tasks_to_cancel[task_id] = FlexKVPutTask(task_id=task_id, + request=request, + num_matched_tokens=num_matched_tokens, + num_unmatched_tokens=num_unmatched_tokens, + match_start_time=match_start_time, + match_end_time=match_end_time) + logger.debug(f"FlexKV create put task: {self.tasks_to_cancel[task_id]}") + + return task_id, num_matched_tokens, num_unmatched_tokens + + def _need_to_put( + self, + num_all_tokens: int, + num_matched_tokens: int, + num_unmatched_tokens: int, + ) -> bool: + """ + Determine whether it is necessary to put the unmatched blocks from flexkv. + """ + return num_unmatched_tokens > 0 + + def wait_for_all_put_tasks(self) -> list[FlexKVResponse]: + """ + Blocking wait for all put tasks. + + Returns: + list[FlexKVResponse]: Responses of all put tasks. + """ + return self._blocking_waiting_for_tasks(self.put_tasks) + + ####################### + #### Common Method #### + ####################### + + def cancel_tasks(self) -> None: + """ + Cancel tasks in self.cancel_tasks. + Call before launch_tasks() to delete req_id in self.req_id_to_task_dict + """ + # TODO: check if this method is inproc. + if len(self.tasks_to_cancel) == 0: + return + for task in self.tasks_to_cancel.values(): + del self.req_id_to_task_dict[task.request.request_id] + logger.info(f"FlexKV Cancel task: {task}") + self.flexkv_manager.cancel(task_ids=list(self.tasks_to_cancel.keys())) + self.tasks_to_cancel.clear() + + def launch_tasks(self) -> None: + """ + Launch tasks in self.unlaunched_tasks + """ + if len(self.tasks_to_launch) == 0: + return + task_launch_time = time.perf_counter() + task_ids: list[int] = [] + slot_mappings: list[np.ndarray] = [] + + for task_id, task in self.tasks_to_launch.items(): + logger.info(f"FlexKV Launch task: {task}") + task.task_launch_time = task_launch_time + task_ids.append(task_id) + slot_mappings.append(task.slot_mapping) + if isinstance(task, FlexKVGetTask): + self.get_tasks[task_id] = task + else: + self.put_tasks[task_id] = task + self.flexkv_manager.launch(task_ids=task_ids, + slot_mappings=slot_mappings) + self.tasks_to_launch.clear() + + def query_finished_task(self) -> tuple[set[str], set[str]]: + """ + Get response of finished task. + + Returns: + list[FlexKVResponse]: Responses of finished tasks. + """ + if len(self.req_id_to_task_dict) == 0: + return set(), set() + logger.debug(f"unfinished task: {self.req_id_to_task_dict}") + task_ids = list(self.get_tasks.keys()) + list(self.put_tasks.keys()) + responses_from_manager = self.flexkv_manager.try_wait(task_ids) + task_finished_time = time.perf_counter() + # responses_to_return: list[FlexKVResponse] = [] + finished_sending = set() + finished_recving = set() + num_failed_tasks = 0 + for task_id, response in responses_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + if task_id in self.get_tasks: + task = self.get_tasks.pop(task_id) + finished_recving.add(task.request.request_id) + else: + task = self.put_tasks.pop(task_id) + finished_sending.add(task.request.request_id) + del self.req_id_to_task_dict[task.request.request_id] + task.task_finished_time = task_finished_time + if success: + logger.info(f"{task} finished successfully.") + else: + logger.error(f"{task} failed, status: {response.status}.") + num_failed_tasks += 1 + # responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + # request=task.request, success=success)) + self.flexkv_stats.record_faild(num_failed_requests=num_failed_tasks) + return finished_sending, finished_recving + + def _blocking_waiting_for_tasks(self, task_dict: dict[int, FlexKVTask]) -> list[FlexKVResponse]: + """ + Blocking wait for tasks in task_dict. + + Returns: + list[FlexKVResponse]: Responses of all tasks in task_dict. + """ + if len(task_dict) == 0: + return [] + + task_ids = list(task_dict.keys()) + response_from_manager = self.flexkv_manager.wait(task_ids=task_ids) + task_finished_time = time.perf_counter() + responses_to_return: list[FlexKVResponse] = [] + for task_id, response in response_from_manager.items(): + success = (response.status == KVResponseStatus.SUCCESS) + task = task_dict.pop(task_id) + task.task_finished_time = task_finished_time + if success: + logger.info(f"{task} finished successfully.") + else: + logger.error(f"{task} failed, status: {response.status}.") + responses_to_return.append(FlexKVResponse(task_id=task_id, task_type=task.task_type, + request=task.request, success=success)) + return responses_to_return + + +class FlexKVWorkerConnector: + def __init__( + self, + flexkv_config: FlexKVConfig, + ): + current_device_id = torch.cuda.current_device() + self.flexkv_config = flexkv_config + logger.info(f"Start init FlexKVWorkerConnector to {flexkv_config.server_recv_port}") + self.tp_client = KVTPClient(flexkv_config.server_recv_port, 0, current_device_id) + logger.info("Finish init FlexKVWorkerConnector") + + def register_to_server(self, kv_caches: dict[str, torch.Tensor]): + logger.info("Start register kv_caches") + gpu_blocks = list(kv_caches.values()) + num_layer = len(kv_caches) + if self.flexkv_config.use_mla: + assert gpu_blocks[0].ndim == 3, ( + f"expect kv cached tensor has 3 dim but get shape={gpu_blocks[0].shape}.") + num_blocks = gpu_blocks[0].shape[0] + block_size = gpu_blocks[0].shape[1] + num_kv_heads = 1 + head_size = gpu_blocks[0].shape[2] + else: + assert gpu_blocks[0].ndim == 5, ( + f"expect kv cached tensor has 5 dim but get shape={gpu_blocks[0].shape}.") + num_blocks = gpu_blocks[0].shape[1] + block_size = gpu_blocks[0].shape[2] + num_kv_heads = gpu_blocks[0].shape[3] + head_size = gpu_blocks[0].shape[4] + gpu_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERWISE, + num_layer=num_layer, + num_block=num_blocks, + tokens_per_block=block_size, + num_head=num_kv_heads, + head_size=head_size, + is_mla=self.flexkv_config.use_mla, + ) + self.tp_client.register_to_server(gpu_blocks, gpu_layout) + logger.info("Finish register kv_caches") + + +class FlexKVConnectorV1Impl: + def __init__(self, vllm_config: "VllmConfig", role: "KVConnectorRole"): + self.role = role + flexkv_config = FlexKVConfig.from_env() + flexkv_config.post_init_from_vllm_config(vllm_config) + + if role == KVConnectorRole.SCHEDULER: + self.connector = FlexKVSchedulerConnector(flexkv_config) + elif role == KVConnectorRole.WORKER: + self.connector = FlexKVWorkerConnector(flexkv_config) + else: + raise ValueError(f"Unrecognized KVConnectorRole: {role}.") + + def shutdown(self): + if self.role == KVConnectorRole.SCHEDULER: + self.connector.shutdown() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + self.connector.register_to_server(kv_caches) + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self.connector.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput") -> "KVConnectorMetadata": + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + self.connector.cancel_tasks() + self.connector.launch_tasks() + return KVConnectorMetadata() + + def update_connector_output(self, connector_output: "KVConnectorOutput"): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + + finished_sending, finished_recving = self.connector.query_finished_task() + connector_output.finished_sending = finished_sending + connector_output.finished_recving = finished_recving + + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return self.connector.request_finished(request, block_ids), None diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index fd75750530..8e85ee9959 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -12,566 +12,199 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing as mp -import threading + +from typing import Optional, Tuple, List, Dict, Union, Iterable import time -from dataclasses import dataclass -from queue import Queue -from typing import Dict, Any, Optional -from typing import List, Callable, Union -import nvtx +import numpy as np import torch -from expiring_dict import ExpiringDict -from flexkv.cache.cache_engine import GlobalCacheEngine, TransferOpGraph -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.server.client import KVDPClient +from flexkv.server.server import KVServer, DPClient +from flexkv.kvtask import KVTaskEngine, KVResponse +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.debug import flexkv_logger -from flexkv.common.memory_handle import TensorSharedHandle -from flexkv.common.request import KVRequestType, KVRequest -from flexkv.common.transfer import DeviceType, get_nvtx_range_color, get_nvtx_default_color -from flexkv.common.storage import KVCacheLayout -from flexkv.common.exceptions import LogicError -from flexkv.common.tracer import FlexKVTracer -from flexkv.storage.storage_engine import StorageEngine -from flexkv.transfer.transfer_engine import TransferEngine - - -@dataclass -class RequestTracker: - task_id: int - task_type: KVRequestType - return_mask: torch.Tensor - callback: Optional[Callable] - task_end_ops_ids: List[int] - task_end_ops_status: List[bool] - task_finished: bool = False class KVManager: def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - gpu_layout: Optional[KVCacheLayout] = None, - gpu_blocks: Optional[Dict[int, List[TensorSharedHandle]]] = None): - - flexkv_logger.info(f"Initializing kvmanager...\nmodel_config: {model_config}\ncache_config: {cache_config}") - - mp.set_start_method('spawn', force=True) - self.init_nvtx_range = nvtx.push_range("Initialize kvmanager", color=get_nvtx_default_color()) - - if not cache_config.enable_cpu: - raise ValueError("enable_cpu must be True") - if cache_config.enable_remote and not cache_config.enable_ssd: - raise ValueError("enable_ssd must be True if enable_remote is True") - if not cache_config.enable_cpu and not cache_config.use_gds: - raise ValueError("use_gds must be True if enable_cpu is False") - self.cache_config = cache_config + gpu_register_port: Optional[str] = None, + server_recv_port: Optional[str] = None, + dp_client_id: int = 0): + flexkv_logger.info(f"{model_config = }") + flexkv_logger.info(f"{cache_config = }") self.model_config = model_config - - self._verify_Model_Cache_config(model_config, cache_config) - self.cache_engine = GlobalCacheEngine(cache_config, model_config) - self.storage_engine = StorageEngine(self.model_config, self.cache_config) - - # Initialize tracer - self.tracer = FlexKVTracer(cache_config) - - # Record configuration in tracer - if gpu_layout is not None: - self.tracer.trace_config(model_config, cache_config, gpu_layout) - - - self.transfer_engine: Optional[TransferEngine] = None - self.gpu_layout: Optional[KVCacheLayout] = gpu_layout - - self.running = False - self.requests_tracker: ExpiringDict[int, RequestTracker] = ExpiringDict(1800) # 30 minutes - self.graph_to_request: Dict[int, int] = {} - self.taskid_to_nvtx_range: Dict[int, Any] = {} - self.graphid_to_nvtx_range: Dict[int, Any] = {} - - self._task_id_counter = 0 - self.task_queue: Queue[KVRequest] = Queue() - - if gpu_blocks is None: - gpu_blocks = {} - - self.num_gpus = len(gpu_blocks) - self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = gpu_blocks - - self.lock = threading.Lock() - - if self.num_gpus == self.model_config.tp_size * self.model_config.dp_size: - self._init_transfer_engine() - - # Note that for now only after all the gpu blocks are added, we can initialize the transfer engine - def _init_transfer_engine(self) -> None: - assert self.gpu_layout is not None - assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size - for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): - self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, - self.gpu_layout, - device_id, - dtype=self.model_config.dtype) - self.gpu_handles = [ - self.storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(self.model_config.tp_size * self.model_config.dp_size) - ] - cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) if self.cache_config.enable_cpu else None - ssd_handle = self.storage_engine.get_storage_handle(DeviceType.SSD) if self.cache_config.enable_ssd else None - remote_handle = ( - self.storage_engine.get_storage_handle(DeviceType.REMOTE) - if self.cache_config.enable_remote - else None - ) - self.transfer_engine = TransferEngine(self.gpu_handles, - self.model_config, - self.cache_config, - cpu_handle, - ssd_handle, - remote_handle) - - nvtx.pop_range(self.init_nvtx_range) - - - def is_ready(self) -> bool: - return self.transfer_engine is not None - - def is_running(self) -> bool: - return self.running + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + self.server_recv_port = server_recv_port + self.server_client_mode = model_config.dp_size > 1 + self.dp_client_id = dp_client_id + flexkv_logger.info(f"server_client_mode: {self.server_client_mode}") + if self.server_client_mode: + # server should only be created once but kvmanager will init in every dp rank. + if dp_client_id == 0: + self.server_handle = KVServer.create_server(model_config, + cache_config, + gpu_register_port, + server_recv_port) + + else: + self.server_handle = None + self.dp_client = KVDPClient(self.server_recv_port, self.model_config, dp_client_id) + else: + self.server_handle = None + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port) + + @property + def dpclient_id(self) -> int: + return self.dp_client_id def start(self) -> None: - if self.running: - flexkv_logger.warning("kvmanager is already running") - return - if not self.is_ready(): - raise ValueError("transfer engine is not ready, please add all gpu blocks first") - if self.transfer_engine is not None: - self.transfer_engine.start() - self.running = True + if not self.server_client_mode: + self.kv_task_engine.start() else: - raise ValueError("transfer engine is not initialized, please call start() after all gpu blocks are added") - - self._worker_thread = threading.Thread(target=self._worker_loop) - self._worker_thread.start() - flexkv_logger.info("KVManager fully started and running") + # send the start request to the server + self.dp_client.start_server_and_register() - # the gpu_blocks of multiple gpus can be added post initialization. - # the transfer engine will be initialized after we have all the intended gpu handles. - def register_single_gpu_blocks( - self, - gpu_handles: List[TensorSharedHandle], - gpu_layout: KVCacheLayout, - dp_client_id: int = 0, - tp_rank: int = 0, - ) -> None: - if self.transfer_engine is not None: - raise ValueError("we have already get all gpu blocks") - if self.gpu_layout is None: - self.gpu_layout = gpu_layout - self.tracer.trace_config(self.model_config, self.cache_config, self.gpu_layout) + def is_ready(self) -> bool: + if self.server_client_mode: + return self.dp_client.is_ready() else: - assert self.gpu_layout == gpu_layout - self.all_gpu_blocks[tp_rank + dp_client_id * self.model_config.tp_size] = gpu_handles - self.num_gpus += 1 - if self.num_gpus == self.model_config.tp_size * self.model_config.dp_size: - self._init_transfer_engine() - - def _worker_loop(self) -> None: - assert self.transfer_engine is not None - while self.running: - # deal with completed requests from the cache engine - if not self.task_queue.empty(): - request = self.task_queue.get() - if request.request_type == KVRequestType.SHUTDOWN: - self.shutdown() - break - elif request.request_type == KVRequestType.GET: - nvtx.push_range(f"cache_engine.get request_id: {request.request_id}", - color=get_nvtx_default_color()) - graph, return_mask, callback, task_end_ops_ids = self.cache_engine.get(request.request_id, - request.token_ids, - request.token_mask, - request.slot_mapping, - self.model_config.num_layers, - request.layer_granularity, - request.dp_id) - elif request.request_type == KVRequestType.PUT: - nvtx.push_range(f"cache_engine.put request_id: {request.request_id}", - color=get_nvtx_default_color()) - graph, return_mask, callback, task_end_ops_ids = self.cache_engine.put(request.request_id, - request.token_ids, - request.token_mask, - request.slot_mapping, - self.model_config.num_layers, - request.dp_id) - else: - raise ValueError(f"Unknown request type: {request.request_type}") - nvtx.pop_range() - if graph.num_ops == 0: #early return - flexkv_logger.info(f"no transfer: " - f"request_id = {request.request_id}, request_type = {request.request_type}") - layer_op_num = self.model_config.num_layers // request.layer_granularity \ - if request.request_type == KVRequestType.GET else 1 - self.requests_tracker[request.request_id] = RequestTracker(task_id=request.request_id, - task_type=request.request_type, - return_mask=return_mask, - callback=None, - task_end_ops_ids=[-1]*layer_op_num, - task_end_ops_status=[True]*layer_op_num, - task_finished=True) - else: - self.graph_to_request[graph.graph_id] = request.request_id - self.graphid_to_nvtx_range[graph.graph_id] = nvtx.start_range( - f"request id: {request.request_id}, " - f"graph id: {graph.graph_id}", - color=get_nvtx_range_color(graph.graph_id)) - self.requests_tracker[request.request_id] = RequestTracker(task_id=request.request_id, - task_type=request.request_type, - return_mask=return_mask, - callback=callback, - task_end_ops_ids=task_end_ops_ids, - task_end_ops_status=len(task_end_ops_ids)*[False], - task_finished=False) - self.transfer_engine.submit_transfer_graph(graph) - results = self.transfer_engine.get_completed_graphs_and_ops(timeout=0.001) - for completed_graph_id, completed_op_id in results: - request_id = self.graph_to_request[completed_graph_id] - request_tracker = self.requests_tracker[request_id] - if completed_op_id == -1: - if request_tracker.callback: - request_tracker.callback() - nvtx.end_range(self.graphid_to_nvtx_range[completed_graph_id]) - self.graphid_to_nvtx_range.pop(completed_graph_id) - self.graph_to_request.pop(completed_graph_id) - nvtx.end_range(self.taskid_to_nvtx_range[request_tracker.task_id]) - self.taskid_to_nvtx_range.pop(request_tracker.task_id) - request_tracker.task_finished = True - elif completed_op_id in request_tracker.task_end_ops_ids: - request_tracker.task_end_ops_status[request_tracker.task_end_ops_ids.index(completed_op_id)] = True - self.requests_tracker[request_id] = request_tracker - time.sleep(0.0001) - - def _get_task_id(self) -> int: - with self.lock: - old_value = self._task_id_counter - self._task_id_counter += 1 - return old_value - - def __del__(self) -> None: - if hasattr(self, 'tracer'): - self.tracer.flush() - if self.running: - self.shutdown() + return self.kv_task_engine.is_ready() def shutdown(self) -> None: - self.running = False - # Flush tracer before shutdown - if hasattr(self, 'tracer'): - self.tracer.flush() - flexkv_logger.info("kvmanager shutdown") - self.task_queue.put(KVRequest( - request_type=KVRequestType.SHUTDOWN, - request_id=-1, - token_ids=torch.empty(0), - token_mask=torch.empty(0), - slot_mapping=torch.empty(0), - )) - self._worker_thread.join() - if self.transfer_engine is not None: - self.transfer_engine.shutdown() + if self.server_client_mode: + self.dp_client.shutdown() + else: + self.kv_task_engine.shutdown() def get_async(self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, + token_ids: Union[torch.Tensor, np.ndarray], + slot_mapping: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, layer_granularity: int = -1, dp_id: int = 0, - task_id: int = -1) -> int: - if not self.running: - raise ValueError("kvmanager is not running, please call start() first") - if token_mask is None: - token_mask = torch.ones_like(token_ids) - if layer_granularity == -1: - layer_granularity = self.model_config.num_layers - if task_id == -1: - task_id = self._get_task_id() - # Trace the request - self.tracer.trace_request( - request_type="GET", - request_id=task_id, - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - layer_granularity=layer_granularity, - dp_id=dp_id - ) - nvtx.mark(f"GET request_id: {task_id}") - self.taskid_to_nvtx_range[task_id] = nvtx.start_range(f"GET request_id: {task_id}", - color=get_nvtx_default_color()) - self.task_queue.put(KVRequest( - request_type=KVRequestType.GET, - request_id=task_id, - token_ids=token_ids, - token_mask=token_mask, - slot_mapping=slot_mapping, - layer_granularity=layer_granularity, - dp_id=dp_id, - )) - self.requests_tracker[task_id] = RequestTracker(task_id=task_id, - task_type=KVRequestType.GET, - return_mask=torch.empty(0), - callback=None, - task_end_ops_ids=[], - task_end_ops_status=[], - task_finished=False) + ) -> int: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(slot_mapping, torch.Tensor): + slot_mapping = slot_mapping.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id = self.dp_client.get_async(token_ids, + slot_mapping, + token_mask, + layer_granularity) + else: + task_id, _ = self.kv_task_engine.get_async(token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id) return task_id + def get_match(self, + token_ids: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + layer_granularity: int = -1, + dp_id: int = 0, + ) -> Tuple[int, np.ndarray]: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id, mask = self.dp_client.get_match(token_ids, + token_mask, + layer_granularity) + else: + task_id, mask = self.kv_task_engine.get_match(token_ids, + token_mask, + layer_granularity, + dp_id) + return task_id, mask + def put_async(self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, + token_ids: Union[torch.Tensor, np.ndarray], + slot_mapping: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, dp_id: int = 0, - task_id: int = -1) -> int: - if not self.running: - raise ValueError("kvmanager is not running, please call start() first") - if token_mask is None: - token_mask = torch.ones_like(token_ids) - if task_id == -1: - task_id = self._get_task_id() - # Trace the request - self.tracer.trace_request( - request_type="PUT", - request_id=task_id, - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - dp_id=dp_id - ) - nvtx.mark(f"PUT request_id: {task_id}") - self.taskid_to_nvtx_range[task_id] = nvtx.start_range(f"PUT request_id: {task_id}", - color=get_nvtx_default_color()) - self.task_queue.put(KVRequest( - request_type=KVRequestType.PUT, - request_id=task_id, - token_ids=token_ids, - token_mask=token_mask, - slot_mapping=slot_mapping, - dp_id=dp_id, - )) - self.requests_tracker[task_id] = RequestTracker(task_id=task_id, - task_type=KVRequestType.PUT, - return_mask=torch.empty(0), - callback=None, - task_end_ops_ids=[], - task_end_ops_status=[], - task_finished=False) + ) -> int: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(slot_mapping, torch.Tensor): + slot_mapping = slot_mapping.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id = self.dp_client.put_async(token_ids, slot_mapping, token_mask) + else: + task_id, _ = self.kv_task_engine.put_async(token_ids, slot_mapping, token_mask, dp_id) return task_id - # wait for the key op to be finished - def wait(self, task_ids: Union[int, List[int]], timeout: float = 20.0) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait", - task_ids=task_ids, - ) - nvtx.mark(f"wait task_ids: {task_ids}") + def put_match(self, + token_ids: Union[torch.Tensor, np.ndarray], + token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + dp_id: int = 0, + ) -> Tuple[int, np.ndarray]: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.numpy() + if isinstance(token_mask, torch.Tensor): + token_mask = token_mask.numpy() + if self.server_client_mode: + task_id, mask = self.dp_client.put_match(token_ids, token_mask) + else: + task_id, mask = self.kv_task_engine.put_match(token_ids, token_mask, dp_id) + return task_id, mask + + def launch(self, + task_ids: Union[int, List[int]], + slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] - num_completed_tasks = 0 - num_tasks = len(task_ids) - return_masks = {} - start_time = time.time() - while num_completed_tasks < num_tasks: - finished_task_ids = [] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0, dtype=torch.bool) #if not found in tracker, the return mask is an empty tensor - num_completed_tasks += 1 - finished_task_ids.append(task_id) - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.return_mask) == 0: #NOT READY - continue - if all(task_tracker.task_end_ops_status): - num_completed_tasks += 1 - return_masks[task_id] = task_tracker.return_mask - finished_task_ids.append(task_id) - task_ids = [task_id for task_id in task_ids if task_id not in finished_task_ids] - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_ids: {task_ids} timeout, has to return now") - for task_id in task_ids: - return_masks[task_id] = torch.empty(0, dtype=torch.bool) # return mask of timeout task is also an empty tensor - nvtx.mark(f"wait task_ids: {task_ids} timeout") - return return_masks - time.sleep(0.0001) - nvtx.mark(f"wait task_ids: {task_ids} done") - return return_masks + if not isinstance(slot_mappings, List): + slot_mappings = [slot_mappings] + if isinstance(slot_mappings[0], torch.Tensor): + slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] + if self.server_client_mode: + self.dp_client.launch_tasks(task_ids, slot_mappings) + else: + self.kv_task_engine.launch_tasks(task_ids, slot_mappings) - # wait for the whole task to be finished, including the key op and all other ops - # this function is mainly designed for testing to avoid the frequency of writing is too high to use up memory blocks - def wait_for_graph_finished(self, - task_ids: Union[int, List[int]], - timeout: float = 20.0) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait_for_graph_finished", - task_ids=task_ids, - ) - nvtx.mark(f"wait task_ids: {task_ids}") + def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] - num_completed_tasks = 0 - return_masks = {} - start_time = time.time() - while num_completed_tasks < len(task_ids): - finished_task_ids = [] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - num_completed_tasks += 1 - finished_task_ids.append(task_id) - continue - task_tracker = self.requests_tracker[task_id] - if task_tracker.task_finished: - num_completed_tasks += 1 - return_masks[task_id] = task_tracker.return_mask - finished_task_ids.append(task_id) - task_ids = [task_id for task_id in task_ids if task_id not in finished_task_ids] - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_ids: {task_ids} timeout, has to return now") - for task_id in task_ids: - return_masks[task_id] = torch.empty(0) # return mask of timeout task is also an empty tensor - nvtx.mark(f"wait task_ids: {task_ids} timeout") - return return_masks - time.sleep(0.0001) - nvtx.mark(f"wait task_ids: {task_ids} done") - return return_masks + if self.server_client_mode: + self.dp_client.cancel_tasks(task_ids) + else: + self.kv_task_engine.cancel_tasks(task_ids) - # the try_wait api is used for server-client mode: - # server process running the kvmanager should NOT be blocked by any single client - def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="try_wait", - task_ids=task_ids, - ) - return_masks: Dict[int, torch.Tensor] = {} + def wait(self, + task_ids: Union[int, List[int]], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - mask = None - elif all(task_tracker.task_end_ops_status): - mask = task_tracker.return_mask - return_masks[task_id] = mask - else: - mask = None - - return return_masks - - def wait_at_layer_group(self, task_id: int, layer_group_id: int, timeout: float = 20.0) -> torch.Tensor: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="wait_at_layer_group", - task_ids=task_id, - layer_group_id=layer_group_id - ) - nvtx.mark(f"wait task_id: {task_id}, layer_group_id: {layer_group_id}") - start_time = time.time() - while True: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return torch.empty(0) #if not found in tracker, the return mask is an empty tensor - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - continue - if task_tracker.task_end_ops_status[layer_group_id]: - return task_tracker.return_mask - if time.time() - start_time > timeout: - flexkv_logger.warning(f"wait task_id: {task_id}, layer_group_id: {layer_group_id} " - f"timeout, has to return now") - return torch.empty(0) # return mask of timeout task is an empty tensor - time.sleep(0.0001) - - # nvtx.mark(f"wait_at_layer_group task_id: {task_id}, layer_group_id: {layer_group_id} done") - # return return_mask + if self.server_client_mode: + return self.dp_client.wait(task_ids, timeout, completely) + else: + return self.kv_task_engine.wait(task_ids, timeout, completely) - def try_wait_at_layer_group(self, - task_ids: Union[int, List[int]], - layer_group_id: int) -> Dict[int, torch.Tensor]: - # Trace the wait request - self.tracer.trace_wait_request( - wait_type="try_wait_at_layer_group", - task_ids=task_ids, - layer_group_id=layer_group_id, - ) - return_masks: Dict[int, torch.Tensor] = {} + def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: if isinstance(task_ids, int): task_ids = [task_ids] - for task_id in task_ids: - if task_id not in self.requests_tracker: - flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") - return_masks[task_id] = torch.empty(0) #if not found in tracker, the return mask is an empty tensor - continue - task_tracker = self.requests_tracker[task_id] - if len(task_tracker.task_end_ops_ids) == 0: - mask = torch.empty(0) - elif task_tracker.task_end_ops_status[layer_group_id]: - mask = task_tracker.return_mask - else: - mask = torch.empty(0) - return_masks[task_id] = mask - return return_masks - - def _verify_Model_Cache_config(self, - model_config: ModelConfig, - cache_config: CacheConfig): - if cache_config.enable_remote: - if cache_config.remote_cache_path is None: - - if cache_config.remote_file_prefix is None: - raise ValueError("remote_file_prefix must be provided when remote_cache_path is None") - - if cache_config.remote_file_num is None or cache_config.remote_file_num <= 0: - raise ValueError("remote_file_num must be a positive integer") - - cache_config.remote_cache_path = [ - f"{cache_config.remote_file_prefix}_{i}" - for i in range(cache_config.remote_file_num) - ] - - if cache_config.remote_cache_size_mode == "block_num": - if cache_config.num_remote_blocks is None: - raise ValueError("num_remote_blocks must not None if use block_num model") - elif cache_config.remote_cache_size_mode == "file_size": - if cache_config.remote_file_size is None: - raise ValueError("remote_file_size must not None if use file_size model") - if model_config.use_mla: - kv_size = ( - model_config.num_layers - * cache_config.tokens_per_block - * model_config.num_kv_heads - * model_config.head_size - * model_config.dtype.itemsize - ) - else: - kv_size = ( - model_config.num_layers - * 2 - * cache_config.tokens_per_block - * model_config.num_kv_heads - * model_config.head_size - * model_config.dtype.itemsize - ) - cache_config.num_remote_blocks = cache_config.remote_file_size // kv_size * cache_config.remote_file_num + if self.server_client_mode: + return self.dp_client.try_wait(task_ids) + else: + return self.kv_task_engine.try_wait(task_ids) - else: - raise ValueError("remote_cache_size_mode must block_num or file_size model") + # Only for testing + def _clear_cpu_cache(self) -> None: + if self.server_client_mode: + flexkv_logger.error("clear_cache is not supported in server client mode") + return + else: + self.kv_task_engine._clear_cpu_cache() diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py new file mode 100644 index 0000000000..3d661a66bf --- /dev/null +++ b/flexkv/kvtask.py @@ -0,0 +1,518 @@ +import time +from typing import Dict, Optional, List, Union, Tuple +import threading +from enum import Enum +from dataclasses import dataclass +from typing import Callable + + +from expiring_dict import ExpiringDict +import nvtx +import torch +import numpy as np + +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.debug import flexkv_logger +from flexkv.common.transfer import TransferOpGraph, get_nvtx_default_color +from flexkv.common.tracer import FlexKVTracer +from flexkv.cache.cache_engine import GlobalCacheEngine +from flexkv.transfer_manager import TransferManagerHandle +from flexkv.common.request import KVResponseStatus, KVResponse + +class TaskStatus(Enum): + # slot mapping is not ready + UNREADY = "unready" + # waiting for the task to be launched + READY = "ready" + # in transfer + RUNNING = "running" + # transfer completed + COMPLETED = "completed" + # transfer cancelled + CANCELLED = "cancelled" + # transfer failed + FAILED = "failed" + +class TaskType(Enum): + GET = "get" + PUT = "put" + +@dataclass +class KVTask: + # task descriptor + task_id: int + task_type: TaskType + task_end_op_id: int + task_end_op_finished: bool + status: TaskStatus + + # params + token_ids: np.ndarray + slot_mapping: np.ndarray + token_mask: Optional[np.ndarray] + dp_id: int + + # cache engine return + graph: TransferOpGraph + return_mask: np.ndarray + callback: Optional[Callable] + + def is_completed(self) -> bool: + return self.status in [TaskStatus.COMPLETED, TaskStatus.CANCELLED, TaskStatus.FAILED] + +TASK_STATUS_TO_RESPONSE_STATUS = { + TaskStatus.COMPLETED: KVResponseStatus.SUCCESS, + TaskStatus.CANCELLED: KVResponseStatus.CANCELLED, + TaskStatus.FAILED: KVResponseStatus.FAILED, + TaskStatus.RUNNING: KVResponseStatus.SUCCESS, # for early return: still running, but success +} + +def convert_to_response_status(task_status: TaskStatus) -> KVResponseStatus: + return TASK_STATUS_TO_RESPONSE_STATUS[task_status] + +class KVTaskManager: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: Optional[str] = None, + use_separate_process: bool = True, + ): + if not cache_config.enable_cpu: + raise ValueError("enable_cpu must be True") + if cache_config.enable_remote and not cache_config.enable_ssd: + raise ValueError("enable_ssd must be True if enable_remote is True") + if not cache_config.enable_cpu and not cache_config.use_gds: + raise ValueError("use_gds must be True if enable_cpu is False") + self.cache_config = cache_config + self.model_config = model_config + self._check_config(model_config, cache_config) + + self.cache_engine = GlobalCacheEngine(cache_config, model_config) + + self.transfer_handle = TransferManagerHandle( + self.model_config, + self.cache_config, + use_separate_process=use_separate_process, + gpu_register_port=gpu_register_port + ) + + self.tasks: ExpiringDict[int, KVTask] = ExpiringDict(max_age_seconds=1800, max_len=100000) # 30 minutes + self.graph_to_task: Dict[int, int] = {} + + self.task_id_counter = 0 + self.task_id_lock = threading.Lock() + + self.running_tasks: int = 0 + + def start(self) -> None: + self.transfer_handle.start() + + def is_ready(self) -> bool: + return self.transfer_handle.is_ready() + + def __del__(self) -> None: + self.shutdown() + + def shutdown(self) -> None: + self.transfer_handle.shutdown() + + def create_get_task(self, + task_id: int, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + is_fake_slot_mapping: bool = False, + ) -> None: + if task_id in self.tasks: + raise ValueError(f"Task ID {task_id} already exists") + graph, return_mask, callback, task_end_op_id = self.cache_engine.get(task_id, + token_ids, + token_mask, + slot_mapping, + self.model_config.num_layers, + layer_granularity, + dp_id) + self.tasks[task_id] = KVTask( + task_id=task_id, + task_type=TaskType.GET, + task_end_op_id=task_end_op_id, + task_end_op_finished=False, + status=TaskStatus.UNREADY if is_fake_slot_mapping else TaskStatus.READY, + token_ids=token_ids, + slot_mapping=slot_mapping, + token_mask=token_mask, + dp_id=dp_id, + graph=graph, + return_mask=return_mask, + callback=callback) + + self.graph_to_task[graph.graph_id] = task_id + + def create_put_task(self, + task_id: int, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + is_fake_slot_mapping: bool = False, + ) -> None: + if task_id in self.tasks: + raise ValueError(f"Task ID {task_id} already exists") + graph, return_mask, callback, task_end_op_id = self.cache_engine.put(task_id, + token_ids, + token_mask, + slot_mapping, + self.model_config.num_layers, + dp_id) + self.tasks[task_id] = KVTask( + task_id=task_id, + task_type=TaskType.PUT, + task_end_op_id=task_end_op_id, + task_end_op_finished=False, + status=TaskStatus.UNREADY if is_fake_slot_mapping else TaskStatus.READY, + token_ids=token_ids, + slot_mapping=slot_mapping, + token_mask=token_mask, + dp_id=dp_id, + graph=graph, + return_mask=return_mask, + callback=callback) + self.graph_to_task[graph.graph_id] = task_id + + def _launch_task(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + return + if task.status != TaskStatus.READY: + raise ValueError(f"Task {task_id} status is {task.status}, cannot launch") + transfer_graph = task.graph + task.status = TaskStatus.RUNNING + nvtx.mark(f"launch task: task_id={task_id}, graph_id={transfer_graph.graph_id}") + if transfer_graph.num_ops > 0: + self.transfer_handle.submit(transfer_graph) + + def _update_tasks(self, timeout: float = 0.001) -> None: + completed_ops = self._get_completed_ops(timeout) + for completed_graph_id, completed_op_id in completed_ops: + if completed_graph_id not in self.graph_to_task: + continue + task_id = self.graph_to_task[completed_graph_id] + task = self.tasks[task_id] + if completed_op_id == -1: # the graph is totally finished + self._mark_completed(task_id) + elif completed_op_id == task.task_end_op_id: + self.tasks[task_id].task_end_op_finished = True + + def _cancel_task(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + flexkv_logger.warning(f"Task {task_id} is already completed, cannot cancel") + return + if task.status == TaskStatus.RUNNING: + flexkv_logger.warning(f"Task {task_id} is running, cannot cancel") + return + if task.status == TaskStatus.CANCELLED: + flexkv_logger.warning(f"Task {task_id} is already cancelled, cannot cancel") + return + task.status = TaskStatus.CANCELLED + self.graph_to_task.pop(task.graph.graph_id, None) + + def check_completed(self, task_id: int, completely: bool = False) -> bool: + self._process_empty_graph(task_id) + task = self.tasks[task_id] + if completely: + return task.is_completed() + return task.is_completed() or task.task_end_op_finished + + def set_slot_mappings(self, + task_ids: List[int], + slot_mappings: List[np.ndarray]) -> None: + for task_id, slot_mapping in zip(task_ids, slot_mappings): + self._set_slot_mapping_impl(task_id, slot_mapping) + + def _set_slot_mapping_impl(self, task_id: int, slot_mapping: np.ndarray) -> None: + task = self.tasks[task_id] + if task.status != TaskStatus.UNREADY: + return + graph_ids = self.cache_engine.slot_mapping_to_block_ids(slot_mapping, + self.cache_config.tokens_per_block) + task.graph.set_gpu_blocks(graph_ids) + task.status = TaskStatus.READY + + def _gen_task_id(self) -> int: + with self.task_id_lock: + old_value = self.task_id_counter + self.task_id_counter += 1 + return old_value + + def _mark_completed(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.is_completed(): + return + if task.callback: + task.callback() + task.status = TaskStatus.COMPLETED + task.task_end_op_finished = True + self.graph_to_task.pop(task.graph.graph_id) + + def _process_empty_graph(self, task_id: int) -> None: + task = self.tasks[task_id] + if task.graph.num_ops == 0: + self._mark_completed(task_id) + + def _get_completed_ops(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_handle.wait(timeout) + + def _check_config(self, model_config: ModelConfig, cache_config: CacheConfig) -> None: + if cache_config.enable_remote: + if cache_config.remote_cache_path is None: + + if cache_config.remote_file_prefix is None: + raise ValueError("remote_file_prefix must be provided when remote_cache_path is None") + + if cache_config.remote_file_num is None or cache_config.remote_file_num <= 0: + raise ValueError("remote_file_num must be a positive integer") + + cache_config.remote_cache_path = [ + f"{cache_config.remote_file_prefix}_{i}" + for i in range(cache_config.remote_file_num) + ] + + if cache_config.remote_cache_size_mode == "block_num": + if cache_config.num_remote_blocks is None: + raise ValueError("num_remote_blocks must not None if use block_num model") + elif cache_config.remote_cache_size_mode == "file_size": + if cache_config.remote_file_size is None: + raise ValueError("remote_file_size must not None if use file_size model") + if model_config.use_mla: + kv_size = ( + model_config.num_layers + * cache_config.tokens_per_block + * model_config.num_kv_heads + * model_config.head_size + * model_config.dtype.itemsize + ) + else: + kv_size = ( + model_config.num_layers + * 2 + * cache_config.tokens_per_block + * model_config.num_kv_heads + * model_config.head_size + * model_config.dtype.itemsize + ) + cache_config.num_remote_blocks = cache_config.remote_file_size // kv_size * cache_config.remote_file_num + + else: + raise ValueError("remote_cache_size_mode must block_num or file_size model") + + +class KVTaskEngine(KVTaskManager): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: Optional[str] = None, + use_separate_process: bool = True, + ): + super().__init__(model_config, cache_config, gpu_register_port, use_separate_process) + self.tracer = FlexKVTracer(cache_config) + + def get_async(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + task_id, return_mask = self._get_match_impl(token_ids, + slot_mapping, + is_fake_slot_mapping=False, + token_mask=token_mask, + layer_granularity=layer_granularity, + dp_id=dp_id, + task_id=task_id) + self._launch_task(task_id) + return task_id, return_mask + + def put_async(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + task_id, return_mask = self._put_match_impl(token_ids, + slot_mapping, + is_fake_slot_mapping=False, + token_mask=token_mask, + dp_id=dp_id, + task_id=task_id) + self._launch_task(task_id) + return task_id, return_mask + + def _wait_impl(self, + task_ids: List[int], + timeout: float = 20.0, + completely: bool = False, + only_return_finished: bool = False, + ) -> Dict[int, KVResponse]: + return_responses = {} + start_time = time.time() + is_timeout = timeout == 0.0 + + self._update_tasks(timeout=0) + + for task_id in task_ids: + while True: + if task_id not in self.tasks: + flexkv_logger.error(f"task_id {task_id} not submitted into flexKV") + return_responses[task_id] = KVResponse( + status=KVResponseStatus.NOTFOUND, + task_id=task_id, + return_mask=None + ) + break + elif self.tasks[task_id].status == TaskStatus.UNREADY: + flexkv_logger.warning(f"task_id {task_id} is unready") + return_responses[task_id] = KVResponse( + status=KVResponseStatus.UNREADY, + task_id=task_id, + return_mask=None + ) + break + elif self.check_completed(task_id, completely=completely): + return_responses[task_id] = KVResponse( + status=convert_to_response_status(self.tasks[task_id].status), + task_id=task_id, + return_mask=self.tasks[task_id].return_mask + ) + break + elif only_return_finished: + break + elif time.time() - start_time > timeout: + is_timeout = True + if is_timeout: + return_responses[task_id] = KVResponse( + status=KVResponseStatus.TIMEOUT, + task_id=task_id, + return_mask=None + ) + break + self._update_tasks(timeout=0.001) + return return_responses + + def try_wait(self, task_ids: Union[int, List[int]]) -> Dict[int, KVResponse]: + if isinstance(task_ids, int): + task_ids = [task_ids] + nvtx.mark(f"try_wait task_ids: {task_ids}") + return_responses = self._wait_impl(task_ids, + completely=False, + only_return_finished=True) + return return_responses + + def wait(self, + task_ids: Union[int, List[int]], + timeout: float = 20.0, + completely: bool = False) -> Dict[int, KVResponse]: + if isinstance(task_ids, int): + task_ids = [task_ids] + nvtx.push_range(f"wait task_ids: {task_ids}", color=get_nvtx_default_color()) + return_responses = self._wait_impl(task_ids, timeout, completely=completely) + nvtx.pop_range() + return return_responses + + def get_match(self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids, dtype=bool) + fake_slot_mapping = np.zeros_like(token_ids[token_mask]) + return self._get_match_impl(token_ids, + fake_slot_mapping, + is_fake_slot_mapping=True, + token_mask=token_mask, + layer_granularity=layer_granularity, + dp_id=dp_id, + task_id=task_id) + + def _get_match_impl(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + is_fake_slot_mapping: bool = False, + token_mask: Optional[np.ndarray] = None, + layer_granularity: int = -1, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids) + if layer_granularity == -1: + layer_granularity = self.model_config.num_layers + if task_id == -1: + task_id = self._gen_task_id() + nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color()) + self.create_get_task(task_id, + token_ids, + slot_mapping, + token_mask, + layer_granularity, + dp_id, + is_fake_slot_mapping=is_fake_slot_mapping) + self._process_empty_graph(task_id) + nvtx.pop_range() + return task_id, self.tasks[task_id].return_mask + + def put_match(self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + fake_slot_mapping = np.zeros_like(token_ids) + return self._put_match_impl(token_ids, + fake_slot_mapping, + is_fake_slot_mapping=True, + token_mask=token_mask, + dp_id=dp_id, + task_id=task_id) + + def _put_match_impl(self, + token_ids: np.ndarray, + slot_mapping: np.ndarray, + is_fake_slot_mapping: bool = False, + token_mask: Optional[np.ndarray] = None, + dp_id: int = 0, + task_id: int = -1) -> Tuple[int, np.ndarray]: + if token_mask is None: + token_mask = np.ones_like(token_ids) + if task_id == -1: + task_id = self._gen_task_id() + nvtx.push_range(f"put match: task_id={task_id}", color=get_nvtx_default_color()) + self.create_put_task(task_id, + token_ids, + slot_mapping, + token_mask, + dp_id, + is_fake_slot_mapping=is_fake_slot_mapping) + self._process_empty_graph(task_id) + nvtx.pop_range() + return task_id, self.tasks[task_id].return_mask + + def launch_tasks(self, + task_ids: List[int], + slot_mappings: List[np.ndarray]) -> None: + assert isinstance(slot_mappings[0], np.ndarray) + self.set_slot_mappings(task_ids, slot_mappings) + for task_id in task_ids: + self._launch_task(task_id) + + def cancel_tasks(self, task_ids: Union[int, List[int]]) -> None: + if isinstance(task_ids, int): + task_ids = [task_ids] + for task_id in task_ids: + self._cancel_task(task_id) + + def _clear_cpu_cache(self) -> None: + self.cache_engine.cpu_cache_engine.reset() diff --git a/flexkv/server/client.py b/flexkv/server/client.py index 257c3f2fa3..1643af98a3 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -2,16 +2,18 @@ from multiprocessing import Lock, Queue from multiprocessing.connection import Connection from queue import Queue as ThreadQueue -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Callable import tempfile import torch import zmq +import numpy as np -from flexkv.common.config import ModelConfig +from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout +from flexkv.common.request import KVResponseStatus, KVResponse from flexkv.server.utils import get_zmq_socket from flexkv.server.request import ( RegisterDPClientRequest, @@ -19,30 +21,36 @@ IsReadyRequest, PutRequest, GetRequest, + PutMatchRequest, + GetMatchRequest, + LaunchTaskRequest, + CancelTaskRequest, WaitRequest, TryWaitRequest, CheckRunningRequest, + StartRequest, ShutdownRequest, Response ) - class KVDPClient: def __init__( self, server_recv_port: str, model_config: ModelConfig, + dp_client_id: int, ): # Init inter-process communication context = zmq.Context(2) self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) - client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" + self.client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=True).name}" self.recv_from_server = get_zmq_socket( - context, zmq.SocketType.PULL, client_recv_port, True + context, zmq.SocketType.PULL, self.client_recv_port, True ) - self.dp_client_id = self.register_to_server(model_config, client_recv_port) + self.dp_client_id = dp_client_id + self.model_config = model_config self._task_id_range = (self.dp_client_id * 10000000, (self.dp_client_id + 1) * 10000000) self._task_id_counter = self._task_id_range[0] @@ -57,22 +65,20 @@ def _get_task_id(self) -> int: self._task_id_counter = self._task_id_range[0] return old_value + def start_server_and_register(self) -> None: + #start server and register + req = StartRequest(self.dp_client_id) + self.send_to_server.send_pyobj(req) + self.register_to_server(self.model_config, self.client_recv_port) + def register_to_server( self, model_config: ModelConfig, client_recv_port: str, - ) -> int: - register_req = RegisterDPClientRequest(model_config, client_recv_port) - + ) -> None: + register_req = RegisterDPClientRequest(self.dp_client_id, model_config, client_recv_port) self.send_to_server.send_pyobj(register_req) - # blocking - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"DP client registered successfully! DP client id: {response.dp_client_id}") - return response.dp_client_id - else: - flexkv_logger.error(f"DP client registeration fialed: {response.error_msg}") - raise + flexkv_logger.info(f"DP client {self.dp_client_id} registered to server request sent!") def is_ready( self, @@ -80,61 +86,103 @@ def is_ready( req = IsReadyRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.success: - return response.is_ready - else: - flexkv_logger.error(f"is_ready failed: {response.error_msg}") - raise - + return response.is_ready + def put_async( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], - ) -> Optional[int]: - # start_time = time.time() + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray], + ) -> int: req = PutRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, + token_ids, + slot_mapping, + token_mask if token_mask is not None else None, self._get_task_id()) self.send_to_server.send_pyobj(req) - # end_time = time.time() - # print(f"[dpclient] put_async task: {req.task_id} created. time: {(end_time - start_time)*1000:.2f}ms") return req.task_id + def put_match( + self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray], + ) -> Optional[Tuple[int, np.ndarray]]: + req = PutMatchRequest(self.dp_client_id, + token_ids, + token_mask if token_mask is not None else None, + self._get_task_id()) + self.send_to_server.send_pyobj(req) + response: Response = self.recv_from_server.recv_pyobj() + if response.error_msg is None: + return response.task_id, response.mask + else: + flexkv_logger.error(f"put_match failed, error_msg: {response.error_msg}") + return None + def get_async( self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor], - ) -> Optional[int]: - # start_time = time.time() + token_ids: np.ndarray, + slot_mapping: np.ndarray, + token_mask: Optional[np.ndarray], + layer_granularity: int, + ) -> int: req = GetRequest(self.dp_client_id, - token_ids.numpy(), - slot_mapping.numpy(), - token_mask.numpy() if token_mask is not None else None, - self._get_task_id()) - + token_ids, + slot_mapping, + token_mask if token_mask is not None else None, + self._get_task_id(), + layer_granularity) self.send_to_server.send_pyobj(req) - # end_time = time.time() - # print(f"[dpclient] get_async task: {req.task_id} created. time: {(end_time - start_time)*1000:.2f}ms") return req.task_id + def get_match( + self, + token_ids: np.ndarray, + token_mask: Optional[np.ndarray], + layer_granularity: int, + ) -> Optional[Tuple[int, np.ndarray]]: + req = GetMatchRequest(self.dp_client_id, + token_ids, + token_mask if token_mask is not None else None, + layer_granularity, + self._get_task_id()) + self.send_to_server.send_pyobj(req) + response: Response = self.recv_from_server.recv_pyobj() + if response.error_msg is None: + return req.task_id, response.mask + else: + flexkv_logger.error(f"get_match failed, error_msg: {response.error_msg}") + return None + + def launch_tasks( + self, + task_ids: List[int], + slot_mappings: List[np.ndarray], + ) -> None: + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings) + self.send_to_server.send_pyobj(req) + + def cancel_task( + self, + task_ids: List[int], + ) -> None: + req = CancelTaskRequest(self.dp_client_id, task_ids) + self.send_to_server.send_pyobj(req) + def wait( self, wait_task_ids: List[int], wait_timeout: float = 20.0, - ) -> Optional[Dict[int, torch.Tensor]]: - req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout) - + completely: bool = False, + ) -> Optional[Dict[int, KVResponse]]: + req = WaitRequest(self.dp_client_id, None, wait_task_ids, wait_timeout, completely) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.masks is not None: - response.masks = {k: torch.from_numpy(v) for k, v in response.masks.items()} - if response.success: - # flexkv_logger.info(f"wait tasks: {wait_task_ids} finished.") - return response.masks + if response.status is not None: + for k, v in response.status.items(): + if v.status != KVResponseStatus.SUCCESS: + flexkv_logger.error(f"wait task {k} failed: {v.status}") + return response.status else: flexkv_logger.error(f"wait tasks: {wait_task_ids} in DP {self.dp_client_id} failed.") return None @@ -142,35 +190,23 @@ def wait( def try_wait( self, try_wait_task_ids: List[int], - ) -> Optional[Dict[int, torch.Tensor]]: + ) -> Optional[Dict[int, KVResponse]]: req = TryWaitRequest(self.dp_client_id, None, try_wait_task_ids) self.send_to_server.send_pyobj(req) response: Response = self.recv_from_server.recv_pyobj() - if response.masks is not None: - response.masks = {k: torch.from_numpy(v) for k, v in response.masks.items()} - if response.success: - # flexkv_logger.info(f"try_wait tasks: {try_wait_task_ids} finished.") - return response.masks + if response.status is not None: + for k, v in response.status.items(): + if v.status != KVResponseStatus.SUCCESS: + flexkv_logger.error(f"try_wait task {k} failed: {v.status}") + return response.status else: flexkv_logger.error(f"try_wait tasks: {try_wait_task_ids} in DP {self.dp_client_id} failed.") return None - def check_running(self) -> bool: - req = CheckRunningRequest(self.dp_client_id) - self.send_to_server.send_pyobj(req) - response: Response = self.recv_from_server.recv_pyobj() - return response.running - def shutdown(self) -> None: req = ShutdownRequest(self.dp_client_id) self.send_to_server.send_pyobj(req) - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"DP client {self.dp_client_id} shutdown successfully.") - else: - flexkv_logger.error(f"DP client {self.dp_client_id} shutdown failed.") - raise class KVTPClient: def __init__( @@ -178,23 +214,17 @@ def __init__( server_recv_port: str, dp_client_id: int, device_id: int, - tp_rank: int, ): # Init inter-process communication context = zmq.Context(2) self.send_to_server = get_zmq_socket( context, zmq.SocketType.PUSH, server_recv_port, False ) - self.client_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - self.recv_from_server = get_zmq_socket( - context, zmq.SocketType.PULL, self.client_recv_port, True - ) self.dp_client_id = dp_client_id self.device_id = device_id - self.tp_rank = tp_rank - flexkv_logger.info(f"KVTPClient {tp_rank} of KVDPClient {self.dp_client_id} Initialized!") + flexkv_logger.info(f"KVTPClient {device_id} of KVDPClient {self.dp_client_id} Initialized!") def register_to_server( self, @@ -214,24 +244,12 @@ def register_to_server( register_req = RegisterTPClientRequest( self.dp_client_id, - self.tp_rank, self.device_id, - self.client_recv_port, handles, kv_layout ) - self.send_to_server.send_pyobj(register_req) - # blocking - response: Response = self.recv_from_server.recv_pyobj() - if response.success: - flexkv_logger.info(f"TP client of DP client {self.dp_client_id} registered successfully!") - else: - flexkv_logger.error( - f"TP client of DP client {self.dp_client_id} registeration fialed: {response.error_msg}" - ) - raise - + self.send_to_server.send_pyobj(register_req, flags=zmq.NOBLOCK) if __name__ == "__main__": diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 532feae047..f6df970b17 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -2,15 +2,16 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import torch from flexkv.common.config import ModelConfig from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout +from flexkv.common.request import KVResponseStatus @dataclass class RegisterDPClientRequest: + dp_client_id: int model_config: ModelConfig client_recv_port: str @@ -18,9 +19,7 @@ class RegisterDPClientRequest: @dataclass class RegisterTPClientRequest: dp_client_id: int - tp_rank: int device_id: int - client_recv_port: str handles: List[TensorSharedHandle] gpu_layout: KVCacheLayout @@ -44,6 +43,33 @@ class GetRequest: slot_mapping: np.ndarray token_mask: Optional[np.ndarray] task_id: int = -1 + layer_granularity: int = -1 + +@dataclass +class PutMatchRequest: + dp_client_id: int + token_ids: np.ndarray + token_mask: Optional[np.ndarray] + task_id: int = -1 + +@dataclass +class GetMatchRequest: + dp_client_id: int + token_ids: np.ndarray + token_mask: Optional[np.ndarray] + layer_granularity: int + task_id: int = -1 + +@dataclass +class LaunchTaskRequest: + dp_client_id: int + task_ids: List[int] + slot_mappings: List[np.ndarray] + +@dataclass +class CancelTaskRequest: + dp_client_id: int + task_ids: List[int] @dataclass class WaitRequest: @@ -51,6 +77,7 @@ class WaitRequest: tp_rank: Optional[int] wait_task_ids: List[int] wait_timeout: float = 20.0 + completely: bool = False # Used for async put/get @dataclass @@ -62,19 +89,25 @@ class TryWaitRequest: @dataclass class Response: - dp_client_id: int + dp_client_id: int = -1 task_id: Optional[int] = None - masks: Optional[Dict[int, np.ndarray]] = None - success: bool = True - running: bool = False - error_msg: str = "" + mask: Optional[Dict[int, np.ndarray]] = None + status: Optional[Dict[int, KVResponseStatus]] = None is_ready: bool = False + error_msg: Optional[str] = None + @property + def success(self) -> bool: + return self.status is not None and \ + all(self.status[task_id] == KVResponseStatus.SUCCESS for task_id in self.status.keys()) @dataclass -class ShutdownRequest: +class StartRequest: dp_client_id: int +@dataclass +class ShutdownRequest: + dp_client_id: int @dataclass class CheckRunningRequest: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 8ea5b91faa..1849c1e304 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -7,12 +7,15 @@ import time import threading from threading import Lock +import multiprocessing as mp +import socket +import os from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.debug import flexkv_logger from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVTaskEngine from flexkv.server.utils import get_zmq_socket from flexkv.server.request import ( RegisterDPClientRequest, @@ -20,27 +23,19 @@ IsReadyRequest, PutRequest, GetRequest, + PutMatchRequest, + GetMatchRequest, + LaunchTaskRequest, + CancelTaskRequest, WaitRequest, TryWaitRequest, Response, + StartRequest, ShutdownRequest, CheckRunningRequest, ) import contextlib - -class TPClient: - def __init__( - self, - send_to_client: zmq.Socket, - tp_rank: int = 0, - device_id: int = 0, - ): - self.tp_rank = tp_rank - self.device_id = device_id - self.send_to_client = send_to_client - - class DPClient: def __init__( self, @@ -50,40 +45,11 @@ def __init__( ): self.client_id = client_id self.tp_size = tp_size - self.tp_client_dict: Dict[int, TPClient] = {} self.send_to_client = send_to_client self.is_ready: bool = False - def register_tp_client( - self, - context: zmq.Context, - client_recv_port: str, - tp_rank: int = 0, - device_id: int = 0, - ) -> None: - if tp_rank in self.tp_client_dict: - flexkv_logger.error(f"TP rank: {tp_rank} in DP client: {self.client_id} has already registered.") - raise - if tp_rank >= self.tp_size: - flexkv_logger.error(f"TP rank: {tp_rank} is larger than TP size of DP client: {self.client_id}.") - raise - - send_to_client = get_zmq_socket( - context, zmq.SocketType.PUSH, client_recv_port, False - ) - - self.tp_client_dict[tp_rank] = TPClient(send_to_client, tp_rank, device_id) - - flexkv_logger.info(f"TP rank: {tp_rank} in DP client: {self.client_id} registered successfully.") - - if len(self.tp_client_dict) == self.tp_size: - self.is_ready = True - flexkv_logger.info(f"All the TP clients in DP client: {self.client_id} has registered. " - f"DP client: {self.client_id} is ready!") - - class ClientManager: def __init__( self, @@ -116,20 +82,6 @@ def register_dp_client( return client_id - def register_tp_client( - self, - context: zmq.Context, - dp_client_id: int, - client_recv_port: str, - tp_rank: int, - device_id: int - ) -> None: - if dp_client_id not in self.client_dict: - flexkv_logger.error(f"DP client: {dp_client_id} has not registered.") - raise - self.client_dict[dp_client_id].register_tp_client( - context, client_recv_port, tp_rank, device_id) - def delete_dp_client(self, client_id: int) -> None: if client_id not in self.client_dict: flexkv_logger.error(f"DP client: {client_id} dosen't exist. Delete failed.") @@ -150,166 +102,127 @@ def is_dp_client_ready(self, dp_client_id: int) -> bool: return self.client_dict[dp_client_id].is_ready return False +class KVServerHandle: + def __init__(self, process: mp.Process): + self.process = process + + def shutdown(self) -> None: + self.process.join(timeout=5) + if self.process.is_alive(): + flexkv_logger.info("force terminate the server process") + self.process.terminate() + self.process.join() + + def __del__(self) -> None: + if self.process.is_alive(): + self.shutdown() class KVServer: def __init__( self, model_config: ModelConfig, cache_config: CacheConfig, - server_recv_port: Optional[str] = None, + gpu_register_port: str, + server_recv_port: str ): # Init inter-process communication self.context = zmq.Context(2) - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" self.recv_from_client = get_zmq_socket( self.context, zmq.SocketType.PULL, server_recv_port, True) self.client_manager = ClientManager(max_num_dp_client=model_config.dp_size) - self.kvmanager = KVManager(model_config, cache_config) - - if self.kvmanager.is_ready(): - flexkv_logger.info("KVManager is ready, starting with worker initialization...") - self.kvmanager.start() + self.kv_task_engine = KVTaskEngine(model_config, cache_config, gpu_register_port, False) self.req_counter = 0 + self._is_ready = False + self._running = False - flexkv_logger.info(f"Server Initialized! [Recv Port]: {server_recv_port}") - # self._running = True - + # Request handler dispatch table + self.request_handlers = { + StartRequest: self._handle_start_request, + RegisterDPClientRequest: self._handle_register_dp_client_request, + IsReadyRequest: self._handle_is_ready_request, + GetRequest: self._handle_get_request, + PutRequest: self._handle_put_request, + GetMatchRequest: self._handle_get_match_request, + PutMatchRequest: self._handle_put_match_request, + WaitRequest: self._handle_wait_request, + LaunchTaskRequest: self._handle_launch_task_request, + CancelTaskRequest: self._handle_cancel_task_request, + TryWaitRequest: self._handle_try_wait_request, + ShutdownRequest: self._handle_shutdown_request, + } + + def is_ready(self) -> bool: + return self._is_ready + + def start_server(self) -> None: + self.kv_task_engine.start() + self._is_ready = True + + @staticmethod + def _server_process(model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str, + server_recv_port: str) -> None: + + server = KVServer(model_config, cache_config, gpu_register_port, server_recv_port) + server.run() + + @classmethod + def create_server(cls, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str, + server_recv_port: Optional[str] = None) -> 'KVServerHandle': + #if server_recv_port is None: + # server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" #TODO unify this + + # Set spawn method for CUDA compatibility + with contextlib.suppress(RuntimeError): + mp.set_start_method("spawn") + process = mp.Process(target=cls._server_process, + args=(model_config, cache_config, gpu_register_port, server_recv_port)) + process.start() + flexkv_logger.info(f"KVServer process started, PID: {process.pid}") + + return KVServerHandle(process) def run(self) -> None: """Main server loop""" # TODO: handle error and return error response # TODO: support check finish + flexkv_logger.info("Servering waiting to be started") + req = self.recv_from_client.recv_pyobj() + if isinstance(req, StartRequest): + flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}, " + f"Starting server...") + self.start_server() + else: + raise TypeError(f"Received RequestType: {type(req)} from DP client " + f"{req.dp_client_id} before the start request") self._running = True while self._running: try: - flexkv_logger.info("start wait for req") + flexkv_logger.info("start waiting for req") req = self.recv_from_client.recv_pyobj() flexkv_logger.info(f"recv req: {type(req)}") - # register dp client - if isinstance(req, RegisterDPClientRequest): - self._verify_model_config(req.model_config) - client_id = self.client_manager.register_dp_client( - self.context, - req.client_recv_port, - req.model_config.tp_size - ) - response = Response(client_id) - result_zmq = self.client_manager.get_zmq(client_id) - result_zmq.send_pyobj(response) - - - elif isinstance(req, RegisterTPClientRequest): - self.client_manager.register_tp_client( - self.context, - req.dp_client_id, - req.client_recv_port, - req.tp_rank, - req.device_id, - ) - - # register GPU Memory - self.kvmanager.register_single_gpu_blocks(req.handles, - req.gpu_layout, - req.dp_client_id, - req.tp_rank) - - response = Response(req.dp_client_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id, req.tp_rank) - result_zmq.send_pyobj(response) - - if self.kvmanager.is_ready(): - flexkv_logger.info("All TP clients registered, starting KVManager...") - self.kvmanager.start() - - elif isinstance(req, IsReadyRequest): - is_ready = self.kvmanager.is_ready() - response = Response(req.dp_client_id, is_ready=is_ready) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, GetRequest): - assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kvmanager.get_async( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - layer_granularity=-1, - dp_id=req.dp_client_id, - task_id=req.task_id, - ) - if req.task_id == -1: - response = Response(req.dp_client_id, req_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, PutRequest): - assert self.client_manager.is_dp_client_ready(req.dp_client_id) - req_id = self.kvmanager.put_async( - token_ids=torch.from_numpy(req.token_ids), - slot_mapping=torch.from_numpy(req.slot_mapping), - token_mask=torch.from_numpy(req.token_mask) if req.token_mask is not None else None, - dp_id=req.dp_client_id, - task_id=req.task_id, - ) - if req.task_id == -1: - response = Response(req.dp_client_id, req_id) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, WaitRequest): - # TODO: support TP client wait - masks = self.kvmanager.wait( - req.wait_task_ids, - timeout=req.wait_timeout, - ) - if masks is not None: - # Convert to numpy arrays for serialization - masks = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in masks.items()} - response = Response(req.dp_client_id, masks=masks) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, TryWaitRequest): - # TODO: support TP client try_wait - masks = self.kvmanager.try_wait( - req.try_wait_task_ids, - ) - if masks is not None: - # Convert to numpy arrays for serialization - masks = {k: v.numpy() if isinstance(v, torch.Tensor) else v for k, v in masks.items()} - response = Response(req.dp_client_id, masks=masks) - result_zmq = self.client_manager.get_zmq( - req.dp_client_id) - result_zmq.send_pyobj(response) - - elif isinstance(req, ShutdownRequest): - flexkv_logger.info(f"Received shutdown request from DP client {req.dp_client_id}") - # Gracefully shutdown the server - self._running = False - # Send response back to client - response = Response(req.dp_client_id, success=True) - result_zmq = self.client_manager.get_zmq(req.dp_client_id) - result_zmq.send_pyobj(response) - break + # Use dispatch table for request handling + req_type = type(req) + handler = self.request_handlers.get(req_type) - elif isinstance(req, CheckRunningRequest): - response = Response(req.dp_client_id, success=True, running=self.kvmanager.is_running()) - result_zmq = self.client_manager.get_zmq(req.dp_client_id) - result_zmq.send_pyobj(response) + if handler is None: + raise TypeError(f"Unrecognized RequestType: {req_type}") - else: - raise TypeError(f"Unregonized RequestType: {type(req)}") + # Call the corresponding handler method + handler(req) + + # If the request is a shutdown request, exit the loop + if req_type == ShutdownRequest: + break except zmq.ZMQError as e: flexkv_logger.error(f"ZMQ Error: {e}", exc_info=True) @@ -329,339 +242,110 @@ def _verify_model_config( # TODO return True - def __del__(self) -> None: - self.kvmanager.shutdown() - - -class SchedulerServer: - """ - Scheduler server that merges the functionality of KVServer and KVDPClient. - Note that this class is ONLY FOR CASES WHEN DP_SIZE = 1. - - This class can: - 1. Directly call KVManager methods to avoid inter-process communication latency - 2. Accept registration requests from TPClient - 3. Provide the same interface as KVDPClient (put_async, get_async, wait, try_wait) - """ - - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - server_recv_port: Optional[str] = None, - ): - self.model_config = model_config - self.cache_config = cache_config - - # Initialize KVManager (similar to KVServer) - self.kvmanager = KVManager(model_config, cache_config) - - # Start KVManager if it's ready (e.g., when no TP clients are needed) - if self.kvmanager.is_ready(): - try: - self.kvmanager.start() - flexkv_logger.info("KVManager started during initialization") - except Exception as e: - flexkv_logger.warning(f"KVManager start failed during initialization: {e}") + # Request Handler Methods - # For TPClient compatibility, we need a server to receive TPClient registration requests - self.context = zmq.Context(2) - if server_recv_port is None: - server_recv_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" - - self.server_recv_port = server_recv_port - self.recv_from_client = get_zmq_socket( - self.context, zmq.SocketType.PULL, server_recv_port, True) - - # Manage TP clients - self.tp_size = model_config.tp_size - self.tp_client_dict: Dict[int, TPClient] = {} - self.is_ready: bool = False - - # DP client related - self.dp_client_id = 0 # Fixed to 0 because we merged scheduler and server - self._task_id_range = (self.dp_client_id * 10000000, (self.dp_client_id + 1) * 10000000) - self._task_id_counter = self._task_id_range[0] - self._task_id_lock = Lock() - - # Server thread control - self._running = False - self._server_thread = None + def _handle_start_request(self, req: StartRequest) -> None: + """Handle start request""" + flexkv_logger.info(f"Received start request from DP client {req.dp_client_id}") - flexkv_logger.info(f"SchedulerServer Initialized! [Recv Port]: {server_recv_port}") - - def _get_task_id(self) -> int: - """Generate unique task ID""" - with self._task_id_lock: - old_value = self._task_id_counter - self._task_id_counter += 1 - if self._task_id_counter >= self._task_id_range[1]: - self._task_id_counter = self._task_id_range[0] - return old_value - - def start_server_thread(self) -> None: - """Start background server thread to handle TPClient requests""" - if self._server_thread is not None and self._server_thread.is_alive(): - flexkv_logger.warning("Server thread is already running") - return - - self._running = True - self._server_thread = threading.Thread(target=self._server_loop, daemon=True) - self._server_thread.start() - flexkv_logger.info("SchedulerServer background thread started") - - def _server_loop(self) -> None: - """Background server loop to handle requests from TPClient""" - while self._running: - try: - # Set non-blocking receive to allow checking _running status - try: - req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) - except zmq.Again: - time.sleep(0.001) # Brief sleep to avoid busy waiting - continue - - flexkv_logger.info(f"SchedulerServer received request: {type(req)}") - - if isinstance(req, RegisterTPClientRequest): - self._handle_tp_registration(req) - elif isinstance(req, ShutdownRequest): - flexkv_logger.info("Received shutdown request from TP client") - response = Response(req.dp_client_id, success=True) - # Since we don't know which TP client sent the shutdown request, - # we send response to all registered TP clients - self._running = False - for tp_client in self.tp_client_dict.values(): - tp_client.send_to_client.send_pyobj(response) - break - else: - flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}") - - except zmq.ZMQError as e: - if e.errno == zmq.ETERM: - break # Context terminated - flexkv_logger.error(f"ZMQ Error in SchedulerServer: {e}", exc_info=True) - except Exception as e: - flexkv_logger.error(f"Error in SchedulerServer: {e}", exc_info=True) - time.sleep(0.0001) - - flexkv_logger.info("SchedulerServer background thread stopped") - - def _handle_tp_registration(self, req: RegisterTPClientRequest) -> None: - """Handle TP Client registration request""" - tp_rank = req.tp_rank - - if tp_rank in self.tp_client_dict: - flexkv_logger.error(f"TP rank: {tp_rank} has already registered.") - response = Response(req.dp_client_id, success=False, - error_msg=f"TP rank {tp_rank} already registered") - elif tp_rank >= self.tp_size: - flexkv_logger.error(f"TP rank: {tp_rank} is larger than TP size: {self.tp_size}.") - response = Response(req.dp_client_id, success=False, - error_msg=f"TP rank {tp_rank} exceeds TP size {self.tp_size}") - else: - try: - # Create connection to TP client - send_to_client = get_zmq_socket( - self.context, zmq.SocketType.PUSH, req.client_recv_port, False - ) - - self.tp_client_dict[tp_rank] = TPClient(send_to_client, tp_rank, req.device_id) - - # Register GPU Memory to KVManager - self.kvmanager.register_single_gpu_blocks( - req.handles, - req.gpu_layout, - self.dp_client_id, # Use fixed dp_client_id = 0 - req.tp_rank - ) - - flexkv_logger.info(f"TP rank: {tp_rank} registered successfully.") - - # Check if all TP clients have registered - if len(self.tp_client_dict) == self.tp_size: - self.is_ready = True - # Always start kvmanager when all TP clients are registered - try: - flexkv_logger.info("All TP clients registered, starting KVManager...") - self.kvmanager.start() - flexkv_logger.info("KVManager started successfully") - except Exception as e: - flexkv_logger.warning(f"KVManager start failed or already started: {e}") - flexkv_logger.info("All TP clients registered. SchedulerServer is ready!") - - response = Response(req.dp_client_id, success=True) - - except Exception as e: - flexkv_logger.error(f"Failed to register TP client {tp_rank}: {e}") - response = Response(req.dp_client_id, success=False, error_msg=str(e)) + def _handle_register_dp_client_request(self, req: RegisterDPClientRequest) -> None: + """Handle DP client registration request""" + self._verify_model_config(req.model_config) + client_id = self.client_manager.register_dp_client( + self.context, + req.client_recv_port, + req.model_config.tp_size + ) + flexkv_logger.info(f"DP client {client_id} registered successfully") - # Send response to TP client - if tp_rank in self.tp_client_dict: - self.tp_client_dict[tp_rank].send_to_client.send_pyobj(response) + def _handle_is_ready_request(self, req: IsReadyRequest) -> None: + """Handle ready state check request""" + is_ready = self.kv_task_engine.is_ready() + response = Response(req.dp_client_id, is_ready=is_ready) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_get_request(self, req: GetRequest) -> None: + """Handle Get request""" + req_id = self.kv_task_engine.get_async( + task_id=req.task_id, + token_ids=req.token_ids, + slot_mapping=req.slot_mapping, + token_mask=req.token_mask, + layer_granularity=req.layer_granularity, + dp_id=req.dp_client_id, + ) - def put_async( - self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, - ) -> Optional[int]: - """ - Asynchronous PUT operation, directly calling KVManager (no network communication required) - - Args: - token_ids: Token IDs tensor - slot_mapping: Slot mapping tensor - token_mask: Optional token mask tensor - - Returns: - Task ID if successful, None otherwise - """ - start_time = time.time() - - if not self.is_ready: - flexkv_logger.error("SchedulerServer is not ready (not all TP clients registered)") - return None - - try: - task_id = self._get_task_id() - req_id = self.kvmanager.put_async( - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - dp_id=self.dp_client_id, - task_id=task_id, - ) - - end_time = time.time() - flexkv_logger.info(f"[SchedulerServer] put_async task: {task_id} created. " - f"time: {(end_time - start_time)*1000:.2f}ms") - return task_id - - except Exception as e: - flexkv_logger.error(f"put_async failed: {e}") - return None - - def get_async( - self, - token_ids: torch.Tensor, - slot_mapping: torch.Tensor, - token_mask: Optional[torch.Tensor] = None, - ) -> Optional[int]: - """ - Asynchronous GET operation, directly calling KVManager (no network communication required) - - Args: - token_ids: Token IDs tensor - slot_mapping: Slot mapping tensor - token_mask: Optional token mask tensor - - Returns: - Task ID if successful, None otherwise - """ - start_time = time.time() - - if not self.is_ready: - flexkv_logger.error("SchedulerServer is not ready (not all TP clients registered)") - return None - - try: - task_id = self._get_task_id() - req_id = self.kvmanager.get_async( - token_ids=token_ids, - slot_mapping=slot_mapping, - token_mask=token_mask, - layer_granularity=-1, - dp_id=self.dp_client_id, - task_id=task_id, - ) - - end_time = time.time() - flexkv_logger.info(f"[SchedulerServer] get_async task: {task_id} created. " - f"time: {(end_time - start_time)*1000:.2f}ms") - return task_id - - except Exception as e: - flexkv_logger.error(f"get_async failed: {e}") - return None - - def wait( - self, - wait_task_ids: List[int], - wait_timeout: float = 20.0, - ) -> Optional[Dict[int, torch.Tensor]]: - """ - Wait for specified tasks to complete, directly calling KVManager (no network communication required) - - Args: - wait_task_ids: List of task IDs to wait for - - Returns: - Dictionary mapping task IDs to result masks, None if failed - """ - try: - masks = self.kvmanager.wait(wait_task_ids, timeout=wait_timeout) - flexkv_logger.info(f"[SchedulerServer] wait tasks: {wait_task_ids} finished.") - return masks - - except Exception as e: - flexkv_logger.error(f"wait failed: {e}") - return None - - def try_wait( - self, - try_wait_task_ids: List[int], - ) -> Optional[Dict[int, torch.Tensor]]: - """ - Non-blocking wait for specified tasks, directly calling KVManager (no network communication required) - - Args: - try_wait_task_ids: List of task IDs to try waiting for - - Returns: - Dictionary mapping task IDs to result masks, None if not ready or failed - """ - try: - masks = self.kvmanager.try_wait(try_wait_task_ids) - if masks is not None: - flexkv_logger.info(f"[SchedulerServer] try_wait tasks: {try_wait_task_ids} finished.") - return masks - - except Exception as e: - flexkv_logger.error(f"try_wait failed: {e}") - return None - - def check_running(self) -> bool: - return self.kvmanager.is_running() + def _handle_put_request(self, req: PutRequest) -> None: + """Handle Put request""" + req_id = self.kv_task_engine.put_async( + token_ids=req.token_ids, + slot_mapping=req.slot_mapping, + token_mask=req.token_mask, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) - def shutdown(self) -> None: - """Shutdown SchedulerServer""" - flexkv_logger.info("Shutting down SchedulerServer...") + def _handle_get_match_request(self, req: GetMatchRequest) -> None: + """Handle GetMatch request""" + req_id, mask = self.kv_task_engine.get_match( + token_ids=req.token_ids, + token_mask=req.token_mask, + layer_granularity=req.layer_granularity, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_put_match_request(self, req: PutMatchRequest) -> None: + """Handle PutMatch request""" + req_id, mask = self.kv_task_engine.put_match( + token_ids=req.token_ids, + token_mask=req.token_mask, + dp_id=req.dp_client_id, + task_id=req.task_id, + ) + response = Response(req.dp_client_id, task_id=req_id, mask=mask) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: + """Handle LaunchTask request""" + self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings) + + def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: + """Handle CancelTask request""" + self.kv_task_engine.cancel_tasks(req.task_ids) + + def _handle_wait_request(self, req: WaitRequest) -> None: + """Handle Wait request""" + kv_responses = self.kv_task_engine.wait( + req.wait_task_ids, + timeout=req.wait_timeout, + completely=req.completely, + ) + response = Response(req.dp_client_id, status=kv_responses) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) + + def _handle_try_wait_request(self, req: TryWaitRequest) -> None: + """Handle TryWait request""" + kv_responses = self.kv_task_engine.try_wait( + req.try_wait_task_ids, + ) + response = Response(req.dp_client_id, status=kv_responses) + result_zmq = self.client_manager.get_zmq(req.dp_client_id) + result_zmq.send_pyobj(response) - # Stop server thread + def _handle_shutdown_request(self, req: ShutdownRequest) -> None: + """Handle shutdown request""" + flexkv_logger.info(f"Received shutdown request from DP client {req.dp_client_id}") self._running = False - if self._server_thread is not None and self._server_thread.is_alive(): - self._server_thread.join(timeout=5.0) - - # Shutdown KVManager - if hasattr(self, 'kvmanager'): - self.kvmanager.shutdown() - - # Close ZMQ context - #if hasattr(self, 'context'): - # self.context.term() - - flexkv_logger.info("SchedulerServer shutdown complete") - - def get_server_port(self) -> str: - """Get server receive port for TPClient to use""" - return self.server_recv_port def __del__(self) -> None: - """Destructor""" - with contextlib.suppress(Exception): - self.shutdown() - + self.kv_task_engine.shutdown() if __name__ == "__main__": import torch @@ -694,7 +378,6 @@ def __del__(self) -> None: enable_ssd=False, enable_remote=False, use_gds=False, - use_pinned_memory=True, tokens_per_block=tokens_per_block, num_cpu_blocks=num_cpu_blocks,) diff --git a/flexkv/storage/allocator.py b/flexkv/storage/allocator.py index 7cd38156e0..ed683e6505 100644 --- a/flexkv/storage/allocator.py +++ b/flexkv/storage/allocator.py @@ -95,7 +95,6 @@ def allocate(cls, layout: KVCacheLayout, dtype: torch.dtype, **kwargs: Any) -> StorageHandle: - pin_memory = kwargs.get("pin_memory", True) total_size = layout.get_total_elements() # although the kv layout may have multiple dimensions, we only have one-dim CPU tensor flexkv_logger.info(f"CPU allocate total_size: {2 * total_size/1024/1024/1024} GB") @@ -103,7 +102,7 @@ def allocate(cls, size=(total_size,), dtype=dtype, device="cpu", - pin_memory=pin_memory, + pin_memory=False, ) return StorageHandle( handle_type=AccessHandleType.TENSOR, diff --git a/flexkv/storage/storage_engine.py b/flexkv/storage/storage_engine.py index 0762b0062d..0d48fe6230 100644 --- a/flexkv/storage/storage_engine.py +++ b/flexkv/storage/storage_engine.py @@ -35,7 +35,6 @@ def __init__(self, device_type=DeviceType.CPU, layout=self._cpu_layout, dtype=self._model_config.dtype, - pin_memory=self._cache_config.use_pinned_memory, ) if self._cache_config.enable_ssd: if not self._cache_config.ssd_kv_layout_type == self._cpu_layout.type: diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 30516e6c3f..1b4036cbe3 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -37,8 +37,19 @@ tpGPUCPUTransferWorker, ) from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.ring_buffer import SharedOpPool +def register_op_to_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + op.src_slot_id = pin_buffer.allocate_slot(op.src_block_ids) + op.dst_slot_id = pin_buffer.allocate_slot(op.dst_block_ids) + +def free_op_from_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + if op.src_slot_id != -1: + pin_buffer.free_slot(op.src_slot_id) + if op.dst_slot_id != -1: + pin_buffer.free_slot(op.dst_slot_id) + class TransferEngine: def __init__(self, gpu_handles: List[StorageHandle], @@ -71,6 +82,8 @@ def __init__(self, self._remote_handle = remote_handle self._cache_config = cache_config + self.pin_buffer = SharedOpPool(2048, self.model_config.max_req_tokens // self.cache_config.tokens_per_block) + self.op_id_to_nvtx_range: Dict[int, str] = {} self.dp_size = model_config.dp_size @@ -88,8 +101,8 @@ def _init_workers(self) -> None: if self.tp_size == 1: self.gpucpu_workers: List[WorkerHandle] = [ GPUCPUTransferWorker.create_worker( - worker_id=i, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), gpu_blocks=self.gpu_handles[i].get_tensor_handle_list(), cpu_blocks=self._cpu_handle.get_tensor(), gpu_kv_layout=self.gpu_handles[i].kv_layout, @@ -106,12 +119,12 @@ def _init_workers(self) -> None: else: self.gpucpu_workers = [ tpGPUCPUTransferWorker.create_worker( - worker_id=i, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ for j in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_blocks=self._cpu_handle.get_tensor(), - gpu_kv_layout=self.gpu_handles[i].kv_layout, + gpu_kv_layouts=[self.gpu_handles[i].kv_layout for i in range(i * self.tp_size, (i + 1) * self.tp_size)], cpu_kv_layout=self._cpu_handle.kv_layout, dtype=self.gpu_handles[i].dtype, tp_group_size=self.tp_size, @@ -128,8 +141,8 @@ def _init_workers(self) -> None: if self._ssd_handle is not None and self._cpu_handle is not None: self.cpussd_read_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - worker_id=10, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -139,8 +152,8 @@ def _init_workers(self) -> None: cache_config=self._cache_config, ) self.cpussd_write_worker: WorkerHandle = CPUSSDDiskTransferWorker.create_worker( - worker_id=11, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), ssd_files=self._ssd_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -153,8 +166,8 @@ def _init_workers(self) -> None: self._worker_map[TransferType.DISK2H] = self.cpussd_read_worker if self._remote_handle is not None and self._cpu_handle is not None: self.remotecpu_read_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( - worker_id=20, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -163,8 +176,8 @@ def _init_workers(self) -> None: remote_config_custom=self._remote_handle.remote_config_custom, ) self.remotecpu_write_worker: WorkerHandle = CPURemoteTransferWorker.create_worker( - worker_id=21, finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor = self.pin_buffer.get_buffer(), cpu_blocks=self._cpu_handle.get_tensor(), remote_file=self._remote_handle.get_file_list(), cpu_kv_layout=self._cpu_handle.kv_layout, @@ -177,18 +190,19 @@ def _init_workers(self) -> None: if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") # Wait for all workers to ready - for worker in self._worker_map.values(): + for transfer_type, worker in self._worker_map.items(): if isinstance(worker, List): for w in worker: - w.ready_event.wait(timeout=60) + flexkv_logger.info(f"waiting for {transfer_type.name} worker {w.worker_id} to ready") + w.ready_event.wait() + flexkv_logger.info(f"{transfer_type.name} worker {w.worker_id} is ready") else: - flexkv_logger.info(f"waiting for worker {worker} to ready") - worker.ready_event.wait(timeout=60) - flexkv_logger.info(f"worker {worker} is ready") + flexkv_logger.info(f"waiting for {transfer_type.name} worker {worker.worker_id} to ready") + worker.ready_event.wait() + flexkv_logger.info(f"{transfer_type.name} worker {worker.worker_id} is ready") # Start scheduler thread self._running = True self._scheduler_thread = threading.Thread(target=self._scheduler_loop) - flexkv_logger.info("TransferEngine initialized and running") self._scheduler_thread.start() def start(self) -> None: @@ -212,6 +226,7 @@ def _scheduler_loop(self) -> None: try: op_id = self.finished_ops_queue.get_nowait() op = self.op_id_to_op[op_id] + free_op_from_buffer(op, self.pin_buffer) self.completed_queue.put((op.graph_id, op.op_id)) finished_ops.append(op) del self.op_id_to_op[op_id] @@ -229,6 +244,8 @@ def _scheduler_loop(self) -> None: self.completed_queue.put((op.graph_id, op.op_id)) else: self.op_id_to_op[op.op_id] = op + # copy block ids into buffer and update slot id info + register_op_to_buffer(op, self.pin_buffer) self._assign_op_to_worker(op) # Handle completed graphs for graph_id in completed_graph_ids: @@ -286,6 +303,8 @@ def get_completed_graphs_and_ops(self, timeout: Optional[float] = None) -> List[ def shutdown(self) -> None: """Shutdown the transfer engine""" try: + if not self._running: + return self._running = False self._scheduler_thread.join(timeout=5) diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index 4d72a0fb93..2dbbc1be86 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -1,10 +1,10 @@ import copy -import multiprocessing as mp +import torch.multiprocessing as mp import threading import time from abc import ABC, abstractmethod from dataclasses import dataclass -from multiprocessing import Queue as MPQueue, Pipe as MPPipe +from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe from multiprocessing.connection import Connection from threading import Thread from typing import List, Any, Dict, Union, Optional @@ -51,31 +51,55 @@ class WorkerTransferOp: transfer_op_id: int transfer_graph_id: int transfer_type: TransferType - src_block_ids: np.ndarray - dst_block_ids: np.ndarray layer_id: int layer_granularity: int + src_slot_id: int + dst_slot_id: int + valid_block_num: int + src_block_ids: np.ndarray + dst_block_ids: np.ndarray # successors: List[int] def __init__(self, transfer_op: TransferOp): self.transfer_op_id = transfer_op.op_id self.transfer_graph_id = transfer_op.graph_id self.transfer_type = transfer_op.transfer_type - self.src_block_ids = transfer_op.src_descriptor.physical_block_ids.numpy() - self.dst_block_ids = transfer_op.dst_descriptor.physical_block_ids.numpy() self.layer_id = transfer_op.layer_id self.layer_granularity = transfer_op.layer_granularity + self.src_slot_id = transfer_op.src_slot_id + self.dst_slot_id = transfer_op.dst_slot_id + self.valid_block_num = transfer_op.valid_block_num + if self.src_slot_id == -1: + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids + else: + self.src_block_ids = np.empty(0) + self.dst_block_ids = np.empty(0) # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): + _worker_id_counter = 0 + _worker_id_lock = threading.Lock() + def __init__(self, worker_id: int, transfer_conn: Connection, # receive end of pipe - finished_ops_queue: MPQueue): + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor): self.worker_id = worker_id self.transfer_conn = transfer_conn # receive end of pipe self.finished_ops_queue: MPQueue[int] = finished_ops_queue + self.op_buffer_tensor = op_buffer_tensor + cudaHostRegister(self.op_buffer_tensor) + + @classmethod + def _get_worker_id(cls) -> int: + with cls._worker_id_lock: + worker_id = cls._worker_id_counter + cls._worker_id_counter += 1 + return worker_id + def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: if isinstance(layer_blocks, torch.Tensor): layer_blocks = [layer_blocks] @@ -90,14 +114,18 @@ def _get_layer_ptrs(self, layer_blocks: Union[List[torch.Tensor], torch.Tensor]) return layer_ptrs @classmethod - def create_worker(cls, worker_id: int, finished_ops_queue: MPQueue, *args: Any, **kwargs: Any) -> 'WorkerHandle': + def create_worker(cls, + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, + *args: Any, **kwargs: Any) -> 'WorkerHandle': """Generic worker creation template method""" parent_conn, child_conn = MPPipe() # create pipe ready_event = mp.Event() + worker_id = cls._get_worker_id() process = mp.Process( target=cls._worker_process, - args=(worker_id, child_conn, finished_ops_queue, ready_event, *args), + args=(worker_id, child_conn, finished_ops_queue, op_buffer_tensor, ready_event, *args), kwargs=kwargs, daemon=True ) @@ -107,16 +135,16 @@ def create_worker(cls, worker_id: int, finished_ops_queue: MPQueue, *args: Any, @classmethod def _worker_process(cls, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, - ready_event: Any, *args: Any, **kwargs: Any) -> None: - worker = cls(worker_id, transfer_conn, finished_ops_queue, *args, **kwargs) + op_buffer_tensor: torch.Tensor, ready_event: Any, *args: Any, **kwargs: Any) -> None: + worker = cls(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor, *args, **kwargs) ready_event.set() worker.run() @abstractmethod def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, @@ -124,6 +152,35 @@ def _transfer_impl( ) -> None: pass + def get_transfer_block_ids(self, + transfer_op: WorkerTransferOp, + pinned: bool = True) ->tuple[torch.Tensor, torch.Tensor]: + """ + Get transfer block ids from op buffer tensor or directly from op + Args: + transfer_op: WorkerTransferOp + pinned: whether to pin the block ids tensor + Returns: + tuple[torch.Tensor, torch.Tensor]: src_block_ids and dst_block_ids + """ + src_slot_id = transfer_op.src_slot_id + dst_slot_id = transfer_op.dst_slot_id + valid_block_num = transfer_op.valid_block_num + if src_slot_id == -1: + src_block_ids = torch.from_numpy(transfer_op.src_block_ids).to(dtype=torch.int64) + if pinned: + src_block_ids = src_block_ids.pin_memory() + else: + src_block_ids = self.op_buffer_tensor[src_slot_id, :valid_block_num] + if dst_slot_id == -1: + dst_block_ids = torch.from_numpy(transfer_op.dst_block_ids).to(dtype=torch.int64) + if pinned: + dst_block_ids = dst_block_ids.pin_memory() + else: + dst_block_ids = self.op_buffer_tensor[dst_slot_id, :valid_block_num] + + return src_block_ids, dst_block_ids + def _log_transfer_performance(self, transfer_op: WorkerTransferOp, transfer_size: int, @@ -159,7 +216,7 @@ def run(self) -> None: try: nvtx.push_range(f"launch {op.transfer_type.name} op_id: {op.transfer_op_id}, " f"graph_id: {op.transfer_graph_id}, " - f"num_blocks: {len(op.src_block_ids)}", + f"num_blocks: {op.valid_block_num}", color=get_nvtx_range_color(op.transfer_graph_id)) self.launch_transfer(op) nvtx.pop_range() @@ -180,7 +237,7 @@ class WorkerHandle: """handle for worker process""" def __init__(self, worker_id: int, transfer_conn: Connection, process: mp.Process, ready_event: Any): self.worker_id = worker_id - self.transfer_conn = transfer_conn # send end of pipe + self.transfer_conn = transfer_conn self.process = process self.ready_event = ready_event @@ -209,6 +266,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, gpu_blocks: List[TensorSharedHandle], cpu_blocks: torch.Tensor, gpu_kv_layout: KVCacheLayout, @@ -220,7 +278,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8) -> None: # initialize worker in a new process - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) # Register CPU tensors with CUDA cudaHostRegister(cpu_blocks) self.gpu_blocks = [wrapper.get_tensor() for wrapper in gpu_blocks] @@ -258,33 +316,30 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, ) -> None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2D: - gpu_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + gpu_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids use_ce_transfer = self.use_ce_transfer_h2d transfer_sms = self.transfer_sms_h2d elif transfer_type == TransferType.D2H: - gpu_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + gpu_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids use_ce_transfer = self.use_ce_transfer_d2h transfer_sms = self.transfer_sms_d2h else: raise ValueError(f"Invalid transfer type: {transfer_type} for GPUCPUTransferWorker") - gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() - assert len(gpu_block_id_list) == len(cpu_block_id_list) if len(gpu_block_id_list) == 0: @@ -320,11 +375,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + with torch.cuda.stream(self.transfer_stream): start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, layer_id, layer_granularity, @@ -332,7 +389,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -346,9 +403,10 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, gpu_blocks: List[List[TensorSharedHandle]], cpu_blocks: torch.Tensor, - gpu_kv_layout: KVCacheLayout, + gpu_kv_layouts: List[KVCacheLayout], cpu_kv_layout: KVCacheLayout, dtype: torch.dtype, tp_group_size: int, @@ -358,7 +416,7 @@ def __init__(self, transfer_sms_h2d: int = 8, transfer_sms_d2h: int = 8): - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) assert len(gpu_blocks) == tp_group_size # Handle tensor import for multi-process case imported_gpu_blocks = [] @@ -369,7 +427,7 @@ def __init__(self, imported_gpu_blocks.append(blocks_in_one_gpu) self.gpu_blocks = imported_gpu_blocks self.dtype = dtype - self.is_mla = gpu_kv_layout.is_mla + self.is_mla = gpu_kv_layouts[0].is_mla self.num_gpus = len(self.gpu_blocks) self.tp_group_size = tp_group_size @@ -377,19 +435,22 @@ def __init__(self, cudaHostRegister(cpu_blocks) - self.num_layers = gpu_kv_layout.num_layer - gpu_kv_layout_per_layer = gpu_kv_layout.div_layer(self.num_layers) + self.num_layers = gpu_kv_layouts[0].num_layer + gpu_kv_layouts_per_layer = [gpu_kv_layout.div_layer(self.num_layers) for gpu_kv_layout in gpu_kv_layouts] - self.gpu_chunk_size_in_bytes = gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize - self.gpu_kv_stride_in_bytes = gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize - self.gpu_block_stride_in_bytes = gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout_per_layer.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout_per_layer.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] + self.gpu_block_strides_in_bytes = [gpu_kv_layout_per_layer.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout_per_layer in gpu_kv_layouts_per_layer] self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize - if not gpu_kv_layout.type == KVCacheLayoutType.LAYERWISE: + if not gpu_kv_layouts[0].type == KVCacheLayoutType.LAYERWISE: raise ValueError("Only layerwise layout is supported for GPU") self.transfer_sms_h2d = transfer_sms_h2d @@ -397,46 +458,47 @@ def __init__(self, self.use_ce_transfer_h2d = use_ce_transfer_h2d self.use_ce_transfer_d2h = use_ce_transfer_d2h - self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id) + gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) + gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) + gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) + + self.tp_transfer_thread_group = TPTransferThreadGroup(self.num_gpus, self.gpu_blocks, cpu_blocks, dp_group_id, + gpu_kv_strides_tensor, gpu_block_strides_tensor, gpu_chunk_sizes_tensor) + def _transfer_impl(self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, )->None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2D: - gpu_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + gpu_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids use_ce_transfer = self.use_ce_transfer_h2d transfer_sms = self.transfer_sms_h2d elif transfer_type == TransferType.D2H: - gpu_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + gpu_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids use_ce_transfer = self.use_ce_transfer_d2h transfer_sms = self.transfer_sms_d2h else: raise ValueError(f"Invalid transfer type: {transfer_type} for tpGPUCPUTransferWorker") - gpu_block_id_list = torch.from_numpy(gpu_block_ids).to(dtype=torch.int64).pin_memory() - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64).pin_memory() assert len(gpu_block_id_list) == len(cpu_block_id_list) if len(gpu_block_id_list) == 0: return - + self.tp_transfer_thread_group.tp_group_transfer( gpu_block_id_list, - self.gpu_kv_stride_in_bytes, - self.gpu_block_stride_in_bytes, - self.gpu_chunk_size_in_bytes, cpu_block_id_list, self.cpu_kv_stride_in_bytes, self.cpu_layer_stride_in_bytes, @@ -459,10 +521,12 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: if layer_granularity == -1: layer_granularity = self.num_layers + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, layer_id, layer_granularity, @@ -470,7 +534,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.cpu_chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -484,6 +548,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, cpu_blocks: torch.Tensor, ssd_files: Dict[int, List[str]], # ssd_device_id -> file_paths cpu_kv_layout: KVCacheLayout, @@ -491,7 +556,7 @@ def __init__(self, dtype: torch.dtype, num_blocks_per_file: int, cache_config: CacheConfig): - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.ssd_files = ssd_files self.num_blocks_per_file = num_blocks_per_file self.num_files = sum(len(file_list) for file_list in ssd_files.values()) @@ -528,30 +593,26 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any, ) -> None: - assert src_block_ids.dtype == np.int64 - assert dst_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if transfer_type == TransferType.H2DISK: - ssd_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + ssd_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids elif transfer_type == TransferType.DISK2H: - ssd_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + ssd_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids else: raise ValueError(f"Invalid transfer type: {transfer_type} for CPUSSDDiskTransferWorker") - # this means partial read hit cpu and other hit ssd - # or partial write hit ssd and none hit cpu - ssd_block_id_list = torch.from_numpy(ssd_block_ids).to(dtype=torch.int64) - cpu_block_id_list = torch.from_numpy(cpu_block_ids).to(dtype=torch.int64) layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) @@ -581,10 +642,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers + + src_block_ids , dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, transfer_op.layer_id, transfer_op.layer_granularity, @@ -592,7 +656,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, @@ -606,6 +670,7 @@ def __init__(self, worker_id: int, transfer_conn: Connection, finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, cpu_blocks: List[torch.Tensor], remote_file: List[str], cpu_kv_layout: KVCacheLayout, @@ -614,7 +679,7 @@ def __init__(self, remote_config_custom: Dict[str, Any]): if transfer_kv_blocks_remote is None: raise RuntimeError("transfer_kv_blocks_remote not available, please build with FLEXKV_ENABLE_CFS=1") - super().__init__(worker_id, transfer_conn, finished_ops_queue) + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) self.cpu_layer_ptrs = self._get_layer_ptrs(cpu_blocks) self.remote_files = remote_file @@ -687,15 +752,15 @@ def __init__(self, def _transfer_impl( self, - src_block_ids: np.ndarray, - dst_block_ids: np.ndarray, + src_block_ids: torch.Tensor, + dst_block_ids: torch.Tensor, transfer_type: TransferType, layer_id: int, layer_granularity: int, **kwargs: Any ) -> None: - assert dst_block_ids.dtype == np.int64 - assert src_block_ids.dtype == np.int64 + assert src_block_ids.dtype == torch.int64 + assert dst_block_ids.dtype == torch.int64 assert len(src_block_ids) == len(dst_block_ids) if layer_id == -1: @@ -707,17 +772,14 @@ def _transfer_impl( # or partial write hit remote and none hit cpu if transfer_type == TransferType.H2REMOTE: - remote_block_ids = dst_block_ids - cpu_block_ids = src_block_ids + remote_block_id_list = dst_block_ids + cpu_block_id_list = src_block_ids elif transfer_type == TransferType.REMOTE2H: - remote_block_ids = src_block_ids - cpu_block_ids = dst_block_ids + remote_block_id_list = src_block_ids + cpu_block_id_list = dst_block_ids else: raise ValueError(f"Invalid transfer type: {transfer_type} for CPUSSDDiskTransferWorker") - remote_block_id_list = torch.from_numpy(remote_block_ids).pin_memory().to(dtype=torch.int64) - cpu_block_id_list = torch.from_numpy(cpu_block_ids).pin_memory().to(dtype=torch.int64) - layer_id_list = torch.arange(layer_id, layer_id + layer_granularity, dtype=torch.int32) transfer_kv_blocks_remote( file_nodeid_list=self.file_nodeid_list, @@ -748,10 +810,13 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: layer_id = 0 if layer_granularity == -1: layer_granularity = self.num_layers + + src_block_ids, dst_block_ids = self.get_transfer_block_ids(transfer_op) + start_time = time.time() self._transfer_impl( - transfer_op.src_block_ids, - transfer_op.dst_block_ids, + src_block_ids, + dst_block_ids, transfer_op.transfer_type, transfer_op.layer_id, transfer_op.layer_granularity, @@ -759,7 +824,7 @@ def launch_transfer(self, transfer_op: WorkerTransferOp) -> None: end_time = time.time() kv_dim = 2 if not self.is_mla else 1 - transfer_size = self.chunk_size_in_bytes * layer_granularity * len(transfer_op.src_block_ids) * kv_dim + transfer_size = self.chunk_size_in_bytes * layer_granularity * transfer_op.valid_block_num * kv_dim self._log_transfer_performance( transfer_op, diff --git a/flexkv/transfer_manager.py b/flexkv/transfer_manager.py new file mode 100644 index 0000000000..fab65f345c --- /dev/null +++ b/flexkv/transfer_manager.py @@ -0,0 +1,322 @@ +import multiprocessing as mp +import time +import queue +from queue import Queue +from typing import Dict, Optional, List, Tuple +from abc import ABC, abstractmethod +from multiprocessing import Process, Pipe, Event +import zmq +import tempfile +import threading +import numpy as np + +from flexkv.common.transfer import TransferOpGraph +from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.debug import flexkv_logger +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.common.transfer import DeviceType +from flexkv.common.storage import KVCacheLayout +from flexkv.storage.storage_engine import StorageEngine +from flexkv.transfer.transfer_engine import TransferEngine +from flexkv.server.utils import get_zmq_socket +from flexkv.server.request import RegisterTPClientRequest, Response + + +class TransferManager: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + + self.all_gpu_layouts: Dict[int, KVCacheLayout] = {} + self.all_gpu_blocks: Dict[int, List[TensorSharedHandle]] = {} # device_id -> gpu_blocks + + self.context = zmq.Context(2) + self.recv_from_client = get_zmq_socket( + self.context, zmq.SocketType.PULL, gpu_register_port, True) + + self.transfer_engine: Optional[TransferEngine] = None + self.storage_engine = StorageEngine(self.model_config, self.cache_config) + + def _handle_gpu_blocks_registration(self, req: RegisterTPClientRequest) -> None: + device_id = req.device_id + + if device_id in self.all_gpu_blocks: + flexkv_logger.error(f"GPU {device_id} has already registered.") + elif device_id >= self.model_config.tp_size * self.model_config.dp_size: + flexkv_logger.error(f"GPU {device_id} is larger than TP size: " + f"{self.model_config.tp_size * self.model_config.dp_size}.") + else: + try: + self.all_gpu_blocks[device_id] = req.handles + self.all_gpu_layouts[device_id] = req.gpu_layout + flexkv_logger.info(f"GPU {device_id} registered successfully") + except Exception as e: + flexkv_logger.error(f"Failed to register GPU {device_id}: {e}") + + + def _register_gpu_blocks_via_socket(self) -> None: + try: + flexkv_logger.info(f"GPU tensor registration server started on port {self.gpu_register_port}") + + expected_gpus = self.model_config.tp_size * self.model_config.dp_size + + while len(self.all_gpu_blocks) < expected_gpus: + try: + req = self.recv_from_client.recv_pyobj(zmq.NOBLOCK) + except zmq.Again: + time.sleep(0.001) + continue + + if isinstance(req, RegisterTPClientRequest): + flexkv_logger.info(f"Received GPU blocks registration request: {type(req)}") + self._handle_gpu_blocks_registration(req) + else: + flexkv_logger.error(f"Unrecognized RequestType in SchedulerServer: {type(req)}") + + flexkv_logger.info(f"All {expected_gpus} GPUs registered successfully") + + except Exception as e: + flexkv_logger.error(f"Error in GPU registration server: {e}") + raise + finally: + pass + # TODO: fix the socket close issue + # self.recv_from_client.close() + # self.context.term() + + def initialize_transfer_engine(self) -> None: + self._register_gpu_blocks_via_socket() + + assert len(self.all_gpu_layouts) == self.model_config.tp_size * self.model_config.dp_size + assert len(self.all_gpu_blocks) == self.model_config.tp_size * self.model_config.dp_size + for device_id, gpu_blocks_wrapper in self.all_gpu_blocks.items(): + self.storage_engine.register_gpu_blocks(gpu_blocks_wrapper, + self.all_gpu_layouts[device_id], + device_id, + dtype=self.model_config.dtype) + self.gpu_handles = [ + self.storage_engine.get_storage_handle(DeviceType.GPU, i) + for i in range(self.model_config.tp_size * self.model_config.dp_size) + ] + cpu_handle = self.storage_engine.get_storage_handle(DeviceType.CPU) \ + if self.cache_config.enable_cpu else None + ssd_handle = self.storage_engine.get_storage_handle(DeviceType.SSD) \ + if self.cache_config.enable_ssd else None + remote_handle = ( + self.storage_engine.get_storage_handle(DeviceType.REMOTE) \ + if self.cache_config.enable_remote \ + else None + ) + self.transfer_engine = TransferEngine(gpu_handles=self.gpu_handles, + model_config=self.model_config, + cache_config=self.cache_config, + cpu_handle=cpu_handle, + ssd_handle=ssd_handle, + remote_handle=remote_handle) + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.transfer_engine.submit_transfer_graph(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_engine.get_completed_graphs_and_ops(timeout) + + def start(self) -> None: + self.transfer_engine.start() + + def shutdown(self) -> None: + self.transfer_engine.shutdown() + + +class TransferManagerHandleBase(ABC): + @abstractmethod + def start(self) -> None: + pass + + @abstractmethod + def is_ready(self) -> bool: + pass + + @abstractmethod + def submit(self, transfer_graph: TransferOpGraph) -> None: + pass + + @abstractmethod + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def shutdown(self) -> None: + pass + + +class TransferManagerIntraProcessHandle(TransferManagerHandleBase): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + self.transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) + self._is_ready = False + + def start(self) -> None: + self.transfer_manager.initialize_transfer_engine() + self.transfer_manager.start() + self._is_ready = True + + def is_ready(self) -> bool: + return self._is_ready + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.transfer_manager.submit(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self.transfer_manager.wait(timeout) + + def shutdown(self) -> None: + self.transfer_manager.shutdown() + + +class TransferManagerInterProcessHandle(TransferManagerHandleBase): + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + gpu_register_port: str): + mp.set_start_method('spawn', force=True) + + self.model_config = model_config + self.cache_config = cache_config + self.gpu_register_port = gpu_register_port + + self.command_parent_conn, self.command_child_conn = Pipe() + self.result_parent_conn, self.result_child_conn = Pipe() + + self.process: Optional[Process] = None + self.ready_event = Event() + + self._completed_results: List[Tuple[int, int]] = [] + + def _start_process(self) -> None: + if self.process is not None and self.process.is_alive(): + return + + self.process = Process( + target=self._process_worker, + args=(self.model_config, + self.cache_config, + self.command_child_conn, + self.result_child_conn, + self.gpu_register_port, + self.ready_event), + daemon=False + ) + self.process.start() + + def _process_worker(self, + model_config: ModelConfig, + cache_config: CacheConfig, + command_conn, + result_conn, + gpu_register_port: str, + ready_event) -> None: + try: + transfer_manager = TransferManager(model_config, cache_config, gpu_register_port) + transfer_manager.initialize_transfer_engine() + transfer_manager.start() + ready_event.set() + while True: + try: + if command_conn.poll(timeout=0.0001): + request = command_conn.recv() + request_type = request.get('type') + if request_type == 'submit': + transfer_manager.submit(request['transfer_graph']) + else: + flexkv_logger.error(f"Unrecognized request type: {request_type}") + try: + finished_ops = transfer_manager.wait(0.0001) + if finished_ops: + result_conn.send(finished_ops) + except queue.Empty: + pass + except Exception as e: + flexkv_logger.error(f"Error in transfer manager process: {e}") + + except Exception as e: + flexkv_logger.error(f"Failed to initialize transfer manager process: {e}") + finally: + command_conn.close() + result_conn.close() + + def start(self) -> None: + self._start_process() + + def is_ready(self) -> bool: + return self.ready_event.is_set() + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self.command_parent_conn.send({ + 'type': 'submit', + 'transfer_graph': transfer_graph + }) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + finished_ops: List[Tuple[int, int]] = [] + try: + if self.result_parent_conn.poll(timeout=timeout): + finished_ops += self.result_parent_conn.recv() + while self.result_parent_conn.poll(): + finished_ops += self.result_parent_conn.recv() + except EOFError: + pass + + return finished_ops + + def shutdown(self) -> None: + if self.process is not None: + self.process.terminate() + self.process.join(timeout=5.0) + if self.process.is_alive(): + self.process.kill() + self.process.join() + + self.command_parent_conn.close() + self.result_parent_conn.close() + + def __del__(self): + self.shutdown() + + +class TransferManagerHandle: + def __init__(self, + model_config: ModelConfig, + cache_config: CacheConfig, + use_separate_process: bool = True, + gpu_register_port: Optional[str] = None): + if gpu_register_port is None: + gpu_register_port = f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}" + if use_separate_process: + self._handle: TransferManagerHandleBase = TransferManagerInterProcessHandle( + model_config, cache_config, gpu_register_port + ) + else: + self._handle: TransferManagerHandleBase = TransferManagerIntraProcessHandle( + model_config, cache_config, gpu_register_port + ) + + def start(self) -> None: + self._handle.start() + + def is_ready(self) -> bool: + return self._handle.is_ready() + + def submit(self, transfer_graph: TransferOpGraph) -> None: + self._handle.submit(transfer_graph) + + def wait(self, timeout: Optional[float] = None) -> List[Tuple[int, int]]: + return self._handle.wait(timeout) + + def shutdown(self) -> None: + self._handle.shutdown() diff --git a/setup.py b/setup.py index f69e28b923..fcd0f97a34 100755 --- a/setup.py +++ b/setup.py @@ -2,13 +2,16 @@ import shutil import sys -from Cython.Build import cythonize + from setuptools import find_packages, setup from setuptools.command.build_ext import build_ext from torch.utils import cpp_extension +def get_version(): + with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: + return f.read().strip() -build_dir = os.path.abspath("build") +build_dir = "build" os.makedirs(build_dir, exist_ok=True) # Check if we're in debug mode using environment variable @@ -25,17 +28,19 @@ "csrc/hash.cpp", "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", + "csrc/radix_tree.cpp", ] hpp_sources = [ "csrc/cache_utils.h", "csrc/tp_transfer_thread_group.h", "csrc/transfer_ssd.h", + "csrc/radix_tree.h", ] extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] extra_compile_args = ["-std=c++17"] -include_dirs = [os.path.join(build_dir, "include")] +include_dirs = [os.path.abspath(os.path.join(build_dir, "include"))] # Add rpath to find libraries at runtime lib_dir = os.path.join(build_dir, "lib") @@ -76,6 +81,8 @@ "flexkv/**/benchmark_*.py", "flexkv/benchmark/**/*.py", "flexkv/benchmark/test_kvmanager.py"] + # Import cython when debug is turned off. + from Cython.Build import cythonize cythonized_modules = cythonize( python_files, exclude=excluded_files, @@ -84,6 +91,7 @@ "boundscheck": False, "wraparound": False, "initializedcheck": False, + "profile": True, }, build_dir=build_dir, # Direct Cython to use the build directory ) @@ -125,16 +133,17 @@ def copy_shared_libraries(self): setup( name="flexkv", description="A global KV-Cache manager for LLM inference", - version="0.1.0", + version=get_version(), packages=find_packages(exclude=("benchmarks", "csrc", "examples", "tests")), package_data={ - "flexkv": ["lib/*.so", "lib/*.so.*"], + "flexkv": ["*.so", "lib/*.so", "lib/*.so.*"], }, include_package_data=True, install_requires=install_requires, ext_modules=ext_modules, # Now contains both C++ and Cython modules as needed cmdclass={ "build_ext": CustomBuildExt.with_options( + include_dirs=os.path.join(build_dir, "include"), # Include directory for xxhash no_python_abi_suffix=True, build_temp=os.path.join(build_dir, "temp"), # Temporary build files ) diff --git a/tests/conftest.py b/tests/conftest.py index e04cf78e50..2aba477a6f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,3 @@ """ # Import fixtures from test_utils so pytest can discover them from test_utils import model_config, cache_config, test_config - -import multiprocessing as mp - -# Set the start method for multiprocessing to 'spawn' -# This ensures consistent behavior across different platforms -mp.set_start_method("spawn", force=True) diff --git a/tests/replay_from_tracer.py b/tests/replay_from_tracer.py index c9de75b720..fad6a20ea9 100644 --- a/tests/replay_from_tracer.py +++ b/tests/replay_from_tracer.py @@ -24,7 +24,7 @@ from flexkv.common.config import CacheConfig, ModelConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.memory_handle import TensorSharedHandle -from flexkv.kvmanager import KVManager +from flexkv.kvtask import KVTaskEngine class FlexKVReplayEngine: @@ -113,7 +113,6 @@ def parse_config_event(self, event: Dict[str, Any]): ssd_kv_layout_type=self._parse_layout_type(cache_config_data['ssd_kv_layout_type']), remote_kv_layout_type=self._parse_layout_type(cache_config_data['remote_kv_layout_type']), use_gds=cache_config_data['use_gds'], - use_pinned_memory=False,#cache_config_data['use_pinned_memory'], # for local test remote_cache_size_mode=cache_config_data['remote_cache_size_mode'], num_cpu_blocks=cache_config_data['num_cpu_blocks'], num_ssd_blocks=cache_config_data['num_ssd_blocks'], @@ -204,7 +203,7 @@ def create_kvmanager(self,): ) # Create KVManager - self.kvmanager = KVManager( + self.kvmanager = KVTaskEngine( model_config=self.model_config, cache_config=self.cache_config, gpu_layout=self.gpu_layout, @@ -274,10 +273,6 @@ def replay_wait_event(self, event: Dict[str, Any]): result = self.kvmanager.wait_for_graph_finished(task_ids) elif wait_type == "try_wait": result = self.kvmanager.try_wait(task_ids) - elif wait_type == "wait_at_layer_group": - result = self.kvmanager.wait_at_layer_group(task_ids[0], layer_group_id) - elif wait_type == "try_wait_at_layer_group": - result = self.kvmanager.try_wait_at_layer_group(task_ids, layer_group_id) else: raise ValueError(f"Unknown wait type: {wait_type}") successed_elements = [] diff --git a/tests/test_cache_engine.py b/tests/test_cache_engine.py index 133dfb470c..e224d4305b 100644 --- a/tests/test_cache_engine.py +++ b/tests/test_cache_engine.py @@ -1,7 +1,7 @@ import random import pytest -import torch +import numpy as np from flexkv.cache.mempool import Mempool from flexkv.cache.cache_engine import CacheEngine @@ -16,6 +16,7 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngine: 'device_type': DeviceType.CPU, 'num_total_blocks': 64, 'tokens_per_block': 4, + 'evict_ratio': 0.05, } default_config_kwargs.update(param) return CacheEngine(**default_config_kwargs) @@ -23,11 +24,11 @@ def cache_engine(request: pytest.FixtureRequest) -> CacheEngine: @pytest.mark.parametrize( "config, should_raise", [ - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), - ({'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), - ({'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'evict_ratio': 0.05, 'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), ] ) def test_config_init(config: dict, should_raise: bool): @@ -42,14 +43,14 @@ def test_mempool(): mempool = Mempool(num_total_blocks=64) assert mempool.num_free_blocks == 64 block_ids = mempool.allocate_blocks(16) - assert isinstance(block_ids, torch.Tensor) - assert block_ids.dtype == torch.int64 + assert isinstance(block_ids, np.ndarray) + assert block_ids.dtype == np.int64 assert block_ids.shape == (16,) assert mempool.num_free_blocks == 48 mempool.recycle_blocks(block_ids) assert mempool.num_free_blocks == 64 - block_ids = torch.cat([mempool.allocate_blocks(16), + block_ids = np.concatenate([mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16), mempool.allocate_blocks(16)]) @@ -63,21 +64,21 @@ def test_mempool(): empty_blocks = mempool.allocate_blocks(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): mempool.allocate_blocks(-1) - mempool.recycle_blocks(torch.tensor([], dtype=torch.int64)) + mempool.recycle_blocks(np.array([], dtype=np.int64)) assert mempool.num_free_blocks == 64 with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int32)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([1, 2, 3], dtype=torch.int64)) + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) with pytest.raises(ValueError): - mempool.recycle_blocks(torch.tensor([[1, 2, 3]], dtype=torch.int64)) + mempool.recycle_blocks(np.array([[1, 2, 3]], dtype=np.int64)) def test_reset(cache_engine: CacheEngine): cache_engine.reset() @@ -101,22 +102,22 @@ def test_reset(cache_engine: CacheEngine): [1, 10, 16, 32, 10000], ) def test_match_and_insert(cache_engine: CacheEngine, num_insert: int, seq_len: int): - base_token_ids = torch.randint(0, 10000, (seq_len, ), dtype=torch.int64) + base_token_ids = np.random.randint(0, 10000, (seq_len, ), dtype=np.int64) base_num_blocks = seq_len // cache_engine.tokens_per_block cache_engine.insert(SequenceMeta(token_ids=base_token_ids, tokens_per_block=cache_engine.tokens_per_block), - torch.arange(base_num_blocks, dtype=torch.int64), + np.arange(base_num_blocks, dtype=np.int64), is_ready=True) cur_cached_blocks = base_num_blocks for i in range(num_insert): prefix_ratio = random.random() prefix_len = int(len(base_token_ids)*prefix_ratio) num_prefix_blocks = prefix_len // cache_engine.tokens_per_block - token_ids = torch.cat([base_token_ids[:prefix_len], - torch.randint(10000 + i * seq_len, + token_ids = np.concatenate([base_token_ids[:prefix_len], + np.random.randint(10000 + i * seq_len, 10000 + (i+1) * seq_len, (seq_len-prefix_len, ), - dtype=torch.int64)]) + dtype=np.int64)]) insert_sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=cache_engine.tokens_per_block) match_result = cache_engine.match(insert_sequence_meta) @@ -125,11 +126,11 @@ def test_match_and_insert(cache_engine: CacheEngine, num_insert: int, seq_len: i assert match_result.last_ready_node is not None assert match_result.last_node is not None assert match_result.physical_blocks.shape == (num_prefix_blocks, ) - assert match_result.physical_blocks.dtype == torch.int64 + assert match_result.physical_blocks.dtype == np.int64 num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks cache_engine.insert(insert_sequence_meta, - torch.arange(num_insert_blocks, dtype=torch.int64), + np.arange(num_insert_blocks, dtype=np.int64), is_ready=True, match_result=match_result) cur_cached_blocks += num_insert_blocks @@ -150,7 +151,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): num_total_blocks = cache_engine.num_total_blocks tokens_per_block = cache_engine.tokens_per_block seq_blocks = 10 - token_ids = torch.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=torch.int64) + token_ids = np.random.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=np.int64) sequence_meta = SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) physical_blocks = cache_engine.take(seq_blocks) @@ -159,7 +160,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): empty_blocks = cache_engine.take(0) assert empty_blocks.shape == (0, ) - assert empty_blocks.dtype == torch.int64 + assert empty_blocks.dtype == np.int64 with pytest.raises(ValueError): cache_engine.take(-1) @@ -168,7 +169,7 @@ def test_take_and_recycle(cache_engine: CacheEngine): physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) - assert physical_blocks2.dtype == torch.int64 + assert physical_blocks2.dtype == np.int64 cache_engine.recycle(physical_blocks2) @@ -193,22 +194,22 @@ def test_cleanup(cache_engine: CacheEngine): if cache_engine.tokens_per_block != 1: pytest.skip("tokens_per_block != 1") tokens_per_block = cache_engine.tokens_per_block - token_ids_list = [torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64), - torch.tensor([0, 1, 2, 3, 17, 15, 19, 20], dtype=torch.int64), - torch.tensor([0, 23, 22, 21], dtype=torch.int64)] + token_ids_list = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64), + np.array([0, 1, 2, 3, 17, 15, 19, 20], dtype=np.int64), + np.array([0, 23, 22, 21], dtype=np.int64)] sequence_meta_list = [SequenceMeta(token_ids=token_ids, tokens_per_block=tokens_per_block) for token_ids in token_ids_list] num_insert_blocks0 = sequence_meta_list[0].num_blocks radixnode0 = cache_engine.insert(sequence_meta_list[0], - torch.arange(num_insert_blocks0, dtype=torch.int64), + np.arange(num_insert_blocks0, dtype=np.int64), is_ready=False) cache_engine.lock_node(radixnode0) radixnode0_size = radixnode0.size() match_result = cache_engine.match(sequence_meta_list[1]) num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks radixnode1 = cache_engine.insert(sequence_meta_list[1], - torch.arange(num_insert_blocks1, dtype=torch.int64), + np.arange(num_insert_blocks1, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode1) @@ -216,7 +217,7 @@ def test_cleanup(cache_engine: CacheEngine): match_result = cache_engine.match(sequence_meta_list[2]) num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks radixnode2 = cache_engine.insert(sequence_meta_list[2], - torch.arange(num_insert_blocks2, dtype=torch.int64), + np.arange(num_insert_blocks2, dtype=np.int64), match_result=match_result, is_ready=False) cache_engine.lock_node(radixnode2) diff --git a/tests/test_cache_engine_accel.py b/tests/test_cache_engine_accel.py new file mode 100644 index 0000000000..15fef43ec3 --- /dev/null +++ b/tests/test_cache_engine_accel.py @@ -0,0 +1,232 @@ +import random + +import pytest +import numpy as np + +from flexkv.cache.mempool import Mempool +from flexkv.cache.cache_engine import CacheEngineAccel +from flexkv.common.transfer import DeviceType +from flexkv.common.exceptions import InvalidConfigError, NotEnoughSpaceError +from flexkv.common.block import SequenceMeta + +@pytest.fixture +def cache_engine(request: pytest.FixtureRequest) -> CacheEngineAccel: + param = request.param if hasattr(request, 'param') else {} + default_config_kwargs = { + 'device_type': DeviceType.CPU, + 'num_total_blocks': 64, + 'tokens_per_block': 4, + 'evict_ratio': 0.05, + } + default_config_kwargs.update(param) + return CacheEngineAccel(**default_config_kwargs) + +@pytest.mark.parametrize( + "config, should_raise", + [ + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': DeviceType.CPU}, False), + ({'evict_ratio': 0.05, 'num_total_blocks': 0, 'tokens_per_block': 4, 'device_type': DeviceType.GPU}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 0, 'device_type': DeviceType.SSD}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 4, 'device_type': 'Unknown'}, True), + ({'evict_ratio': 0.05, 'num_total_blocks': 64, 'tokens_per_block': 3, 'device_type': DeviceType.CPU}, True), + ] +) +def test_config_init(config: dict, should_raise: bool): + if should_raise: + with pytest.raises(InvalidConfigError) as e: + CacheEngineAccel(**config) + else: + engine = CacheEngineAccel(**config) + assert isinstance(engine, CacheEngineAccel) + +def test_mempool(): + mempool = Mempool(num_total_blocks=64) + assert mempool.num_free_blocks == 64 + block_ids = mempool.allocate_blocks(16) + assert isinstance(block_ids, np.ndarray) + assert block_ids.dtype == np.int64 + assert block_ids.shape == (16,) + assert mempool.num_free_blocks == 48 + mempool.recycle_blocks(block_ids) + assert mempool.num_free_blocks == 64 + + block_ids = np.concatenate([mempool.allocate_blocks(16), + mempool.allocate_blocks(16), + mempool.allocate_blocks(16), + mempool.allocate_blocks(16)]) + assert mempool.num_free_blocks == 0 + + with pytest.raises(NotEnoughSpaceError): + mempool.allocate_blocks(1) + + mempool.recycle_blocks(block_ids) + assert mempool.num_free_blocks == 64 + + empty_blocks = mempool.allocate_blocks(0) + assert empty_blocks.shape == (0, ) + assert empty_blocks.dtype == np.int64 + assert mempool.num_free_blocks == 64 + + with pytest.raises(ValueError): + mempool.allocate_blocks(-1) + + mempool.recycle_blocks(np.array([], dtype=np.int64)) + assert mempool.num_free_blocks == 64 + + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int32)) + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([1, 2, 3], dtype=np.int64)) + with pytest.raises(ValueError): + mempool.recycle_blocks(np.array([[1, 2, 3]], dtype=np.int64)) + +def test_reset(cache_engine: CacheEngineAccel): + cache_engine.reset() + assert cache_engine.index.is_empty() + assert cache_engine.mempool.num_used_blocks == 0 + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 10000000, 'tokens_per_block': 1, 'device_type': DeviceType.CPU}, + {'num_total_blocks': 10000000, 'tokens_per_block': 16, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +@pytest.mark.parametrize( + "num_insert", + [100], +) +@pytest.mark.parametrize( + "seq_len", + [1, 10, 16, 32, 10000], +) +def test_match_and_insert(cache_engine: CacheEngineAccel, num_insert: int, seq_len: int): + base_token_ids = np.random.randint(0, 10000, (seq_len, ), dtype=np.int64) + base_num_blocks = seq_len // cache_engine.tokens_per_block + cache_engine.insert(SequenceMeta(token_ids=base_token_ids, + tokens_per_block=cache_engine.tokens_per_block), + np.arange(base_num_blocks, dtype=np.int64), + is_ready=True) + cur_cached_blocks = base_num_blocks + for i in range(num_insert): + prefix_ratio = random.random() + prefix_len = int(len(base_token_ids)*prefix_ratio) + num_prefix_blocks = prefix_len // cache_engine.tokens_per_block + token_ids = np.concatenate([base_token_ids[:prefix_len], + np.random.randint(10000 + i * seq_len, + 10000 + (i+1) * seq_len, + (seq_len-prefix_len, ), + dtype=np.int64)]) + insert_sequence_meta = SequenceMeta(token_ids=token_ids, + tokens_per_block=cache_engine.tokens_per_block) + match_result = cache_engine.match(insert_sequence_meta) + assert match_result.num_ready_matched_blocks == num_prefix_blocks + assert match_result.num_matched_blocks == num_prefix_blocks + + num_insert_blocks = insert_sequence_meta.num_blocks - num_prefix_blocks + cache_engine.insert(insert_sequence_meta, + np.arange(num_insert_blocks, dtype=np.int64), + is_ready=True, + match_result=match_result) + cur_cached_blocks += num_insert_blocks + assert cache_engine.index.total_cached_blocks() == cur_cached_blocks + + match_result = cache_engine.match(insert_sequence_meta) + assert match_result.num_matched_blocks == insert_sequence_meta.num_blocks + assert match_result.num_ready_matched_blocks == insert_sequence_meta.num_blocks + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 100, 'tokens_per_block': 16, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +def test_take_and_recycle(cache_engine: CacheEngineAccel): + num_total_blocks = cache_engine.num_total_blocks + tokens_per_block = cache_engine.tokens_per_block + seq_blocks = 10 + token_ids = np.random.randint(0, 10000, (seq_blocks * tokens_per_block, ), dtype=np.int64) + sequence_meta = SequenceMeta(token_ids=token_ids, + tokens_per_block=tokens_per_block) + physical_blocks = cache_engine.take(seq_blocks) + radixnode = cache_engine.insert(sequence_meta, physical_blocks, is_ready=True) + assert cache_engine.index.total_cached_blocks() == seq_blocks + + empty_blocks = cache_engine.take(0) + assert empty_blocks.shape == (0, ) + assert empty_blocks.dtype == np.int64 + + with pytest.raises(ValueError): + cache_engine.take(-1) + with pytest.raises(NotEnoughSpaceError): + cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) + + physical_blocks2 = cache_engine.take(num_total_blocks, protected_node=radixnode, strict=False) + assert physical_blocks2.shape == (num_total_blocks - seq_blocks, ) + assert physical_blocks2.dtype == np.int64 + + cache_engine.recycle(physical_blocks2) + + cache_engine.lock_node(radixnode) + with pytest.raises(NotEnoughSpaceError): + cache_engine.take(num_total_blocks, protected_node=radixnode, strict=True) + cache_engine.cleanup(radixnode, radixnode.size()) + + physical_blocks = cache_engine.take(num_total_blocks, protected_node=None, strict=True) + assert physical_blocks.shape == (num_total_blocks, ) + assert cache_engine.index.total_cached_blocks() == 0 + +@pytest.mark.parametrize( + "cache_engine", + [ + {'num_total_blocks': 100, 'tokens_per_block': 1, 'device_type': DeviceType.CPU}, + ], + indirect=True +) +def test_cleanup(cache_engine: CacheEngineAccel): + if cache_engine.tokens_per_block != 1: + pytest.skip("tokens_per_block != 1") + tokens_per_block = cache_engine.tokens_per_block + token_ids_list = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64), + np.array([0, 1, 2, 3, 17, 15, 19, 20], dtype=np.int64), + np.array([0, 23, 22, 21], dtype=np.int64)] + sequence_meta_list = [SequenceMeta(token_ids=token_ids, + tokens_per_block=tokens_per_block) + for token_ids in token_ids_list] + num_insert_blocks0 = sequence_meta_list[0].num_blocks + radixnode0 = cache_engine.insert(sequence_meta_list[0], + np.arange(num_insert_blocks0, dtype=np.int64), + is_ready=False) + cache_engine.lock_node(radixnode0) + radixnode0_size = radixnode0.size() + match_result = cache_engine.match(sequence_meta_list[1]) + num_insert_blocks1 = sequence_meta_list[1].num_blocks - match_result.num_matched_blocks + radixnode1 = cache_engine.insert(sequence_meta_list[1], + np.arange(num_insert_blocks1, dtype=np.int64), + match_result=match_result, + is_ready=False) + cache_engine.lock_node(radixnode1) + radixnode1_size = radixnode1.size() + match_result = cache_engine.match(sequence_meta_list[2]) + num_insert_blocks2 = sequence_meta_list[2].num_blocks - match_result.num_matched_blocks + radixnode2 = cache_engine.insert(sequence_meta_list[2], + np.arange(num_insert_blocks2, dtype=np.int64), + match_result=match_result, + is_ready=False) + cache_engine.lock_node(radixnode2) + radixnode2_size = radixnode2.size() + total_insert_blocks = num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 + assert cache_engine.index.total_cached_blocks() == total_insert_blocks + assert cache_engine.index.total_unready_blocks() == total_insert_blocks + assert cache_engine.index.total_ready_blocks() == 0 + + cache_engine.cleanup(radixnode2, radixnode2_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks2 + + cache_engine.cleanup(radixnode1, radixnode1_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks1 + num_insert_blocks2 + + cache_engine.cleanup(radixnode0, radixnode0_size) + assert cache_engine.index.total_ready_blocks() == num_insert_blocks0 + num_insert_blocks1 + num_insert_blocks2 diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index c01bee7aaa..f8e74d45e2 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -4,19 +4,67 @@ import pytest import torch +import multiprocessing as mp +from multiprocessing import Process, Pipe from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.request import KVResponseStatus +from flexkv.kvtask import KVTaskEngine from flexkv.kvmanager import KVManager +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.server.client import KVTPClient +from flexkv.common.debug import flexkv_logger # Import utilities from test_utils from test_utils import ( DEFAULT_MODEL_CONFIG, DEFAULT_CACHE_CONFIG, DEFAULT_TEST_CONFIG, generate_request_pair, verify_data, block_ids_2_slot_mapping, generate_gpu_blocks_with_ground_truth, skip_if_insufficient_gpus, - create_kvmanager_with_mode, create_gpu_kv_layout + create_gpu_kv_layout, GPUKVCacheVerifier ) +def run_tp_client(dp_client_id, tp_rank, server_recv_port, model_config, cache_config, num_gpu_blocks, child_conn): + """Run tp_client process""" + try: + device_id = tp_rank + dp_client_id * model_config.tp_size + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) + + gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) + + # Create GPU blocks for this tp_rank in the tp_client process + gpu_blocks_for_tp = [] + for _ in range(model_config.num_layers): + gpu_blocks_for_tp.append( + torch.empty(size=tuple(gpu_kv_layout.kv_shape[1:]), dtype=model_config.dtype).cuda(device_id) + ) + tp_client.register_to_server(gpu_blocks_for_tp, gpu_kv_layout) + + # Send GPU blocks back to main process via pipe if connection provided + if child_conn is not None: + print(f"[TP Client {tp_rank}] Converting {len(gpu_blocks_for_tp)} GPU blocks to TensorSharedHandle") + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks_for_tp] + child_conn.send(shared_gpu_blocks) + print(f"[TP Client {tp_rank}] Sent GPU blocks to main process via pipe") + child_conn.close() + + # Keep the process running + while True: + time.sleep(1) + except Exception as e: + if child_conn is not None: + child_conn.send(None) + child_conn.close() + +def shutdown_tp_client(tp_client_processes): + for tp_process in tp_client_processes: + if tp_process.is_alive(): + tp_process.terminate() + tp_process.join(timeout=5) + if tp_process.is_alive(): + print(f"Force killing tp_client process {tp_process.pid}") + tp_process.kill() + tp_process.join(timeout=2) @pytest.mark.parametrize("model_config", [ {'tp_size': 1, 'dp_size': 1}, @@ -31,23 +79,18 @@ {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': False, 'ssd_cache_iouring_entries': 512}, {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': True, 'num_ssd_blocks': 256, 'num_remote_blocks': 512}, {'enable_cpu': True, 'enable_ssd': True, 'enable_remote': True, - 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, + 'num_ssd_blocks': 256, 'num_remote_blocks': 512, 'ssd_cache_iouring_entries': 512}, ], indirect=True) @pytest.mark.parametrize("test_config", [ - {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4, 'use_server_client': False}, - {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4, 'use_server_client': True}, + {'num_gpu_blocks': 512, 'requests_per_block': 16, 'initial_write_ratio': 0.4}, ], indirect=True) @pytest.mark.parametrize("flex_kv_layout_type", [ KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE, ]) def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type): - num_layers = model_config.num_layers - num_kv_heads = model_config.num_kv_heads - head_size = model_config.head_size tp_size = model_config.tp_size dp_size = model_config.dp_size - use_mla = model_config.use_mla tokens_per_block = cache_config.tokens_per_block num_cpu_blocks = cache_config.num_cpu_blocks @@ -64,7 +107,6 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) num_gpu_blocks = test_config["num_gpu_blocks"] block_per_request = test_config['requests_per_block'] initial_write_ratio = test_config['initial_write_ratio'] - use_server_client = test_config.get('use_server_client', False) num_requests = num_gpu_blocks // block_per_request @@ -74,50 +116,97 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) if enable_remote: pytest.skip("skip because enable_remote is not supported") - if use_server_client and dp_size > 1: - pytest.skip("skip because server-client mode is not supported for dp_size > 1 IN THIS TEST SCRIPT now") + if dp_size > 1: + #note that for now only dp_size=1 is supported + pytest.skip("skip because server-client mode is not ready for dp_size > 1") + + import uuid + gpu_register_port = f"ipc:///tmp/flexkv_gpu_{uuid.uuid4().hex[:8]}" + server_recv_port = f"ipc:///tmp/flexkv_srv_{uuid.uuid4().hex[:8]}" + kvmanager = KVManager(model_config, cache_config, gpu_register_port, server_recv_port) + kvmanager.start() + + # Create pipes for each tp_client to send GPU blocks back + pipe_connections = [] + tp_client_processes = [] + + for tp_rank in range(tp_size): + parent_conn, child_conn = Pipe() + pipe_connections.append(parent_conn) - if use_server_client: - # In server-client mode, GPU blocks are created in tp_client processes - # We only need the layout for initialization + tp_client_process = Process( + target=run_tp_client, + args=(0, tp_rank, gpu_register_port, model_config, cache_config, num_gpu_blocks + tp_rank, child_conn), + daemon=True + ) + tp_client_processes.append(tp_client_process) + tp_client_process.start() + + # Collect GPU blocks from all tp_client processes + print(f"[Main Process] Waiting to receive GPU blocks from {tp_size} TP client processes...") + all_gpu_blocks = [] + + for tp_rank, parent_conn in enumerate(pipe_connections): + try: + shared_gpu_blocks = parent_conn.recv() + if shared_gpu_blocks is not None: + all_gpu_blocks.append(shared_gpu_blocks) + print(f"[Main Process] Received GPU blocks from TP client {tp_rank}") + else: + print(f"[Main Process] TP client {tp_rank} failed to create GPU blocks") + parent_conn.close() + except Exception as e: + print(f"[Main Process] Error receiving from TP client {tp_rank}: {e}") + + # Create GPUKVCacheVerifier with collected GPU blocks + if all_gpu_blocks and len(all_gpu_blocks) == tp_size: + print(f"[Main Process] Creating GPUKVCacheVerifier with GPU blocks from {len(all_gpu_blocks)} TP clients") + + # Get gpu_kv_layout from cache_config for GPUKVCacheVerifier gpu_kv_layout = create_gpu_kv_layout(model_config, cache_config, num_gpu_blocks) - gpu_blocks = None # Not used in server-client mode - dp_wise_gpu_blocks_gt = None # Not used in server-client mode + + gpu_kv_verifier = GPUKVCacheVerifier( + shared_gpu_blocks=all_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype + ) + print("[Main Process] GPUKVCacheVerifier created successfully") else: - # In direct mode, create GPU blocks in current process - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config) + print(f"[Main Process] Failed to collect GPU blocks from all TP clients. " + f"Got {len(all_gpu_blocks)} out of {tp_size}") + gpu_kv_verifier = None - kvmanager = create_kvmanager_with_mode(model_config, cache_config, gpu_kv_layout, gpu_blocks, use_server_client) + while not kvmanager.is_ready(): + time.sleep(1) + flexkv_logger.info("waiting for flexkv to be ready") - # put this after KVManager() num_remote_blocks = cache_config.num_remote_blocks - assert kvmanager.is_ready() - kvmanager.start() request_pairs = [generate_request_pair(i, block_per_request, num_gpu_blocks, tokens_per_block, dp_size) for i in range(num_requests)] initial_write_num = int(num_requests * initial_write_ratio) print("writing initial data...") + put_ids = [] for token_ids, block_ids, dp_id in request_pairs[:initial_write_num]: + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) write_request = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, dp_id=dp_id, ) - kvmanager.wait_for_graph_finished(write_request) - if not use_server_client: - # In direct mode, update GPU blocks for verification - for gpu in range(dp_id * tp_size, (dp_id + 1) * tp_size): - for i in range(num_layers): - gpu_blocks[gpu][i][:, block_ids, :, :, :] = 0 + kvmanager.wait([write_request], completely=True) #corner case: input token length for put is less than tokens_per_block write_request = kvmanager.put_async( token_ids=torch.randint(0, 100, size=(8,), dtype=torch.int64), slot_mapping=block_ids_2_slot_mapping(torch.arange(0,1, dtype=torch.int64), tokens_per_block, actual_length=8), + token_mask=None, dp_id=0, ) - kvmanager.wait_for_graph_finished(write_request) + kvmanager.wait([write_request], completely=True) #corner case: input token length is long enough, but the mask is less than tokens_per_block #my_mask = torch.zeros(16, dtype=torch.bool) #my_mask[0:8] = True @@ -134,44 +223,77 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) total_cache_miss = 0 running_get_requests = [] running_put_requests = [] + req_id2block_ids = {} + req_id2token_ids = {} + flexkv_id2req_id = {} start_time = time.time() print(f"the initial {initial_write_num} write done,performing mixed read/write...") for i in range(initial_write_num, num_requests): print(f"performing mixed read/write {i} / {num_requests} ...") read_idx = i - initial_write_num token_ids, block_ids, dp_id = request_pairs[read_idx] - request_id = kvmanager.get_async( + slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) + request_id, _ = kvmanager.get_match( token_ids=token_ids, - slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), layer_granularity=-1, + token_mask=None, dp_id=dp_id, ) + kvmanager.launch(request_id, slot_mapping) + flexkv_id2req_id[request_id] = read_idx running_get_requests.append(request_id) + req_id2block_ids[request_id] = block_ids + req_id2token_ids[request_id] = token_ids token_ids, block_ids, dp_id = request_pairs[i] + if gpu_kv_verifier is not None: + gpu_kv_verifier.fill_gpu_blocks(token_ids, block_ids) request_id = kvmanager.put_async( token_ids=token_ids, slot_mapping=block_ids_2_slot_mapping(block_ids, tokens_per_block), + token_mask=None, dp_id=dp_id, ) + flexkv_id2req_id[request_id] = i + print(f"write flexkv request_id {request_id} to req_id {i}") running_put_requests.append(request_id) min_block_num = min(num_cpu_blocks, num_gpu_blocks) if (len(running_get_requests) + len(running_put_requests) >= min_block_num // block_per_request - 2 or i % initial_write_num == initial_write_num - 1 or i == num_requests - 1): if len(running_put_requests) > 0: - kvmanager.wait_for_graph_finished(running_put_requests) + kvmanager.wait(running_put_requests, completely=True) if len(running_get_requests) > 0: - return_masks = kvmanager.wait(running_get_requests) - for return_mask in return_masks.values(): - total_cache_hit += return_mask.sum() - total_cache_miss += len(return_mask) - return_mask.sum() + return_results = kvmanager.wait(running_get_requests, completely=True) + if gpu_kv_verifier is not None: + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + valid_fetched_tokens = kvresponse.return_mask.sum().item() // \ + tokens_per_block * tokens_per_block + token_ids = req_id2token_ids[req_id] + block_ids = req_id2block_ids[req_id] + assert gpu_kv_verifier.verify_kv_blocks( + token_ids[:valid_fetched_tokens], + block_ids[:valid_fetched_tokens//tokens_per_block]) + for kvresponse in return_results.values(): + assert kvresponse.status == KVResponseStatus.SUCCESS + total_cache_hit += kvresponse.return_mask.sum().item() + total_cache_miss += len(kvresponse.return_mask) - kvresponse.return_mask.sum().item() running_get_requests = [] running_put_requests = [] if len(running_get_requests) > 0: - kvmanager.wait(running_get_requests) + return_results = kvmanager.wait(running_get_requests, completely=True) + if gpu_kv_verifier is not None: + for req_id, kvresponse in return_results.items(): + assert kvresponse.status == KVResponseStatus.SUCCESS + valid_fetched_tokens = kvresponse.return_mask.sum().item() // tokens_per_block * tokens_per_block + token_ids = req_id2token_ids[req_id] + block_ids = req_id2block_ids[req_id] + assert gpu_kv_verifier.verify_kv_blocks( + token_ids[:valid_fetched_tokens], + block_ids[:valid_fetched_tokens//tokens_per_block]) running_get_requests = [] if len(running_put_requests) > 0: - kvmanager.wait_for_graph_finished(running_put_requests) + kvmanager.wait(running_put_requests, completely=True) running_put_requests = [] print("mixed read/write done") end_time = time.time() @@ -182,12 +304,12 @@ def test_kvmanager(model_config, cache_config, test_config, flex_kv_layout_type) enable_ssd and num_ssd_blocks >= num_gpu_blocks or \ enable_remote and num_remote_blocks >= num_gpu_blocks: assert total_cache_miss == 0 + shutdown_tp_client(tp_client_processes) kvmanager.shutdown() - if total_cache_miss == 0 and not use_server_client: - # Only verify data in direct mode - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, num_kv_heads, tp_size, dp_size, num_layers, use_mla) + # Only verify data in direct mode + # verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, num_kv_heads, tp_size, dp_size, num_layers, use_mla) + if total_cache_miss == 0: + return elif total_cache_miss > 0: - print(f"verify skipped, because of total_cache_miss={total_cache_miss}>0") - elif use_server_client: - print("verify skipped in server-client mode (verification happens in tp_client processes)") + print(f"verify skipped, because of total_cache_miss={total_cache_miss} > 0") diff --git a/tests/test_transfer_engine.py b/tests/test_transfer_engine.py deleted file mode 100644 index 73ee107864..0000000000 --- a/tests/test_transfer_engine.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -Transfer Engine Unit Tests - -This module contains comprehensive unit tests for the TransferEngine component, -which handles data transfers between different storage tiers (GPU, CPU, SSD). - -Test Functions Overview: -1. test_gpu_cpu_round_trip: Tests round-trip data transfers between GPU and CPU - - Parameterized by: tp_size, dp_size, num_gpu_blocks, transfer_block_num - - Validates data consistency after GPU->CPU->GPU transfers - -2. test_ssd_round_trip: Tests round-trip data transfers involving SSD storage - - Parameterized by: num_gpu_blocks, transfer_block_num, enable_ssd_cache - - Validates data consistency after GPU->CPU->SSD->CPU->GPU transfers - -3. test_concurrent_mixed_transfers: Tests multiple concurrent read/write transfers - - Parameterized by: num_concurrent_transfers, blocks_per_transfer, include_ssd - - Validates correctness of mixed read/write transfer graphs running concurrently - -usage example: - python -m pytest tests/test_transfer_engine.py::test_gpu_cpu_round_trip -v --tb=short -Each test validates both transfer completion and data correctness to ensure -the TransferEngine maintains data integrity across all transfer operations. -""" - -import os -import time -import tempfile -from typing import List, Dict, Tuple -import multiprocessing as mp -from contextlib import suppress - -import pytest -import torch - -from flexkv.cache.transfer_pattern import ( - create_read_graph_cpu_storage, - create_write_graph_cpu_storage, -) -from flexkv.common.config import ModelConfig, CacheConfig -from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType -from flexkv.common.transfer import DeviceType -from flexkv.storage.storage_engine import StorageEngine -from flexkv.transfer.transfer_engine import TransferEngine - -# Import utilities from test_utils -from test_utils import ( - wait_for_transfer_completion, - skip_if_no_cuda, - skip_if_insufficient_gpus, - generate_gpu_blocks_with_ground_truth, - verify_data -) - -@pytest.mark.parametrize("tp_size,dp_size", [(1, 1), (2, 1), (2, 2)]) -@pytest.mark.parametrize("num_gpu_blocks", [128]) -@pytest.mark.parametrize("transfer_block_num", [16]) -@pytest.mark.parametrize("use_mla", [False, True]) -@pytest.mark.parametrize("underlying_layout_type", [KVCacheLayoutType.LAYERWISE, KVCacheLayoutType.BLOCKWISE]) -def test_gpu_cpu_round_trip(model_config, - cache_config, - test_config, - tp_size, - dp_size, - num_gpu_blocks, - transfer_block_num, - use_mla, - underlying_layout_type): - """ - Test round-trip data transfers between GPU and CPU - - This test validates: - 1. GPU -> CPU transfer correctness - 2. CPU -> GPU transfer correctness - 3. Round-trip data consistency (GPU -> CPU -> GPU) - - Parameterized by: - - tp_size, dp_size: Tensor and data parallelism configurations - - num_gpu_blocks: Number of GPU blocks to test with - - transfer_block_num: Number of blocks to transfer in each operation - """ - total_gpus = tp_size * dp_size - skip_if_insufficient_gpus(total_gpus) - - if transfer_block_num > num_gpu_blocks: - pytest.skip(f"transfer_block_num ({transfer_block_num}) > num_gpu_blocks ({num_gpu_blocks})") - - # Update model config - model_config.use_mla = use_mla - model_config.tp_size = tp_size - model_config.dp_size = dp_size - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - cache_config.cpu_kv_layout_type = underlying_layout_type - # Setup configurations - cache_config.enable_ssd = False - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) for i in range(total_gpus)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle - ) - transfer_engine.start() - - # Test each DP group separately - for dp_id in range(dp_size): - gpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - cpu_block_ids = torch.arange(dp_id * transfer_block_num, (dp_id + 1) * transfer_block_num, dtype=torch.int64) - - # Step 1: GPU -> CPU transfer - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=torch.tensor([], dtype=torch.int64), - gpu_device_id=dp_id * tp_size, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(dp_id) - transfer_engine.submit_transfer_graph(write_graph) - - # Wait for write completion - assert wait_for_transfer_completion(transfer_engine, [write_graph.graph_id]), \ - f"GPU->CPU transfer failed for DP group {dp_id}" - - # Clear GPU blocks for read test - for tp_id in range(tp_size): - global_gpu_id = dp_id * tp_size + tp_id - for layer_id in range(model_config.num_layers): - gpu_blocks[global_gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Step 2: CPU -> GPU transfer - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=torch.tensor([], dtype=torch.int64), - gpu_device_id=dp_id * tp_size, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(dp_id) - transfer_engine.submit_transfer_graph(read_graph) - # Wait for read completion - assert wait_for_transfer_completion(transfer_engine, [read_graph.graph_id]), \ - f"CPU->GPU transfer failed for DP group {dp_id}" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - # Cleanup - transfer_engine.shutdown() - - -@pytest.mark.parametrize("num_gpu_blocks", [64, 128]) -@pytest.mark.parametrize("transfer_block_num", [8, 16]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("iouring_entries", [0, 512]) -def test_ssd_round_trip(model_config, - cache_config, - test_config, - num_gpu_blocks, - transfer_block_num, - use_mla, - iouring_entries): - """ - Test round-trip data transfers involving SSD storage - - This test validates: - 1. GPU -> CPU -> SSD transfer chain - 2. SSD -> CPU -> GPU transfer chain - 3. Full round-trip data consistency - - Parameterized by: - - num_gpu_blocks: Number of GPU blocks to test with - - transfer_block_num: Number of blocks to transfer - """ - skip_if_no_cuda() - - if transfer_block_num > num_gpu_blocks: - pytest.skip(f"transfer_block_num ({transfer_block_num}) > num_gpu_blocks ({num_gpu_blocks})") - - # Setup configurations - cache_config.enable_ssd = True - cache_config.ssd_cache_iouring_entries = iouring_entries - model_config.use_mla = use_mla - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - if (model_config.tp_size * model_config.dp_size) > 1: - pytest.skip("SSD transfer test is not supported for multi-GPU") - - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(model_config.tp_size * model_config.dp_size)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - ssd_handle = storage_engine.get_storage_handle(DeviceType.SSD) - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle - ) - transfer_engine.start() - # Prepare transfer block IDs - gpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - cpu_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - ssd_block_ids = torch.arange(0, transfer_block_num, dtype=torch.int64) - - # Step 1: GPU -> CPU -> SSD write - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(0) - transfer_engine.submit_transfer_graph(write_graph) - - # Wait for write completion - assert wait_for_transfer_completion(transfer_engine, [write_graph.graph_id]), \ - "GPU->CPU->SSD write transfer failed" - - # Clear GPU blocks for read test - for gpu_id in range(model_config.tp_size * model_config.dp_size): - for layer_id in range(model_config.num_layers): - gpu_blocks[gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Step 2: SSD -> CPU -> GPU read - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(0) - transfer_engine.submit_transfer_graph(read_graph) - - # Wait for read completion - assert wait_for_transfer_completion(transfer_engine, [read_graph.graph_id]), \ - "SSD->CPU->GPU read transfer failed" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - - # Cleanup - transfer_engine.shutdown() - - -@pytest.mark.parametrize("num_concurrent_transfers", [4]) -@pytest.mark.parametrize("blocks_per_transfer", [16]) -@pytest.mark.parametrize("include_ssd", [True, False]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("iouring_entries", [0, 512]) -def test_concurrent_mixed_transfers(model_config, - cache_config, - test_config, - num_concurrent_transfers, - blocks_per_transfer, - include_ssd, - use_mla, - iouring_entries): - """ - Test multiple concurrent read/write transfers - - This test validates: - 1. Multiple write transfers running concurrently - 2. Multiple read transfers running concurrently - 3. Mixed read/write transfers running concurrently - 4. Data correctness across all concurrent operations - - Parameterized by: - - num_concurrent_transfers: Number of concurrent transfer graphs - - blocks_per_transfer: Number of blocks per transfer operation - - include_ssd: Whether to include SSD in transfer operations - """ - model_config.use_mla = use_mla - skip_if_no_cuda() - - if (model_config.tp_size * model_config.dp_size) > 1: - pytest.skip("Concurrent transfer test is not supported for multi-GPU") - - total_blocks_needed = num_concurrent_transfers * blocks_per_transfer * 2 # For both read and write - num_gpu_blocks = max(128, total_blocks_needed) - - cache_config.num_cpu_blocks = num_gpu_blocks - cache_config.num_ssd_blocks = num_gpu_blocks - cache_config.ssd_cache_iouring_entries = iouring_entries - - # Setup configurations - cache_config.enable_ssd = include_ssd - - # Create a copy of test_config to avoid modifying the fixture - test_config_copy = test_config.copy() - test_config_copy['num_gpu_blocks'] = num_gpu_blocks - - gpu_blocks, dp_wise_gpu_blocks_gt, gpu_kv_layout = \ - generate_gpu_blocks_with_ground_truth(model_config, cache_config, test_config_copy) - - # Setup storage engine and transfer engine - storage_engine = StorageEngine(model_config, cache_config) - for gpu_id, gpu_block in gpu_blocks.items(): - storage_engine.register_gpu_blocks(gpu_block, gpu_kv_layout, device_id=gpu_id, dtype=model_config.dtype) - gpu_handles = [storage_engine.get_storage_handle(DeviceType.GPU, i) - for i in range(model_config.tp_size * model_config.dp_size)] - cpu_handle = storage_engine.get_storage_handle(DeviceType.CPU) - ssd_handle = storage_engine.get_storage_handle(DeviceType.SSD) if include_ssd else None - - transfer_engine = TransferEngine( - gpu_handles=gpu_handles, - model_config=model_config, - cache_config=cache_config, - cpu_handle=cpu_handle, - ssd_handle=ssd_handle - ) - - transfer_engine.start() - # Create concurrent write transfers - write_graphs = [] - - for i in range(num_concurrent_transfers): - start_block = i * blocks_per_transfer - end_block = start_block + blocks_per_transfer - - gpu_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) - cpu_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) - ssd_block_ids = torch.arange(start_block, end_block, dtype=torch.int64) \ - if include_ssd else torch.tensor([], dtype=torch.int64) - - - write_graph, _ = create_write_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - write_graph.bind_to_dp_group(0) - write_graphs.append(write_graph) - - # Submit all write transfers - for graph in write_graphs: - transfer_engine.submit_transfer_graph(graph) - - # Wait for all writes to complete - write_graph_ids = [graph.graph_id for graph in write_graphs] - assert wait_for_transfer_completion(transfer_engine, write_graph_ids, max_wait_time=20.0), \ - "Concurrent write transfers failed to complete" - - # Clear GPU blocks for read test - for gpu_id in range(model_config.tp_size * model_config.dp_size): - gpu_block_ids = torch.arange(0, (num_concurrent_transfers + 1) * blocks_per_transfer, dtype=torch.int64) - for layer_id in range(model_config.num_layers): - gpu_blocks[gpu_id][layer_id][:,gpu_block_ids].zero_() - - # Create concurrent read transfers (using different GPU blocks) - read_graphs = [] - - for i in range(num_concurrent_transfers): - gpu_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) - cpu_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) - ssd_block_ids = torch.arange(i * blocks_per_transfer, (i + 1) * blocks_per_transfer, dtype=torch.int64) \ - if include_ssd else torch.tensor([], dtype=torch.int64) - - read_graph, _ = create_read_graph_cpu_storage( - gpu_blocks=gpu_block_ids, - cpu_blocks=cpu_block_ids, - ssd_blocks=ssd_block_ids, - gpu_device_id=0, - layer_num=model_config.num_layers - ) - read_graph.bind_to_dp_group(0) - read_graphs.append(read_graph) - - # Submit all read transfers - for graph in read_graphs: - transfer_engine.submit_transfer_graph(graph) - - # Wait for all reads to complete - read_graph_ids = [graph.graph_id for graph in read_graphs] - assert wait_for_transfer_completion(transfer_engine, read_graph_ids, max_wait_time=20.0), \ - "Concurrent read transfers failed to complete" - - verify_data(gpu_blocks, dp_wise_gpu_blocks_gt, model_config.num_kv_heads, - model_config.tp_size, model_config.dp_size, model_config.num_layers, model_config.use_mla) - - # Cleanup - transfer_engine.shutdown() - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_utils.py b/tests/test_utils.py index a7e2f166c5..93541b612b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,16 @@ import time import os import shutil -from typing import List, Dict, Tuple -from multiprocessing import Process - +from typing import List, Dict, Tuple, Optional, Union +from multiprocessing import Process, Pipe, Queue +import pickle +import multiprocessing as mp import pytest import torch from flexkv.common.config import ModelConfig, CacheConfig from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.memory_handle import TensorSharedHandle # Default configurations @@ -36,9 +38,8 @@ 'remote_file_prefix': "remote_cache", 'use_gds': False, 'enable_trace': False, - 'use_pinned_memory': False, 'ssd_cache_dir': ["./ssd_cache", "./ssd_cache2/"], - 'ssd_cache_iouring_entries': 0, + 'ssd_cache_iouring_entries': 32, 'remote_cache_path': ["remote_cache1", "remote_cache2"], 'remote_config_custom': { "pcfs_fsid": "f_l91fz6", @@ -47,9 +48,9 @@ "pcfs_parent_nodeid": 144115188075855883 # Using transfer engine value for consistency }, 'use_ce_transfer_h2d': False, - 'use_ce_transfer_d2h': True, - 'transfer_sms_h2d': 4, - 'transfer_sms_d2h': 4, + 'use_ce_transfer_d2h': False, + 'transfer_sms_h2d': 8, + 'transfer_sms_d2h': 8, } DEFAULT_TEST_CONFIG = { @@ -292,7 +293,7 @@ def __init__(self, model_config, cache_config, gpu_kv_layout, gpu_blocks): tp_client_process = Process( target=KVManagerServerClient._run_tp_client, args=(self.dp_client.dp_client_id, tp_rank, device_id, self.server_recv_port, - model_config.num_layers, str(model_config.dtype), + model_config.num_layers, str(model_config.dtype), list(gpu_kv_layout.kv_shape[1:]), model_config.use_mla), daemon=True ) @@ -340,7 +341,7 @@ def _run_tp_client(dp_client_id, tp_rank, device_id, server_recv_port, num_layer from flexkv.server.client import KVTPClient from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType - tp_client = KVTPClient(server_recv_port, dp_client_id, device_id, tp_rank) + tp_client = KVTPClient(server_recv_port, dp_client_id, device_id) # Convert dtype string back to torch dtype if dtype_str == "torch.float16": dtype = torch.float16 @@ -449,11 +450,310 @@ def shutdown(self): print("KVManagerServerClient shutdown complete") -def create_kvmanager_with_mode(model_config, cache_config, gpu_kv_layout, gpu_blocks, use_server_client=False): - """Create KVManager with optional server-client mode""" - if use_server_client: - print("Using server-client mode") - return KVManagerServerClient(model_config, cache_config, gpu_kv_layout, gpu_blocks) - else: - from flexkv.kvmanager import KVManager - return KVManager(model_config, cache_config, gpu_kv_layout, gpu_blocks) +class GPUKVCacheVerifier: + def __init__(self, + shared_gpu_blocks: Union[List[torch.Tensor], List[TensorSharedHandle], List[List[TensorSharedHandle]]], + gpu_kv_layout: KVCacheLayout, + tp_size: int, + tokens_per_block: int, + dtype: torch.dtype)->None: + self.gpu_kv_layout = gpu_kv_layout + self.num_layers = gpu_kv_layout.num_layer + # we have to map the exported gpu blocks into the virtual space of current process + if isinstance(shared_gpu_blocks[0], torch.Tensor): + self.gpu_blocks = shared_gpu_blocks + elif isinstance(shared_gpu_blocks[0], TensorSharedHandle): + self.gpu_blocks = [wrapper.get_tensor() for wrapper in shared_gpu_blocks] + else: + imported_gpu_blocks = [] + for handles_in_one_gpu in shared_gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_gpu_blocks.append(blocks_in_one_gpu) + self.gpu_blocks = imported_gpu_blocks + self.gpu_block_num = gpu_kv_layout.num_block + self.tp_size = tp_size + self.is_mla = gpu_kv_layout.is_mla + self.tokens_per_block = tokens_per_block + self.dtype = dtype + + + def hash_all_values(self, layer_id, kv_id, token_ids, head_id): + base_hash = hash((layer_id, kv_id, head_id)) + + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + + token_hash = 0 + prime = 31 + for i, token_id in enumerate(token_ids): + token_hash += (token_id * (prime ** i)) % (2**31 - 1) + + combined_hash = (base_hash + token_hash) % (2**31 - 1) + + normalized_value = (combined_hash % 1000000) / 1000000.0 + + return torch.tensor(normalized_value, dtype=self.dtype).item() + + def fill_gpu_blocks(self, token_ids, block_ids): + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + # Ensure token_ids is in tensor format + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + for layer_id in range(self.num_layers): + kv_num = 2 if not self.is_mla else 1 + for kv_id in range(kv_num): + for tp_id in range(self.tp_size): + if isinstance(self.gpu_blocks[0], list): + # multiple gpu:gpu_blocks[tp_id][layer_id] + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + else: + # single gpu:gpu_blocks[layer_id] + gpu_tensor = self.gpu_blocks[layer_id] + + for head_id in range(self.gpu_kv_layout.num_head): + actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id + + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + hash_value = self.hash_all_values(layer_id, + kv_id, + token_ids[start_token_idx:end_token_idx], + actual_head_id) + # GPU tensor dim:[kv_dim, num_block, tokens_per_block, num_head, head_size] + gpu_tensor[kv_id, block_id, :, head_id, :] = hash_value + + def verify_kv_blocks(self, token_ids, block_ids)->bool: + assert len(token_ids) == len(block_ids) * self.tokens_per_block + + if not isinstance(token_ids, torch.Tensor): + token_ids = torch.tensor(token_ids, dtype=torch.int64) + if not isinstance(block_ids, torch.Tensor): + block_ids = torch.tensor(block_ids, dtype=torch.int64) + + verification_passed = True + errors = [] + + for layer_id in range(self.num_layers): + kv_num = 2 if not self.is_mla else 1 + for kv_id in range(kv_num): + for tp_id in range(self.tp_size): + if isinstance(self.gpu_blocks[0], list): + gpu_tensor = self.gpu_blocks[tp_id][layer_id] + else: + gpu_tensor = self.gpu_blocks[layer_id] + + for head_id in range(self.gpu_kv_layout.num_head): + actual_head_id = tp_id * self.gpu_kv_layout.num_head + head_id if not self.is_mla else head_id + for block_idx, block_id in enumerate(block_ids): + start_token_idx = block_idx * self.tokens_per_block + end_token_idx = start_token_idx + self.tokens_per_block + expected_hash_value = self.hash_all_values(layer_id, kv_id, + token_ids[start_token_idx:end_token_idx], + actual_head_id) + + actual_values = gpu_tensor[kv_id, block_id, :, head_id, :] + + if not torch.allclose(actual_values, + torch.full_like(actual_values, expected_hash_value), + rtol=1e-5, atol=1e-6): + verification_passed = False + errors.append( + f"Mismatch at layer={layer_id}, kv={kv_id}, tp={tp_id}, " + f"head={head_id}, block={block_id}: " + f"expected={expected_hash_value}, got={actual_values[0, 0].item()}" + ) + + if not verification_passed: + print(f"Verification failed with {len(errors)} errors:") + for error in errors[:10]: + print(f" {error}") + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more errors") + else: + print("KV blocks verification passed!") + + return verification_passed + + +def gpu_blocks_worker_process(conn, model_config, cache_config, gpu_kv_layout): + try: + print(f"[Worker Process {os.getpid()}] Starting to create GPU blocks...") + + # Create GPU blocks in subprocess + gpu_blocks = [] + for layer_id in range(model_config.num_layers): + # LAYERWISE format: [kv_dim, num_block, tokens_per_block, num_head, head_size] + kv_dim = 2 if not model_config.use_mla else 1 + gpu_tensor = torch.zeros( + kv_dim, + gpu_kv_layout.num_block, + gpu_kv_layout.tokens_per_block, + gpu_kv_layout.num_head, + gpu_kv_layout.head_size, + dtype=model_config.dtype, + device='cuda:0' if torch.cuda.is_available() else 'cpu' + ) + gpu_blocks.append(gpu_tensor) + + print(f"[Worker Process {os.getpid()}] Successfully created {len(gpu_blocks)} GPU blocks") + + # Convert to TensorSharedHandle + shared_gpu_blocks = [TensorSharedHandle(tensor) for tensor in gpu_blocks] + print(f"[Worker Process {os.getpid()}] Successfully converted to {len(shared_gpu_blocks)} TensorSharedHandles") + + # Send to main process via pipe + conn.send(shared_gpu_blocks) + print(f"[Worker Process {os.getpid()}] Sent TensorSharedHandle list to main process via pipe") + + #while True: + # time.sleep(1) + conn.close() + + except Exception as e: + print(f"[Worker Process {os.getpid()}] Error occurred: {e}") + conn.send(None) + conn.close() + + +# Usage examples +def example_usage_gpu_kv_cache_verifier(): + """Demonstrates three ways to initialize GPUKVCacheVerifier""" + import torch + from flexkv.common.config import ModelConfig, CacheConfig + from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType + from flexkv.common.memory_handle import TensorSharedHandle + + # Create example configurations + model_config = ModelConfig( + num_layers=2, + num_kv_heads=8, + head_size=64, + use_mla=False, + dtype=torch.float16, + tp_size=1, + dp_size=1 + ) + + cache_config = CacheConfig( + tokens_per_block=16 + ) + + # Create GPU KV layout + gpu_kv_layout = KVCacheLayout( + type=KVCacheLayoutType.LAYERWISE, + num_layer=model_config.num_layers, + num_block=64, # Assume 64 blocks + tokens_per_block=cache_config.tokens_per_block, + num_head=model_config.num_kv_heads, + head_size=model_config.head_size, + is_mla=model_config.use_mla + ) + + # Create mock GPU blocks + gpu_blocks = [] + for layer_id in range(model_config.num_layers): + # LAYERWISE format: [kv_dim, num_block, tokens_per_block, num_head, head_size] + kv_dim = 2 if not model_config.use_mla else 1 + gpu_tensor = torch.zeros( + kv_dim, + gpu_kv_layout.num_block, + gpu_kv_layout.tokens_per_block, + gpu_kv_layout.num_head, + gpu_kv_layout.head_size, + dtype=model_config.dtype, + device='cuda:0' if torch.cuda.is_available() else 'cpu' + ) + gpu_blocks.append(gpu_tensor) + + print("=== Method 1: Direct Tensor List ===") + verifier1 = GPUKVCacheVerifier( + shared_gpu_blocks=gpu_blocks, # Pass tensor list directly + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + ) + + print("=== Method 2: Using TensorSharedHandle (Multi-process version) ===") + mp.set_start_method('spawn') + # Create pipe for inter-process communication + parent_conn, child_conn = Pipe() + print(f"[Main Process {os.getpid()}] Successfully created pipe connection") + + # Start worker process to create GPU blocks and TensorSharedHandle + worker_process = Process( + target=gpu_blocks_worker_process, + args=(child_conn, model_config, cache_config, gpu_kv_layout) + ) + + print(f"[Main Process {os.getpid()}] Starting worker process...") + worker_process.start() + + # Wait to receive TensorSharedHandle created by worker process + print(f"[Main Process {os.getpid()}] Waiting to receive results from worker process...") + shared_gpu_blocks = parent_conn.recv() + + # Wait for worker process to complete + + + if shared_gpu_blocks is None: + raise RuntimeError("Worker process failed to create GPU blocks") + + print(f"[Main Process {os.getpid()}] Successfully received {len(shared_gpu_blocks)} TensorSharedHandles") + verifier2 = GPUKVCacheVerifier( + shared_gpu_blocks=shared_gpu_blocks, + gpu_kv_layout=gpu_kv_layout, + tp_size=model_config.tp_size, + tokens_per_block=cache_config.tokens_per_block, + dtype=model_config.dtype, + ) + + # Prepare test data - Note: now hash is calculated per block + token_ids = torch.randint(0, 1000, (32,), dtype=torch.int64) # 32 tokens (2 blocks) + block_ids = torch.tensor([0, 1], dtype=torch.int64) # Use blocks 0 and 1 + + print(f"Token IDs shape: {token_ids.shape}") + print(f"Block IDs: {block_ids}") + print(f"Tokens per block: {cache_config.tokens_per_block}") + + # Test method 1 + print("\n=== Testing Method 1 (Direct Tensor) ===") + print("Starting to fill GPU blocks...") + verifier1.fill_gpu_blocks(token_ids, block_ids) + print("Filling completed!") + + print("Starting data verification...") + is_valid1 = verifier1.verify_kv_blocks(token_ids, block_ids) + print(f"Verification result: {'PASSED' if is_valid1 else 'FAILED'}") + + # Test method 2 + print("\n=== Testing Method 2 (SharedHandle) ===") + print("Starting to fill GPU blocks...") + verifier2.fill_gpu_blocks(token_ids, block_ids) + print("Filling completed!") + + print("Starting data verification...") + is_valid2 = verifier2.verify_kv_blocks(token_ids, block_ids) + print(f"Verification result: {'PASSED' if is_valid2 else 'FAILED'}") + + # Demonstrate hash calculation changes: now each block has independent hash values + print("\n=== Hash Calculation Demo ===") + for block_idx, block_id in enumerate(block_ids): + start_idx = block_idx * cache_config.tokens_per_block + end_idx = start_idx + cache_config.tokens_per_block + block_tokens = token_ids[start_idx:end_idx] + hash_value = verifier1.hash_all_values(0, 0, block_tokens, 0) + print(f"Block {block_id} tokens: {block_tokens.tolist()[:5]}... -> hash: {hash_value:.6f}") + worker_process.join() + parent_conn.close() + return verifier1, token_ids, block_ids + +if __name__ == "__main__": + example_usage_gpu_kv_cache_verifier()